Unverified 提交 69b68c52 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1451 from katzyn/aggregate

Add experimental support for aggregates with OVER (ORDER BY *)
......@@ -3050,8 +3050,13 @@ public class Parser {
partitionBy.add(expr);
} while (readIf(COMMA));
}
ArrayList<SelectOrderBy> orderBy = null;
if (readIf(ORDER)) {
read("BY");
orderBy = parseSimpleOrderList();
}
read(CLOSE_PAREN);
aggregate.setOverCondition(new Window(partitionBy));
aggregate.setOverCondition(new Window(partitionBy, orderBy));
currentSelect.setWindowQuery();
} else {
currentSelect.setGroupQuery();
......
......@@ -366,8 +366,10 @@ public class Select extends Query {
groupData = SelectGroups.getInstance(session, expressions, isGroupQuery, groupIndex);
}
groupData.reset();
groupData.resetCounter();
try {
gatherGroup(columnCount, true);
groupData.resetCounter();
processGroupResult(columnCount, result, offset, quickOffset);
} finally {
groupData.reset();
......@@ -379,12 +381,15 @@ public class Select extends Query {
groupData = SelectGroups.getInstance(session, expressions, isGroupQuery, groupIndex);
}
groupData.reset();
groupData.resetCounter();
try {
gatherGroup(columnCount, false);
groupData.resetCounter();
while (groupData.next() != null) {
updateAgg(columnCount, true);
}
groupData.done();
groupData.resetCounter();
try {
isGroupWindowStage2 = true;
processGroupResult(columnCount, result, offset, quickOffset);
......
......@@ -183,6 +183,7 @@ public abstract class SelectGroups {
if (cursor.hasNext()) {
Object[] values = cursor.next();
currentGroupByExprData = values;
currentGroupRowId++;
return ValueArray.get(new Value[0]);
}
return null;
......@@ -315,6 +316,14 @@ public abstract class SelectGroups {
windowData.clear();
}
/**
* Reset the row id counter.
*/
public void resetCounter() {
// TODO merge into reset() and done()
currentGroupRowId = 0;
}
/**
* Invoked for each source row to evaluate group key and setup all necessary
* data for aggregates.
......
......@@ -5,17 +5,24 @@
*/
package org.h2.expression.aggregate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.message.DbException;
import org.h2.result.SortOrder;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueArray;
import org.h2.value.ValueInt;
/**
* A base class for aggregates.
......@@ -30,8 +37,22 @@ public abstract class AbstractAggregate extends Expression {
protected Window over;
private SortOrder overOrderBySort;
private int lastGroupRowId;
protected static SortOrder createOrder(Session session, ArrayList<SelectOrderBy> orderBy, int offset) {
int size = orderBy.size();
int[] index = new int[size];
int[] sortType = new int[size];
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderBy.get(i);
index[i] = i + offset;
sortType[i] = o.sortType;
}
return new SortOrder(session.getDatabase(), index, sortType, null);
}
AbstractAggregate(Select select, boolean distinct) {
this.select = select;
this.distinct = distinct;
......@@ -67,6 +88,17 @@ public abstract class AbstractAggregate extends Expression {
}
}
@Override
public Expression optimize(Session session) {
if (over != null) {
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
overOrderBySort = createOrder(session, orderBy, getNumExpressions());
}
}
return this;
}
@Override
public void setEvaluatable(TableFilter tableFilter, boolean b) {
if (filterCondition != null) {
......@@ -116,7 +148,14 @@ public abstract class AbstractAggregate extends Expression {
return;
}
}
updateAggregate(session, getData(session, groupData, false));
if (over != null) {
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
updateOrderedAggregate(session, groupData, groupRowId, orderBy);
return;
}
}
updateAggregate(session, getData(session, groupData, false, false));
}
/**
......@@ -138,7 +177,36 @@ public abstract class AbstractAggregate extends Expression {
*/
protected abstract void updateGroupAggregates(Session session);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists) {
/**
* Returns the number of expressions, excluding FILTER and OVER clauses.
*
* @return the number of expressions
*/
protected abstract int getNumExpressions();
/**
* Stores current values of expressions into the specified array.
*
* @param session
* the session
* @param array
* array to store values of expressions
*/
protected abstract void rememberExpressions(Session session, Value[] array);
/**
* Updates the provided aggregate data from the remembered expressions.
*
* @param session
* the session
* @param aggregateData
* aggregate data
* @param array
* values of expressions
*/
protected abstract void updateFromExpressions(Session session, Object aggregateData, Value[] array);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists, boolean forOrderBy) {
Object data;
if (over != null) {
ValueArray key = over.getCurrentKey(session);
......@@ -157,7 +225,7 @@ public abstract class AbstractAggregate extends Expression {
if (ifExists) {
return null;
}
data = createAggregateData();
data = forOrderBy ? new ArrayList<>() : createAggregateData();
map.put(key, new PartitionData(data));
} else {
data = partition.getData();
......@@ -168,7 +236,7 @@ public abstract class AbstractAggregate extends Expression {
if (ifExists) {
return null;
}
data = createAggregateData();
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setCurrentGroupExprData(this, new PartitionData(data), true);
} else {
data = partition.getData();
......@@ -180,7 +248,7 @@ public abstract class AbstractAggregate extends Expression {
if (ifExists) {
return null;
}
data = createAggregateData();
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setCurrentGroupExprData(this, data, false);
}
}
......@@ -195,13 +263,14 @@ public abstract class AbstractAggregate extends Expression {
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
return over == null ? getAggregatedValue(session, getData(session, groupData, true))
return over == null ? getAggregatedValue(session, getData(session, groupData, true, false))
: getWindowResult(session, groupData);
}
private Value getWindowResult(Session session, SelectGroups groupData) {
PartitionData partition;
Object data;
boolean forOrderBy = over.getOrderBy() != null;
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
......@@ -212,7 +281,7 @@ public abstract class AbstractAggregate extends Expression {
}
partition = (PartitionData) map.get(key);
if (partition == null) {
data = createAggregateData();
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
map.put(key, partition);
} else {
......@@ -221,13 +290,16 @@ public abstract class AbstractAggregate extends Expression {
} else {
partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
data = createAggregateData();
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
groupData.setCurrentGroupExprData(this, partition, true);
} else {
data = partition.getData();
}
}
if (over.getOrderBy() != null) {
return getOrderedResult(session, groupData, partition, data);
}
Value result = partition.getResult();
if (result == null) {
result = getAggregatedValue(session, data);
......@@ -238,6 +310,41 @@ public abstract class AbstractAggregate extends Expression {
protected abstract Value getAggregatedValue(Session session, Object aggregateData);
private void updateOrderedAggregate(Session session, SelectGroups groupData, int groupRowId,
ArrayList<SelectOrderBy> orderBy) {
int ne = getNumExpressions();
int size = orderBy.size();
Value[] array = new Value[ne + size + 1];
rememberExpressions(session, array);
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderBy.get(i);
array[ne++] = o.expression.getValue(session);
}
array[ne] = ValueInt.get(groupRowId);
@SuppressWarnings("unchecked")
ArrayList<Value[]> data = (ArrayList<Value[]>) getData(session, groupData, false, true);
data.add(array);
return;
}
private Value getOrderedResult(Session session, SelectGroups groupData, PartitionData partition, Object data) {
HashMap<Integer, Value> result = partition.getOrderedResult();
if (result == null) {
result = new HashMap<>();
@SuppressWarnings("unchecked")
ArrayList<Value[]> orderedData = (ArrayList<Value[]>) data;
int ne = getNumExpressions();
int last = ne + over.getOrderBy().size();
Collections.sort(orderedData, overOrderBySort);
Object aggregateData = createAggregateData();
for (Value[] row : orderedData) {
updateFromExpressions(session, aggregateData, row);
result.put(row[last].getInt(), getAggregatedValue(session, aggregateData));
}
}
return result.get(groupData.getCurrentGroupRowId());
}
protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
......
......@@ -256,18 +256,6 @@ public class Aggregate extends AbstractAggregate {
this.groupConcatSeparator = separator;
}
private SortOrder initOrder(Session session) {
int size = orderByList.size();
int[] index = new int[size];
int[] sortType = new int[size];
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderByList.get(i);
index[i] = i + 1;
sortType[i] = o.sortType;
}
return new SortOrder(session.getDatabase(), index, sortType, null);
}
private void sortWithOrderBy(Value[] array) {
final SortOrder sortOrder = orderBySort;
if (sortOrder != null) {
......@@ -286,13 +274,17 @@ public class Aggregate extends AbstractAggregate {
protected void updateAggregate(Session session, Object aggregateData) {
AggregateData data = (AggregateData) aggregateData;
Value v = on == null ? null : on.getValue(session);
updateData(session, data, v, null);
}
private void updateData(Session session, AggregateData data, Value v, Value[] remembered) {
if (type == AggregateType.GROUP_CONCAT) {
if (v != ValueNull.INSTANCE) {
v = updateCollecting(session, v.convertTo(Value.STRING));
v = updateCollecting(session, v.convertTo(Value.STRING), remembered);
}
} else if (type == AggregateType.ARRAY_AGG) {
if (v != ValueNull.INSTANCE) {
v = updateCollecting(session, v);
v = updateCollecting(session, v, remembered);
}
}
data.add(session.getDatabase(), dataType, distinct, v);
......@@ -310,20 +302,55 @@ public class Aggregate extends AbstractAggregate {
}
}
private Value updateCollecting(Session session, Value v) {
private Value updateCollecting(Session session, Value v, Value[] remembered) {
if (orderByList != null) {
int size = orderByList.size();
Value[] array = new Value[1 + size];
array[0] = v;
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderByList.get(i);
array[i + 1] = o.expression.getValue(session);
if (remembered == null) {
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderByList.get(i);
array[i + 1] = o.expression.getValue(session);
}
} else {
for (int i = 1; i <= size; i++) {
array[i] = remembered[i];
}
}
v = ValueArray.get(array);
}
return v;
}
@Override
protected int getNumExpressions() {
int n = on != null ? 1 : 0;
if (orderByList != null) {
n += orderByList.size();
}
return n;
}
@Override
protected void rememberExpressions(Session session, Value[] array) {
int offset = 0;
if (on != null) {
array[offset++] = on.getValue(session);
}
if (orderByList != null) {
for (SelectOrderBy o : orderByList) {
array[offset++] = o.expression.getValue(session);
}
}
}
@Override
protected void updateFromExpressions(Session session, Object aggregateData, Value[] array) {
AggregateData data = (AggregateData) aggregateData;
Value v = on == null ? null : array[0];
updateData(session, data, v, array);
}
@Override
protected Object createAggregateData() {
return AggregateData.create(type);
......@@ -450,6 +477,7 @@ public class Aggregate extends AbstractAggregate {
@Override
public Expression optimize(Session session) {
super.optimize(session);
if (on != null) {
on = on.optimize(session);
dataType = on.getType();
......@@ -461,7 +489,7 @@ public class Aggregate extends AbstractAggregate {
for (SelectOrderBy o : orderByList) {
o.expression = o.expression.optimize(session);
}
orderBySort = initOrder(session);
orderBySort = createOrder(session, orderByList, 1);
}
if (groupConcatSeparator != null) {
groupConcatSeparator = groupConcatSeparator.optimize(session);
......@@ -583,42 +611,28 @@ public class Aggregate extends AbstractAggregate {
}
private String getSQLGroupConcat() {
StatementBuilder buff = new StatementBuilder("GROUP_CONCAT(");
StringBuilder buff = new StringBuilder("GROUP_CONCAT(");
if (distinct) {
buff.append("DISTINCT ");
}
buff.append(on.getSQL());
if (orderByList != null) {
buff.append(" ORDER BY ");
for (SelectOrderBy o : orderByList) {
buff.appendExceptFirst(", ");
buff.append(o.expression.getSQL());
SortOrder.typeToString(buff.builder(), o.sortType);
}
}
Window.appendOrderBy(buff, orderByList);
if (groupConcatSeparator != null) {
buff.append(" SEPARATOR ").append(groupConcatSeparator.getSQL());
}
buff.append(')');
return appendTailConditions(buff.builder()).toString();
return appendTailConditions(buff).toString();
}
private String getSQLArrayAggregate() {
StatementBuilder buff = new StatementBuilder("ARRAY_AGG(");
StringBuilder buff = new StringBuilder("ARRAY_AGG(");
if (distinct) {
buff.append("DISTINCT ");
}
buff.append(on.getSQL());
if (orderByList != null) {
buff.append(" ORDER BY ");
for (SelectOrderBy o : orderByList) {
buff.appendExceptFirst(", ");
buff.append(o.expression.getSQL());
SortOrder.typeToString(buff.builder(), o.sortType);
}
}
Window.appendOrderBy(buff, orderByList);
buff.append(')');
return appendTailConditions(buff.builder()).toString();
return appendTailConditions(buff).toString();
}
@Override
......
......@@ -116,6 +116,7 @@ public class JavaAggregate extends AbstractAggregate {
@Override
public Expression optimize(Session session) {
super.optimize(session);
userConnection = session.createConnection(false);
int len = args.length;
argTypes = new int[len];
......@@ -194,13 +195,17 @@ public class JavaAggregate extends AbstractAggregate {
@Override
protected void updateAggregate(Session session, Object aggregateData) {
updateData(session, aggregateData, null);
}
private void updateData(Session session, Object aggregateData, Value[] remembered) {
try {
if (distinct) {
AggregateDataCollecting data = (AggregateDataCollecting) aggregateData;
Value[] argValues = new Value[args.length];
Value arg = null;
for (int i = 0, len = args.length; i < len; i++) {
arg = args[i].getValue(session);
arg = remembered == null ? args[i].getValue(session) : remembered[i];
arg = arg.convertTo(argTypes[i]);
argValues[i] = arg;
}
......@@ -210,7 +215,7 @@ public class JavaAggregate extends AbstractAggregate {
Object[] argValues = new Object[args.length];
Object arg = null;
for (int i = 0, len = args.length; i < len; i++) {
Value v = args[i].getValue(session);
Value v = remembered == null ? args[i].getValue(session) : remembered[i];
v = v.convertTo(argTypes[i]);
arg = v.getObject();
argValues[i] = arg;
......@@ -229,6 +234,23 @@ public class JavaAggregate extends AbstractAggregate {
}
}
@Override
protected int getNumExpressions() {
return args.length;
}
@Override
protected void rememberExpressions(Session session, Value[] array) {
for (int i = 0; i < args.length; i++) {
array[i] = args[i].getValue(session);
}
}
@Override
protected void updateFromExpressions(Session session, Object aggregateData, Value[] array) {
updateData(session, aggregateData, array);
}
@Override
protected Object createAggregateData() {
return distinct ? new AggregateDataCollecting() : getInstance();
......
......@@ -5,6 +5,8 @@
*/
package org.h2.expression.aggregate;
import java.util.HashMap;
import org.h2.value.Value;
/**
......@@ -22,6 +24,11 @@ final class PartitionData {
*/
private Value result;
/**
* Evaluated ordered result.
*/
private HashMap<Integer, Value> orderedResult;
/**
* Creates new instance of partition data.
*
......@@ -60,4 +67,23 @@ final class PartitionData {
this.result = result;
}
/**
* Returns the ordered result.
*
* @return the ordered result
*/
HashMap<Integer, Value> getOrderedResult() {
return orderedResult;
}
/**
* Sets the ordered result.
*
* @param orderedResult
* the ordered result to set
*/
void setOrderedResult(HashMap<Integer, Value> orderedResult) {
this.orderedResult = orderedResult;
}
}
......@@ -7,8 +7,10 @@ package org.h2.expression.aggregate;
import java.util.ArrayList;
import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.result.SortOrder;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.StringUtils;
......@@ -22,14 +24,39 @@ public final class Window {
private final ArrayList<Expression> partitionBy;
private final ArrayList<SelectOrderBy> orderBy;
/**
* @param builder
* string builder
* @param orderBy
* ORDER BY clause, or null
*/
static void appendOrderBy(StringBuilder builder, ArrayList<SelectOrderBy> orderBy) {
if (orderBy != null) {
builder.append(" ORDER BY ");
for (int i = 0; i < orderBy.size(); i++) {
SelectOrderBy o = orderBy.get(i);
if (i > 0) {
builder.append(", ");
}
builder.append(o.expression.getSQL());
SortOrder.typeToString(builder, o.sortType);
}
}
}
/**
* Creates a new instance of window clause.
*
* @param partitionBy
* PARTITION BY clause, or null
* @param orderBy
* ORDER BY clause, or null
*/
public Window(ArrayList<Expression> partitionBy) {
public Window(ArrayList<Expression> partitionBy, ArrayList<SelectOrderBy> orderBy) {
this.partitionBy = partitionBy;
this.orderBy = orderBy;
}
/**
......@@ -47,6 +74,11 @@ public final class Window {
e.mapColumns(resolver, level);
}
}
if (orderBy != null) {
for (SelectOrderBy o : orderBy) {
o.expression.mapColumns(resolver, level);
}
}
}
/**
......@@ -65,6 +97,20 @@ public final class Window {
e.setEvaluatable(tableFilter, value);
}
}
if (orderBy != null) {
for (SelectOrderBy o : orderBy) {
o.expression.setEvaluatable(tableFilter, value);
}
}
}
/**
* Returns ORDER BY clause.
*
* @return ORDER BY clause, or null
*/
public ArrayList<SelectOrderBy> getOrderBy() {
return orderBy;
}
/**
......@@ -95,16 +141,20 @@ public final class Window {
* @see Expression#getSQL()
*/
public String getSQL() {
if (partitionBy == null) {
if (partitionBy == null && orderBy == null) {
return "OVER ()";
}
StringBuilder builder = new StringBuilder().append("OVER (PARTITION BY ");
for (int i = 0; i < partitionBy.size(); i++) {
if (i > 0) {
builder.append(", ");
StringBuilder builder = new StringBuilder().append("OVER (");
if (partitionBy != null) {
builder.append("PARTITION BY ");
for (int i = 0; i < partitionBy.size(); i++) {
if (i > 0) {
builder.append(", ");
}
builder.append(StringUtils.unEnclose(partitionBy.get(i).getSQL()));
}
builder.append(StringUtils.unEnclose(partitionBy.get(i).getSQL()));
}
appendOrderBy(builder, orderBy);
return builder.append(')').toString();
}
......@@ -113,7 +163,8 @@ public final class Window {
*
* @param session
* the session
* @param window true for window processing stage, false for group stage
* @param window
* true for window processing stage, false for group stage
* @see Expression#updateAggregate(Session, boolean)
*/
public void updateAggregate(Session session, boolean window) {
......@@ -122,6 +173,11 @@ public final class Window {
expr.updateAggregate(session, window);
}
}
if (orderBy != null) {
for (SelectOrderBy o : orderBy) {
o.expression.updateAggregate(session, false);
}
}
}
@Override
......
......@@ -196,5 +196,52 @@ SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'c') OVER () FR
SELECT ARRAY_AGG(ID) OVER() FROM TEST GROUP BY NAME;
> exception MUST_GROUP_BY_COLUMN_1
SELECT ARRAY_AGG(ID) OVER(PARTITION BY NAME ORDER /**/ BY ID), NAME FROM TEST;
> ARRAY_AGG(ID) OVER (PARTITION BY NAME ORDER BY ID) NAME
> -------------------------------------------------- ----
> (1) a
> (1, 2) a
> (3) b
> (4) c
> (4, 5) c
> (4, 5, 6) c
> rows: 6
SELECT ARRAY_AGG(ID) OVER(PARTITION BY NAME ORDER /**/ BY ID DESC), NAME FROM TEST;
> ARRAY_AGG(ID) OVER (PARTITION BY NAME ORDER BY ID DESC) NAME
> ------------------------------------------------------- ----
> (2) a
> (2, 1) a
> (3) b
> (6) c
> (6, 5) c
> (6, 5, 4) c
> rows: 6
SELECT
ARRAY_AGG(ID ORDER /**/ BY ID) OVER(PARTITION BY NAME ORDER /**/ BY ID DESC) A,
ARRAY_AGG(ID) OVER(PARTITION BY NAME ORDER /**/ BY ID DESC) D,
NAME FROM TEST;
> A D NAME
> --------- --------- ----
> (1, 2) (2, 1) a
> (2) (2) a
> (3) (3) b
> (4, 5, 6) (6, 5, 4) c
> (5, 6) (6, 5) c
> (6) (6) c
> rows: 6
SELECT ARRAY_AGG(SUM(ID)) OVER(ORDER /**/ BY ID) FROM TEST GROUP BY ID;
> ARRAY_AGG(SUM(ID)) OVER ( ORDER BY ID)
> --------------------------------------
> (1)
> (1, 2)
> (1, 2, 3)
> (1, 2, 3, 4)
> (1, 2, 3, 4, 5)
> (1, 2, 3, 4, 5, 6)
> rows: 6
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论