Unverified 提交 771013f8 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1444 from katzyn/aggregate

Add experimental unoptimized support for OVER ([PARTITION BY ...]) in aggregates
...@@ -170,9 +170,11 @@ import org.h2.expression.UnaryOperation; ...@@ -170,9 +170,11 @@ import org.h2.expression.UnaryOperation;
import org.h2.expression.ValueExpression; import org.h2.expression.ValueExpression;
import org.h2.expression.Variable; import org.h2.expression.Variable;
import org.h2.expression.Wildcard; import org.h2.expression.Wildcard;
import org.h2.expression.aggregate.AbstractAggregate;
import org.h2.expression.aggregate.Aggregate; import org.h2.expression.aggregate.Aggregate;
import org.h2.expression.aggregate.Aggregate.AggregateType; import org.h2.expression.aggregate.Aggregate.AggregateType;
import org.h2.expression.aggregate.JavaAggregate; import org.h2.expression.aggregate.JavaAggregate;
import org.h2.expression.aggregate.Window;
import org.h2.index.Index; import org.h2.index.Index;
import org.h2.message.DbException; import org.h2.message.DbException;
import org.h2.result.SortOrder; import org.h2.result.SortOrder;
...@@ -2883,7 +2885,6 @@ public class Parser { ...@@ -2883,7 +2885,6 @@ public class Parser {
if (currentSelect == null) { if (currentSelect == null) {
throw getSyntaxError(); throw getSyntaxError();
} }
currentSelect.setGroupQuery();
Aggregate r; Aggregate r;
switch (aggregateType) { switch (aggregateType) {
case COUNT: case COUNT:
...@@ -2967,7 +2968,7 @@ public class Parser { ...@@ -2967,7 +2968,7 @@ public class Parser {
} }
read(CLOSE_PAREN); read(CLOSE_PAREN);
if (r != null) { if (r != null) {
r.setFilterCondition(readFilterCondition()); readFilterAndOver(r);
} }
return r; return r;
} }
...@@ -3024,22 +3025,37 @@ public class Parser { ...@@ -3024,22 +3025,37 @@ public class Parser {
do { do {
params.add(readExpression()); params.add(readExpression());
} while (readIfMore(true)); } while (readIfMore(true));
Expression filterCondition = readFilterCondition();
Expression[] list = params.toArray(new Expression[0]); Expression[] list = params.toArray(new Expression[0]);
JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, distinct, filterCondition); JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, distinct);
currentSelect.setGroupQuery(); readFilterAndOver(agg);
return agg; return agg;
} }
private Expression readFilterCondition() { private void readFilterAndOver(AbstractAggregate aggregate) {
if (readIf("FILTER")) { if (readIf("FILTER")) {
read(OPEN_PAREN); read(OPEN_PAREN);
read(WHERE); read(WHERE);
Expression filterCondition = readExpression(); Expression filterCondition = readExpression();
read(CLOSE_PAREN); read(CLOSE_PAREN);
return filterCondition; aggregate.setFilterCondition(filterCondition);
}
if (readIf("OVER")) {
read(OPEN_PAREN);
ArrayList<Expression> partitionBy = null;
if (readIf("PARTITION")) {
read("BY");
partitionBy = Utils.newSmallArrayList();
do {
Expression expr = readExpression();
partitionBy.add(expr);
} while (readIf(COMMA));
}
read(CLOSE_PAREN);
aggregate.setOverCondition(new Window(partitionBy));
currentSelect.setWindowQuery();
} else {
currentSelect.setGroupQuery();
} }
return null;
} }
private AggregateType getAggregateType(String name) { private AggregateType getAggregateType(String name) {
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.command.dml;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.value.Value;
import org.h2.value.ValueArray;
/**
* Grouped data for aggregates.
*
* <p>
* Call sequence:
* </p>
* <ul>
* <li>{@link #reset()} (not required before the first execution).</li>
* <li>For each source row {@link #nextSource()} should be invoked.</li>
* <li>{@link #done()}.</li>
* <li>{@link #next()} is invoked inside a loop until it returns null.</li>
* </ul>
* <p>
* Call sequence for lazy group sorted result:
* </p>
* <ul>
* <li>{@link #resetLazy()} (not required before the first execution).</li>
* <li>For each source group {@link #nextLazyGroup()} should be invoked.</li>
* <li>For each source row {@link #nextLazyRow()} should be invoked. Each group
* can have one or more rows.</li>
* </ul>
*/
public final class SelectGroups {
private final Session session;
private final ArrayList<Expression> expressions;
private final int[] groupIndex;
/**
* The array of current group-by expression data e.g. AggregateData.
*/
private Object[] currentGroupByExprData;
/**
* Maps an expression object to an index, to use in accessing the Object[]
* pointed to by groupByData.
*/
private final HashMap<Expression, Integer> exprToIndexInGroupByData = new HashMap<>();
/**
* Map of group-by key to group-by expression data e.g. AggregateData
*/
private HashMap<ValueArray, Object[]> groupByData;
/**
* Key into groupByData that produces currentGroupByExprData. Not used in
* lazy mode.
*/
private ValueArray currentGroupsKey;
/**
* The id of the current group.
*/
private int currentGroupRowId;
/**
* The key for the default group.
*/
// Can be static, but TestClearReferences complains about it
private ValueArray defaultGroup = ValueArray.get(new Value[0]);
/**
* Cursor for {@link #next()} method.
*/
private Iterator<Entry<ValueArray, Object[]>> cursor;
/**
* Creates new instance of grouped data.
*
* @param session
* the session
* @param expressions
* the expressions
* @param groupIndex
* the indexes of group expressions, or null
*/
public SelectGroups(Session session, ArrayList<Expression> expressions, int[] groupIndex) {
this.session = session;
this.expressions = expressions;
this.groupIndex = groupIndex;
}
/**
* Is there currently a group-by active
*/
public boolean isCurrentGroup() {
return currentGroupByExprData != null;
}
/**
* Get the group-by data for the current group and the passed in expression.
*/
public Object getCurrentGroupExprData(Expression expr) {
Integer index = exprToIndexInGroupByData.get(expr);
if (index == null) {
return null;
}
return currentGroupByExprData[index];
}
/**
* Set the group-by data for the current group and the passed in expression.
*/
public void setCurrentGroupExprData(Expression expr, Object obj) {
Integer index = exprToIndexInGroupByData.get(expr);
if (index != null) {
assert currentGroupByExprData[index] == null;
currentGroupByExprData[index] = obj;
return;
}
index = exprToIndexInGroupByData.size();
exprToIndexInGroupByData.put(expr, index);
if (index >= currentGroupByExprData.length) {
currentGroupByExprData = Arrays.copyOf(currentGroupByExprData, currentGroupByExprData.length * 2);
// this can be null in lazy mode
if (currentGroupsKey != null) {
// since we changed the size of the array, update the object in
// the groups map
groupByData.put(currentGroupsKey, currentGroupByExprData);
}
}
currentGroupByExprData[index] = obj;
}
/**
* Returns identity of the current row. Used by aggregates to check whether
* they already processed this row or not.
*
* @return identity of the current row
*/
public int getCurrentGroupRowId() {
return currentGroupRowId;
}
/**
* Resets this group data for reuse.
*/
public void reset() {
groupByData = new HashMap<>();
currentGroupByExprData = null;
currentGroupsKey = null;
exprToIndexInGroupByData.clear();
cursor = null;
}
/**
* Invoked for each source row to evaluate group key and setup all necessary
* data for aggregates.
*
* @return key of the current group
*/
public ValueArray nextSource() {
if (groupIndex == null) {
currentGroupsKey = defaultGroup;
} else {
Value[] keyValues = new Value[groupIndex.length];
// update group
for (int i = 0; i < groupIndex.length; i++) {
int idx = groupIndex[i];
Expression expr = expressions.get(idx);
keyValues[i] = expr.getValue(session);
}
currentGroupsKey = ValueArray.get(keyValues);
}
Object[] values = groupByData.get(currentGroupsKey);
if (values == null) {
values = new Object[Math.max(exprToIndexInGroupByData.size(), expressions.size())];
groupByData.put(currentGroupsKey, values);
}
currentGroupByExprData = values;
currentGroupRowId++;
return currentGroupsKey;
}
/**
* Invoked after all source rows are evaluated.
*/
public void done() {
if (groupIndex == null && groupByData.size() == 0) {
groupByData.put(defaultGroup, new Object[Math.max(exprToIndexInGroupByData.size(), expressions.size())]);
}
cursor = groupByData.entrySet().iterator();
}
/**
* Returns the key of the next group.
*
* @return the key of the next group, or null
*/
public ValueArray next() {
if (cursor.hasNext()) {
Map.Entry<ValueArray, Object[]> entry = cursor.next();
currentGroupByExprData = entry.getValue();
return entry.getKey();
}
return null;
}
/**
* Resets this group data for reuse in lazy mode.
*/
public void resetLazy() {
currentGroupByExprData = null;
currentGroupsKey = null;
}
/**
* Moves group data to the next group in lazy mode.
*/
public void nextLazyGroup() {
currentGroupByExprData = new Object[Math.max(exprToIndexInGroupByData.size(), expressions.size())];
}
/**
* Moves group data to the next row in lazy mode.
*/
public void nextLazyRow() {
currentGroupRowId++;
}
}
...@@ -8,6 +8,7 @@ package org.h2.expression; ...@@ -8,6 +8,7 @@ package org.h2.expression;
import org.h2.api.ErrorCode; import org.h2.api.ErrorCode;
import org.h2.command.Parser; import org.h2.command.Parser;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectListColumnResolver; import org.h2.command.dml.SelectListColumnResolver;
import org.h2.engine.Database; import org.h2.engine.Database;
import org.h2.engine.Session; import org.h2.engine.Session;
...@@ -158,13 +159,14 @@ public class ExpressionColumn extends Expression { ...@@ -158,13 +159,14 @@ public class ExpressionColumn extends Expression {
if (select == null) { if (select == null) {
throw DbException.get(ErrorCode.MUST_GROUP_BY_COLUMN_1, getSQL()); throw DbException.get(ErrorCode.MUST_GROUP_BY_COLUMN_1, getSQL());
} }
if (!select.isCurrentGroup()) { SelectGroups groupData = select.getGroupDataIfCurrent(false);
if (groupData == null) {
// this is a different level (the enclosing query) // this is a different level (the enclosing query)
return; return;
} }
Value v = (Value) select.getCurrentGroupExprData(this); Value v = (Value) groupData.getCurrentGroupExprData(this);
if (v == null) { if (v == null) {
select.setCurrentGroupExprData(this, now); groupData.setCurrentGroupExprData(this, now);
} else { } else {
if (!database.areEqual(now, v)) { if (!database.areEqual(now, v)) {
throw DbException.get(ErrorCode.MUST_GROUP_BY_COLUMN_1, getSQL()); throw DbException.get(ErrorCode.MUST_GROUP_BY_COLUMN_1, getSQL());
...@@ -176,8 +178,9 @@ public class ExpressionColumn extends Expression { ...@@ -176,8 +178,9 @@ public class ExpressionColumn extends Expression {
public Value getValue(Session session) { public Value getValue(Session session) {
Select select = columnResolver.getSelect(); Select select = columnResolver.getSelect();
if (select != null) { if (select != null) {
if (select.isCurrentGroup()) { SelectGroups groupData = select.getGroupDataIfCurrent(false);
Value v = (Value) select.getCurrentGroupExprData(this); if (groupData != null) {
Value v = (Value) groupData.getCurrentGroupExprData(this);
if (v != null) { if (v != null) {
return v; return v;
} }
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
import org.h2.expression.Expression;
import org.h2.table.ColumnResolver;
/**
* A base class for aggregates.
*/
public abstract class AbstractAggregate extends Expression {
protected Expression filterCondition;
protected Window over;
/**
* Sets the FILTER condition.
*
* @param filterCondition
* FILTER condition
*/
public void setFilterCondition(Expression filterCondition) {
this.filterCondition = filterCondition;
}
/**
* Sets the OVER condition.
*
* @param over
* OVER condition
*/
public void setOverCondition(Window over) {
this.over = over;
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
if (filterCondition != null) {
filterCondition.mapColumns(resolver, level);
}
if (over != null) {
over.mapColumns(resolver, level);
}
}
}
...@@ -11,6 +11,7 @@ import java.util.Comparator; ...@@ -11,6 +11,7 @@ import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import org.h2.api.ErrorCode; import org.h2.api.ErrorCode;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy; import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
...@@ -28,6 +29,7 @@ import org.h2.table.Table; ...@@ -28,6 +29,7 @@ import org.h2.table.Table;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.util.StringUtils; import org.h2.util.StringUtils;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -41,7 +43,7 @@ import org.h2.value.ValueString; ...@@ -41,7 +43,7 @@ import org.h2.value.ValueString;
/** /**
* Implements the integrated aggregate functions, such as COUNT, MAX, SUM. * Implements the integrated aggregate functions, such as COUNT, MAX, SUM.
*/ */
public class Aggregate extends Expression { public class Aggregate extends AbstractAggregate {
public enum AggregateType { public enum AggregateType {
/** /**
...@@ -165,8 +167,6 @@ public class Aggregate extends Expression { ...@@ -165,8 +167,6 @@ public class Aggregate extends Expression {
private int displaySize; private int displaySize;
private int lastGroupRowId; private int lastGroupRowId;
private Expression filterCondition;
/** /**
* Create a new aggregate object. * Create a new aggregate object.
* *
...@@ -254,15 +254,6 @@ public class Aggregate extends Expression { ...@@ -254,15 +254,6 @@ public class Aggregate extends Expression {
this.groupConcatSeparator = separator; this.groupConcatSeparator = separator;
} }
/**
* Sets the FILTER condition.
*
* @param filterCondition condition
*/
public void setFilterCondition(Expression filterCondition) {
this.filterCondition = filterCondition;
}
private SortOrder initOrder(Session session) { private SortOrder initOrder(Session session) {
int size = orderByList.size(); int size = orderByList.size();
int[] index = new int[size]; int[] index = new int[size];
...@@ -295,12 +286,13 @@ public class Aggregate extends Expression { ...@@ -295,12 +286,13 @@ public class Aggregate extends Expression {
// if (on != null) { // if (on != null) {
// on.updateAggregate(); // on.updateAggregate();
// } // }
if (!select.isCurrentGroup()) { SelectGroups groupData = select.getGroupDataIfCurrent(true);
if (groupData == null) {
// this is a different level (the enclosing query) // this is a different level (the enclosing query)
return; return;
} }
int groupRowId = select.getCurrentGroupRowId(); int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) { if (lastGroupRowId == groupRowId) {
// already visited // already visited
return; return;
...@@ -312,11 +304,7 @@ public class Aggregate extends Expression { ...@@ -312,11 +304,7 @@ public class Aggregate extends Expression {
return; return;
} }
} }
AggregateData data = (AggregateData) select.getCurrentGroupExprData(this); AggregateData data = getData(session, groupData);
if (data == null) {
data = AggregateData.create(type);
select.setCurrentGroupExprData(this, data);
}
Value v = on == null ? null : on.getValue(session); Value v = on == null ? null : on.getValue(session);
if (type == AggregateType.GROUP_CONCAT) { if (type == AggregateType.GROUP_CONCAT) {
if (v != ValueNull.INSTANCE) { if (v != ValueNull.INSTANCE) {
...@@ -378,14 +366,11 @@ public class Aggregate extends Expression { ...@@ -378,14 +366,11 @@ public class Aggregate extends Expression {
DbException.throwInternalError("type=" + type); DbException.throwInternalError("type=" + type);
} }
} }
if (!select.isCurrentGroup()) { SelectGroups groupData = select.getGroupDataIfCurrent(true);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL()); throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
AggregateData data = (AggregateData)select.getCurrentGroupExprData(this); AggregateData data = getData(session, groupData);
if (data == null) {
data = AggregateData.create(type);
select.setCurrentGroupExprData(this, data);
}
switch (type) { switch (type) {
case GROUP_CONCAT: { case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray(); Value[] array = ((AggregateDataCollecting) data).getArray();
...@@ -441,6 +426,31 @@ public class Aggregate extends Expression { ...@@ -441,6 +426,31 @@ public class Aggregate extends Expression {
} }
} }
private AggregateData getData(Session session, SelectGroups groupData) {
AggregateData data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateData> map = (ValueHashMap<AggregateData>) groupData.getCurrentGroupExprData(this);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
data = AggregateData.create(type);
map.put(key, data);
}
} else {
data = (AggregateData) groupData.getCurrentGroupExprData(this);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
@Override @Override
public int getType() { public int getType() {
return dataType; return dataType;
...@@ -459,9 +469,7 @@ public class Aggregate extends Expression { ...@@ -459,9 +469,7 @@ public class Aggregate extends Expression {
if (groupConcatSeparator != null) { if (groupConcatSeparator != null) {
groupConcatSeparator.mapColumns(resolver, level); groupConcatSeparator.mapColumns(resolver, level);
} }
if (filterCondition != null) { super.mapColumns(resolver, level);
filterCondition.mapColumns(resolver, level);
}
} }
@Override @Override
...@@ -621,6 +629,9 @@ public class Aggregate extends Expression { ...@@ -621,6 +629,9 @@ public class Aggregate extends Expression {
if (filterCondition != null) { if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
} }
if (over != null) {
buff.append(' ').append(over.getSQL());
}
return buff.toString(); return buff.toString();
} }
...@@ -642,6 +653,9 @@ public class Aggregate extends Expression { ...@@ -642,6 +653,9 @@ public class Aggregate extends Expression {
if (filterCondition != null) { if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
} }
if (over != null) {
buff.append(' ').append(over.getSQL());
}
return buff.toString(); return buff.toString();
} }
...@@ -720,6 +734,9 @@ public class Aggregate extends Expression { ...@@ -720,6 +734,9 @@ public class Aggregate extends Expression {
if (filterCondition != null) { if (filterCondition != null) {
text += " FILTER (WHERE " + filterCondition.getSQL() + ')'; text += " FILTER (WHERE " + filterCondition.getSQL() + ')';
} }
if (over != null) {
text += ' ' + over.getSQL();
}
return text; return text;
} }
......
...@@ -11,6 +11,7 @@ import org.h2.api.Aggregate; ...@@ -11,6 +11,7 @@ import org.h2.api.Aggregate;
import org.h2.api.ErrorCode; import org.h2.api.ErrorCode;
import org.h2.command.Parser; import org.h2.command.Parser;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.engine.UserAggregate; import org.h2.engine.UserAggregate;
import org.h2.expression.Expression; import org.h2.expression.Expression;
...@@ -19,6 +20,7 @@ import org.h2.message.DbException; ...@@ -19,6 +20,7 @@ import org.h2.message.DbException;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -27,25 +29,22 @@ import org.h2.value.ValueNull; ...@@ -27,25 +29,22 @@ import org.h2.value.ValueNull;
/** /**
* This class wraps a user-defined aggregate. * This class wraps a user-defined aggregate.
*/ */
public class JavaAggregate extends Expression { public class JavaAggregate extends AbstractAggregate {
private final UserAggregate userAggregate; private final UserAggregate userAggregate;
private final Select select; private final Select select;
private final Expression[] args; private final Expression[] args;
private int[] argTypes; private int[] argTypes;
private final boolean distinct; private final boolean distinct;
private Expression filterCondition;
private int dataType; private int dataType;
private Connection userConnection; private Connection userConnection;
private int lastGroupRowId; private int lastGroupRowId;
public JavaAggregate(UserAggregate userAggregate, Expression[] args, public JavaAggregate(UserAggregate userAggregate, Expression[] args, Select select, boolean distinct) {
Select select, boolean distinct, Expression filterCondition) {
this.userAggregate = userAggregate; this.userAggregate = userAggregate;
this.args = args; this.args = args;
this.select = select; this.select = select;
this.distinct = distinct; this.distinct = distinct;
this.filterCondition = filterCondition;
} }
@Override @Override
...@@ -87,6 +86,9 @@ public class JavaAggregate extends Expression { ...@@ -87,6 +86,9 @@ public class JavaAggregate extends Expression {
if (filterCondition != null) { if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
} }
if (over != null) {
buff.append(' ').append(over.getSQL());
}
return buff.toString(); return buff.toString();
} }
...@@ -122,9 +124,7 @@ public class JavaAggregate extends Expression { ...@@ -122,9 +124,7 @@ public class JavaAggregate extends Expression {
for (Expression arg : args) { for (Expression arg : args) {
arg.mapColumns(resolver, level); arg.mapColumns(resolver, level);
} }
if (filterCondition != null) { super.mapColumns(resolver, level);
filterCondition.mapColumns(resolver, level);
}
} }
@Override @Override
...@@ -168,14 +168,15 @@ public class JavaAggregate extends Expression { ...@@ -168,14 +168,15 @@ public class JavaAggregate extends Expression {
@Override @Override
public Value getValue(Session session) { public Value getValue(Session session) {
if (!select.isCurrentGroup()) { SelectGroups groupData = select.getGroupDataIfCurrent(true);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL()); throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
try { try {
Aggregate agg; Aggregate agg;
if (distinct) { if (distinct) {
agg = getInstance(); agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) select.getCurrentGroupExprData(this); AggregateDataCollecting data = getDataDistinct(session, groupData, true);
if (data != null) { if (data != null) {
for (Value value : data.values) { for (Value value : data.values) {
if (args.length == 1) { if (args.length == 1) {
...@@ -191,7 +192,7 @@ public class JavaAggregate extends Expression { ...@@ -191,7 +192,7 @@ public class JavaAggregate extends Expression {
} }
} }
} else { } else {
agg = (Aggregate) select.getCurrentGroupExprData(this); agg = getData(session, groupData, true);
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
} }
...@@ -208,12 +209,13 @@ public class JavaAggregate extends Expression { ...@@ -208,12 +209,13 @@ public class JavaAggregate extends Expression {
@Override @Override
public void updateAggregate(Session session) { public void updateAggregate(Session session) {
if (!select.isCurrentGroup()) { SelectGroups groupData = select.getGroupDataIfCurrent(true);
if (groupData == null) {
// this is a different level (the enclosing query) // this is a different level (the enclosing query)
return; return;
} }
int groupRowId = select.getCurrentGroupRowId(); int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) { if (lastGroupRowId == groupRowId) {
// already visited // already visited
return; return;
...@@ -228,11 +230,7 @@ public class JavaAggregate extends Expression { ...@@ -228,11 +230,7 @@ public class JavaAggregate extends Expression {
try { try {
if (distinct) { if (distinct) {
AggregateDataCollecting data = (AggregateDataCollecting) select.getCurrentGroupExprData(this); AggregateDataCollecting data = getDataDistinct(session, groupData, false);
if (data == null) {
data = new AggregateDataCollecting();
select.setCurrentGroupExprData(this, data);
}
Value[] argValues = new Value[args.length]; Value[] argValues = new Value[args.length];
Value arg = null; Value arg = null;
for (int i = 0, len = args.length; i < len; i++) { for (int i = 0, len = args.length; i < len; i++) {
...@@ -242,11 +240,7 @@ public class JavaAggregate extends Expression { ...@@ -242,11 +240,7 @@ public class JavaAggregate extends Expression {
} }
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues)); data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else { } else {
Aggregate agg = (Aggregate) select.getCurrentGroupExprData(this); Aggregate agg = getData(session, groupData, false);
if (agg == null) {
agg = getInstance();
select.setCurrentGroupExprData(this, agg);
}
Object[] argValues = new Object[args.length]; Object[] argValues = new Object[args.length];
Object arg = null; Object arg = null;
for (int i = 0, len = args.length; i < len; i++) { for (int i = 0, len = args.length; i < len; i++) {
...@@ -262,4 +256,73 @@ public class JavaAggregate extends Expression { ...@@ -262,4 +256,73 @@ public class JavaAggregate extends Expression {
} }
} }
private Aggregate getData(Session session, SelectGroups groupData, boolean ifExists) throws SQLException {
Aggregate data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Aggregate> map = (ValueHashMap<Aggregate>) groupData.getCurrentGroupExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
map.put(key, data);
}
} else {
data = (Aggregate) groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
private AggregateDataCollecting getDataDistinct(Session session, SelectGroups groupData, boolean ifExists) {
AggregateDataCollecting data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateDataCollecting> map =
(ValueHashMap<AggregateDataCollecting>) groupData.getCurrentGroupExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
map.put(key, data);
}
} else {
data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
} }
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
import java.util.ArrayList;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.table.ColumnResolver;
import org.h2.util.StringUtils;
import org.h2.value.Value;
import org.h2.value.ValueArray;
/**
* Window clause.
*/
public final class Window {
private final ArrayList<Expression> partitionBy;
/**
* Creates a new instance of window clause.
*
* @param partitionBy
* PARTITION BY clause, or null
*/
public Window(ArrayList<Expression> partitionBy) {
this.partitionBy = partitionBy;
}
/**
* Map the columns of the resolver to expression columns.
*
* @param resolver
* the column resolver
* @param level
* the subquery nesting level
*/
public void mapColumns(ColumnResolver resolver, int level) {
if (partitionBy != null) {
for (Expression e : partitionBy) {
e.mapColumns(resolver, level);
}
}
}
/**
* Returns the key for the current group.
*
* @param session
* session
* @return key for the current group, or null
*/
public ValueArray getCurrentKey(Session session) {
if (partitionBy == null) {
return null;
}
int len = partitionBy.size();
Value[] keyValues = new Value[len];
// update group
for (int i = 0; i < len; i++) {
Expression expr = partitionBy.get(i);
keyValues[i] = expr.getValue(session);
}
return ValueArray.get(keyValues);
}
/**
* Returns SQL representation.
*
* @return SQL representation.
*/
public String getSQL() {
if (partitionBy == null) {
return "OVER ()";
}
StringBuilder builder = new StringBuilder().append("OVER (PARTITION BY ");
for (int i = 0; i < partitionBy.size(); i++) {
if (i > 0) {
builder.append(", ");
}
builder.append(StringUtils.unEnclose(partitionBy.get(i).getSQL()));
}
return builder.append(')').toString();
}
@Override
public String toString() {
return getSQL();
}
}
...@@ -757,6 +757,16 @@ public class TestFunctions extends TestDb implements AggregateFunction { ...@@ -757,6 +757,16 @@ public class TestFunctions extends TestDb implements AggregateFunction {
"SELECT SIMPLE_MEDIAN(X) FILTER (WHERE X > 2) FROM SYSTEM_RANGE(1, 9)"); "SELECT SIMPLE_MEDIAN(X) FILTER (WHERE X > 2) FROM SYSTEM_RANGE(1, 9)");
rs.next(); rs.next();
assertEquals("6", rs.getString(1)); assertEquals("6", rs.getString(1));
rs = stat.executeQuery("SELECT SIMPLE_MEDIAN(X) OVER () FROM SYSTEM_RANGE(1, 9)");
for (int i = 1; i < 9; i++) {
assertTrue(rs.next());
assertEquals("5", rs.getString(1));
}
rs = stat.executeQuery("SELECT SIMPLE_MEDIAN(X) OVER (PARTITION BY X) FROM SYSTEM_RANGE(1, 9)");
for (int i = 1; i < 9; i++) {
assertTrue(rs.next());
assertEquals(Integer.toString(i), rs.getString(1));
}
conn.close(); conn.close();
if (config.memory) { if (config.memory) {
......
...@@ -67,3 +67,51 @@ select array_agg(distinct v order by v desc) from test; ...@@ -67,3 +67,51 @@ select array_agg(distinct v order by v desc) from test;
drop table test; drop table test;
> ok > ok
CREATE TABLE TEST (ID INT PRIMARY KEY, NAME VARCHAR);
> ok
INSERT INTO TEST VALUES (1, 'a'), (2, 'a'), (3, 'b'), (4, 'c'), (5, 'c'), (6, 'c');
> update count: 6
SELECT ARRAY_AGG(ID), NAME FROM TEST;
> exception MUST_GROUP_BY_COLUMN_1
SELECT ARRAY_AGG(ID), NAME FROM TEST GROUP BY NAME;
> ARRAY_AGG(ID) NAME
> ------------- ----
> (1, 2) a
> (3) b
> (4, 5, 6) c
> rows: 3
SELECT ARRAY_AGG(ID) OVER (), NAME FROM TEST;
> ARRAY_AGG(ID) OVER () NAME
> --------------------- ----
> (1, 2, 3, 4, 5, 6) a
> (1, 2, 3, 4, 5, 6) a
> (1, 2, 3, 4, 5, 6) b
> (1, 2, 3, 4, 5, 6) c
> (1, 2, 3, 4, 5, 6) c
> (1, 2, 3, 4, 5, 6) c
> rows: 6
SELECT ARRAY_AGG(ID) OVER (PARTITION BY NAME), NAME FROM TEST;
> ARRAY_AGG(ID) OVER (PARTITION BY NAME) NAME
> -------------------------------------- ----
> (1, 2) a
> (1, 2) a
> (3) b
> (4, 5, 6) c
> (4, 5, 6) c
> (4, 5, 6) c
> rows: 6
SELECT ARRAY_AGG(SUM(ID)) OVER () FROM TEST;
> exception FEATURE_NOT_SUPPORTED_1
SELECT ARRAY_AGG(ID) OVER() FROM TEST GROUP BY ID;
> exception FEATURE_NOT_SUPPORTED_1
DROP TABLE TEST;
> ok
...@@ -795,4 +795,4 @@ minxd maxxd minyd maxyd bminxd bmaxxd bminyd bmaxyd ...@@ -795,4 +795,4 @@ minxd maxxd minyd maxyd bminxd bmaxxd bminyd bmaxyd
interior envelopes multilinestring multipoint packed exterior normalization awkward determination subgeometries interior envelopes multilinestring multipoint packed exterior normalization awkward determination subgeometries
xym normalizes coord setz xyzm geometrycollection multipolygon mixup rings polygons rejection finite xym normalizes coord setz xyzm geometrycollection multipolygon mixup rings polygons rejection finite
pointzm pointz pointm dimensionality redefine forum measures pointzm pointz pointm dimensionality redefine forum measures
mpg casted pzm mls constrained mpg casted pzm mls constrained subtypes complains
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论