提交 aee4cfaa authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Add partial support for OVER() without conditions to aggregate functions

上级 f6e1781a
......@@ -173,6 +173,7 @@ import org.h2.expression.Wildcard;
import org.h2.expression.aggregate.Aggregate;
import org.h2.expression.aggregate.Aggregate.AggregateType;
import org.h2.expression.aggregate.JavaAggregate;
import org.h2.expression.aggregate.Window;
import org.h2.index.Index;
import org.h2.message.DbException;
import org.h2.result.SortOrder;
......@@ -2883,7 +2884,6 @@ public class Parser {
if (currentSelect == null) {
throw getSyntaxError();
}
currentSelect.setGroupQuery();
Aggregate r;
switch (aggregateType) {
case COUNT:
......@@ -2968,10 +2968,26 @@ public class Parser {
read(CLOSE_PAREN);
if (r != null) {
r.setFilterCondition(readFilterCondition());
Window over = readOver();
if (over != null) {
r.setOverCondition(over);
currentSelect.setWindowQuery();
} else {
currentSelect.setGroupQuery();
}
}
return r;
}
private Window readOver() {
if (readIf("OVER")) {
read(OPEN_PAREN);
read(CLOSE_PAREN);
return new Window();
}
return null;
}
private void setModeAggOrder(Aggregate r, Expression expr) {
ArrayList<SelectOrderBy> orderList = new ArrayList<>(1);
SelectOrderBy order = new SelectOrderBy();
......@@ -3027,7 +3043,13 @@ public class Parser {
Expression filterCondition = readFilterCondition();
Expression[] list = params.toArray(new Expression[0]);
JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, distinct, filterCondition);
currentSelect.setGroupQuery();
Window over = readOver();
if (over != null) {
agg.setOverCondition(over);
currentSelect.setWindowQuery();
} else {
currentSelect.setGroupQuery();
}
return agg;
}
......
......@@ -8,6 +8,7 @@ package org.h2.command.dml;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import org.h2.api.ErrorCode;
import org.h2.api.Trigger;
......@@ -105,6 +106,7 @@ public class Select extends Query {
private int havingIndex;
private boolean isGroupQuery, isGroupSortedQuery;
private boolean isWindowQuery;
private boolean isForUpdate, isForUpdateMvcc;
private double cost;
private boolean isQuickAggregateQuery, isDistinctQuery;
......@@ -160,6 +162,13 @@ public class Select extends Query {
isGroupQuery = true;
}
/**
* Called if this query contains window functions.
*/
public void setWindowQuery() {
isWindowQuery = true;
}
public void setGroupBy(ArrayList<Expression> group) {
this.group = group;
}
......@@ -168,8 +177,8 @@ public class Select extends Query {
return group;
}
public SelectGroups getGroupDataIfCurrent() {
return groupData != null && groupData.isCurrentGroup() ? groupData : null;
public SelectGroups getGroupDataIfCurrent(boolean forAggregate) {
return groupData != null && (forAggregate || !isWindowQuery) && groupData.isCurrentGroup() ? groupData : null;
}
@Override
......@@ -346,11 +355,12 @@ public class Select extends Query {
return condition == null || condition.getBooleanValue(session);
}
private void queryGroup(int columnCount, LocalResult result, long offset, boolean quickOffset) {
private void queryWindow(int columnCount, LocalResult result, long offset, boolean quickOffset) {
if (groupData == null) {
groupData = new SelectGroups(session, expressions, groupIndex);
}
groupData.reset();
HashMap<ValueArray, ArrayList<Row>> rows = new HashMap<>();
try {
int rowNumber = 0;
setCurrentRowNumber(0);
......@@ -359,13 +369,14 @@ public class Select extends Query {
setCurrentRowNumber(rowNumber + 1);
if (isConditionMet()) {
rowNumber++;
groupData.nextSource();
for (int i = 0; i < columnCount; i++) {
if (groupByExpression == null || !groupByExpression[i]) {
Expression expr = expressions.get(i);
expr.updateAggregate(session);
}
ValueArray key = groupData.nextSource();
ArrayList<Row> groupRows = rows.get(key);
if (groupRows == null) {
groupRows = Utils.newSmallArrayList();
rows.put(key, groupRows);
}
groupRows.add(topTableFilter.get());
updateAgg(columnCount);
if (sampleSize > 0 && rowNumber >= sampleSize) {
break;
}
......@@ -373,33 +384,80 @@ public class Select extends Query {
}
groupData.done();
for (ValueArray currentGroupsKey; (currentGroupsKey = groupData.next()) != null;) {
Value[] keyValues = currentGroupsKey.getList();
Value[] row = new Value[columnCount];
for (int j = 0; groupIndex != null && j < groupIndex.length; j++) {
row[groupIndex[j]] = keyValues[j];
for (Row originalRow : rows.get(currentGroupsKey)) {
topTableFilter.set(originalRow);
offset = processGroupedRow(columnCount, result, offset, quickOffset, currentGroupsKey);
}
for (int j = 0; j < columnCount; j++) {
if (groupByExpression != null && groupByExpression[j]) {
continue;
}
} finally {
groupData.reset();
}
}
private void queryGroup(int columnCount, LocalResult result, long offset, boolean quickOffset) {
if (groupData == null) {
groupData = new SelectGroups(session, expressions, groupIndex);
}
groupData.reset();
try {
int rowNumber = 0;
setCurrentRowNumber(0);
int sampleSize = getSampleSizeValue(session);
while (topTableFilter.next()) {
setCurrentRowNumber(rowNumber + 1);
if (isConditionMet()) {
rowNumber++;
groupData.nextSource();
updateAgg(columnCount);
if (sampleSize > 0 && rowNumber >= sampleSize) {
break;
}
Expression expr = expressions.get(j);
row[j] = expr.getValue(session);
}
if (isHavingNullOrFalse(row)) {
continue;
}
if (quickOffset && offset > 0) {
offset--;
continue;
}
row = keepOnlyDistinct(row, columnCount);
result.addRow(row);
}
groupData.done();
for (ValueArray currentGroupsKey; (currentGroupsKey = groupData.next()) != null;) {
offset = processGroupedRow(columnCount, result, offset, quickOffset, currentGroupsKey);
}
} finally {
groupData.reset();
}
}
private void updateAgg(int columnCount) {
for (int i = 0; i < columnCount; i++) {
if (groupByExpression == null || !groupByExpression[i]) {
Expression expr = expressions.get(i);
expr.updateAggregate(session);
}
}
}
private long processGroupedRow(int columnCount, LocalResult result, long offset, boolean quickOffset,
ValueArray currentGroupsKey) {
Value[] keyValues = currentGroupsKey.getList();
Value[] row = new Value[columnCount];
for (int j = 0; groupIndex != null && j < groupIndex.length; j++) {
row[groupIndex[j]] = keyValues[j];
}
for (int j = 0; j < columnCount; j++) {
if (groupByExpression != null && groupByExpression[j]) {
continue;
}
Expression expr = expressions.get(j);
row[j] = expr.getValue(session);
}
if (isHavingNullOrFalse(row)) {
return offset;
}
if (quickOffset && offset > 0) {
offset--;
return offset;
}
row = keepOnlyDistinct(row, columnCount);
result.addRow(row);
return offset;
}
/**
* Get the index that matches the ORDER BY list, if one exists. This is to
* avoid running a separate ORDER BY if an index can be used. This is
......@@ -663,7 +721,7 @@ public class Select extends Query {
result = createLocalResult(result);
result.setDistinct(distinctIndexes);
}
if (isGroupQuery && !isGroupSortedQuery) {
if (isWindowQuery || isGroupQuery && !isGroupSortedQuery) {
result = createLocalResult(result);
}
if (!lazy && (limitRows >= 0 || offset > 0)) {
......@@ -697,6 +755,8 @@ public class Select extends Query {
try {
if (isQuickAggregateQuery) {
queryQuick(columnCount, to, quickOffset && offset > 0);
} else if (isWindowQuery) {
queryWindow(columnCount, result, offset, quickOffset);
} else if (isGroupQuery) {
if (isGroupSortedQuery) {
lazyResult = queryGroupSorted(columnCount, to, offset, quickOffset);
......
......@@ -167,8 +167,10 @@ public final class SelectGroups {
/**
* Invoked for each source row to evaluate group key and setup all necessary
* data for aggregates.
*
* @return key of the current group
*/
public void nextSource() {
public ValueArray nextSource() {
if (groupIndex == null) {
currentGroupsKey = defaultGroup;
} else {
......@@ -188,6 +190,7 @@ public final class SelectGroups {
}
currentGroupByExprData = values;
currentGroupRowId++;
return currentGroupsKey;
}
/**
......
......@@ -159,7 +159,7 @@ public class ExpressionColumn extends Expression {
if (select == null) {
throw DbException.get(ErrorCode.MUST_GROUP_BY_COLUMN_1, getSQL());
}
SelectGroups groupData = select.getGroupDataIfCurrent();
SelectGroups groupData = select.getGroupDataIfCurrent(false);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
......@@ -178,7 +178,7 @@ public class ExpressionColumn extends Expression {
public Value getValue(Session session) {
Select select = columnResolver.getSelect();
if (select != null) {
SelectGroups groupData = select.getGroupDataIfCurrent();
SelectGroups groupData = select.getGroupDataIfCurrent(false);
if (groupData != null) {
Value v = (Value) groupData.getCurrentGroupExprData(this);
if (v != null) {
......
......@@ -168,6 +168,8 @@ public class Aggregate extends Expression {
private Expression filterCondition;
private Window over;
/**
* Create a new aggregate object.
*
......@@ -264,6 +266,15 @@ public class Aggregate extends Expression {
this.filterCondition = filterCondition;
}
/**
* Sets the OVER condition.
*
* @param over OVER condition
*/
public void setOverCondition(Window over) {
this.over = over;
}
private SortOrder initOrder(Session session) {
int size = orderByList.size();
int[] index = new int[size];
......@@ -296,7 +307,7 @@ public class Aggregate extends Expression {
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent();
SelectGroups groupData = select.getGroupDataIfCurrent(true);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
......@@ -380,7 +391,7 @@ public class Aggregate extends Expression {
DbException.throwInternalError("type=" + type);
}
}
SelectGroups groupData = select.getGroupDataIfCurrent();
SelectGroups groupData = select.getGroupDataIfCurrent(true);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
......@@ -624,6 +635,9 @@ public class Aggregate extends Expression {
if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
if (over != null) {
buff.append(" OVER()");
}
return buff.toString();
}
......@@ -645,6 +659,9 @@ public class Aggregate extends Expression {
if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
if (over != null) {
buff.append(" OVER()");
}
return buff.toString();
}
......@@ -723,6 +740,9 @@ public class Aggregate extends Expression {
if (filterCondition != null) {
text += " FILTER (WHERE " + filterCondition.getSQL() + ')';
}
if (over != null) {
text += " OVER()";
}
return text;
}
......
......@@ -36,6 +36,7 @@ public class JavaAggregate extends Expression {
private int[] argTypes;
private final boolean distinct;
private Expression filterCondition;
private Window over;
private int dataType;
private Connection userConnection;
private int lastGroupRowId;
......@@ -49,6 +50,15 @@ public class JavaAggregate extends Expression {
this.filterCondition = filterCondition;
}
/**
* Sets the OVER condition.
*
* @param over OVER condition
*/
public void setOverCondition(Window over) {
this.over = over;
}
@Override
public int getCost() {
int cost = 5;
......@@ -88,6 +98,9 @@ public class JavaAggregate extends Expression {
if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
if (over != null) {
buff.append(" OVER()");
}
return buff.toString();
}
......@@ -169,7 +182,7 @@ public class JavaAggregate extends Expression {
@Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent();
SelectGroups groupData = select.getGroupDataIfCurrent(true);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
......@@ -210,7 +223,7 @@ public class JavaAggregate extends Expression {
@Override
public void updateAggregate(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent();
SelectGroups groupData = select.getGroupDataIfCurrent(true);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
/**
* Window clause.
*/
public final class Window {
}
......@@ -67,3 +67,34 @@ select array_agg(distinct v order by v desc) from test;
drop table test;
> ok
CREATE TABLE TEST (ID INT PRIMARY KEY, NAME VARCHAR);
> ok
INSERT INTO TEST VALUES (1, 'a'), (2, 'a'), (3, 'b'), (4, 'c'), (5, 'c'), (6, 'c');
> update count: 6
SELECT ARRAY_AGG(ID), NAME FROM TEST;
> exception MUST_GROUP_BY_COLUMN_1
SELECT ARRAY_AGG(ID), NAME FROM TEST GROUP BY NAME;
> ARRAY_AGG(ID) NAME
> ------------- ----
> (1, 2) a
> (3) b
> (4, 5, 6) c
> rows: 3
SELECT ARRAY_AGG(ID) OVER (), NAME FROM TEST;
> ARRAY_AGG(ID) OVER() NAME
> -------------------- ----
> (1, 2, 3, 4, 5, 6) a
> (1, 2, 3, 4, 5, 6) a
> (1, 2, 3, 4, 5, 6) b
> (1, 2, 3, 4, 5, 6) c
> (1, 2, 3, 4, 5, 6) c
> (1, 2, 3, 4, 5, 6) c
> rows: 6
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论