提交 cbb47698 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Add OVER (PARTITION BY *) clause to aggregates

上级 5a5f919f
...@@ -3041,8 +3041,17 @@ public class Parser { ...@@ -3041,8 +3041,17 @@ public class Parser {
} }
if (readIf("OVER")) { if (readIf("OVER")) {
read(OPEN_PAREN); read(OPEN_PAREN);
ArrayList<Expression> partitionBy = null;
if (readIf("PARTITION")) {
read("BY");
partitionBy = Utils.newSmallArrayList();
do {
Expression expr = readExpression();
partitionBy.add(expr);
} while (readIf(COMMA));
}
read(CLOSE_PAREN); read(CLOSE_PAREN);
aggregate.setOverCondition(new Window()); aggregate.setOverCondition(new Window(partitionBy));
currentSelect.setWindowQuery(); currentSelect.setWindowQuery();
} else { } else {
currentSelect.setGroupQuery(); currentSelect.setGroupQuery();
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
package org.h2.expression.aggregate; package org.h2.expression.aggregate;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.table.ColumnResolver;
/** /**
* A base class for aggregates. * A base class for aggregates.
...@@ -36,4 +37,14 @@ public abstract class AbstractAggregate extends Expression { ...@@ -36,4 +37,14 @@ public abstract class AbstractAggregate extends Expression {
this.over = over; this.over = over;
} }
@Override
public void mapColumns(ColumnResolver resolver, int level) {
if (filterCondition != null) {
filterCondition.mapColumns(resolver, level);
}
if (over != null) {
over.mapColumns(resolver, level);
}
}
} }
...@@ -29,6 +29,7 @@ import org.h2.table.Table; ...@@ -29,6 +29,7 @@ import org.h2.table.Table;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.util.StringUtils; import org.h2.util.StringUtils;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -303,11 +304,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -303,11 +304,7 @@ public class Aggregate extends AbstractAggregate {
return; return;
} }
} }
AggregateData data = (AggregateData) groupData.getCurrentGroupExprData(this); AggregateData data = getData(session, groupData);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data);
}
Value v = on == null ? null : on.getValue(session); Value v = on == null ? null : on.getValue(session);
if (type == AggregateType.GROUP_CONCAT) { if (type == AggregateType.GROUP_CONCAT) {
if (v != ValueNull.INSTANCE) { if (v != ValueNull.INSTANCE) {
...@@ -373,11 +370,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -373,11 +370,7 @@ public class Aggregate extends AbstractAggregate {
if (groupData == null) { if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL()); throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
AggregateData data = (AggregateData) groupData.getCurrentGroupExprData(this); AggregateData data = getData(session, groupData);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data);
}
switch (type) { switch (type) {
case GROUP_CONCAT: { case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray(); Value[] array = ((AggregateDataCollecting) data).getArray();
...@@ -433,6 +426,31 @@ public class Aggregate extends AbstractAggregate { ...@@ -433,6 +426,31 @@ public class Aggregate extends AbstractAggregate {
} }
} }
private AggregateData getData(Session session, SelectGroups groupData) {
AggregateData data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateData> map = (ValueHashMap<AggregateData>) groupData.getCurrentGroupExprData(this);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
data = AggregateData.create(type);
map.put(key, data);
}
} else {
data = (AggregateData) groupData.getCurrentGroupExprData(this);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
@Override @Override
public int getType() { public int getType() {
return dataType; return dataType;
...@@ -451,9 +469,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -451,9 +469,7 @@ public class Aggregate extends AbstractAggregate {
if (groupConcatSeparator != null) { if (groupConcatSeparator != null) {
groupConcatSeparator.mapColumns(resolver, level); groupConcatSeparator.mapColumns(resolver, level);
} }
if (filterCondition != null) { super.mapColumns(resolver, level);
filterCondition.mapColumns(resolver, level);
}
} }
@Override @Override
...@@ -614,7 +630,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -614,7 +630,7 @@ public class Aggregate extends AbstractAggregate {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
} }
if (over != null) { if (over != null) {
buff.append(" OVER()"); buff.append(' ').append(over.getSQL());
} }
return buff.toString(); return buff.toString();
} }
...@@ -638,7 +654,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -638,7 +654,7 @@ public class Aggregate extends AbstractAggregate {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
} }
if (over != null) { if (over != null) {
buff.append(" OVER()"); buff.append(' ').append(over.getSQL());
} }
return buff.toString(); return buff.toString();
} }
...@@ -719,7 +735,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -719,7 +735,7 @@ public class Aggregate extends AbstractAggregate {
text += " FILTER (WHERE " + filterCondition.getSQL() + ')'; text += " FILTER (WHERE " + filterCondition.getSQL() + ')';
} }
if (over != null) { if (over != null) {
text += " OVER()"; text += ' ' + over.getSQL();
} }
return text; return text;
} }
......
...@@ -20,6 +20,7 @@ import org.h2.message.DbException; ...@@ -20,6 +20,7 @@ import org.h2.message.DbException;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -86,7 +87,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -86,7 +87,7 @@ public class JavaAggregate extends AbstractAggregate {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
} }
if (over != null) { if (over != null) {
buff.append(" OVER()"); buff.append(' ').append(over.getSQL());
} }
return buff.toString(); return buff.toString();
} }
...@@ -123,9 +124,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -123,9 +124,7 @@ public class JavaAggregate extends AbstractAggregate {
for (Expression arg : args) { for (Expression arg : args) {
arg.mapColumns(resolver, level); arg.mapColumns(resolver, level);
} }
if (filterCondition != null) { super.mapColumns(resolver, level);
filterCondition.mapColumns(resolver, level);
}
} }
@Override @Override
...@@ -177,7 +176,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -177,7 +176,7 @@ public class JavaAggregate extends AbstractAggregate {
Aggregate agg; Aggregate agg;
if (distinct) { if (distinct) {
agg = getInstance(); agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this); AggregateDataCollecting data = getDataDistinct(session, groupData, true);
if (data != null) { if (data != null) {
for (Value value : data.values) { for (Value value : data.values) {
if (args.length == 1) { if (args.length == 1) {
...@@ -193,7 +192,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -193,7 +192,7 @@ public class JavaAggregate extends AbstractAggregate {
} }
} }
} else { } else {
agg = (Aggregate) groupData.getCurrentGroupExprData(this); agg = getData(session, groupData, true);
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
} }
...@@ -231,11 +230,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -231,11 +230,7 @@ public class JavaAggregate extends AbstractAggregate {
try { try {
if (distinct) { if (distinct) {
AggregateDataCollecting data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this); AggregateDataCollecting data = getDataDistinct(session, groupData, false);
if (data == null) {
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data);
}
Value[] argValues = new Value[args.length]; Value[] argValues = new Value[args.length];
Value arg = null; Value arg = null;
for (int i = 0, len = args.length; i < len; i++) { for (int i = 0, len = args.length; i < len; i++) {
...@@ -245,11 +240,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -245,11 +240,7 @@ public class JavaAggregate extends AbstractAggregate {
} }
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues)); data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else { } else {
Aggregate agg = (Aggregate) groupData.getCurrentGroupExprData(this); Aggregate agg = getData(session, groupData, false);
if (agg == null) {
agg = getInstance();
groupData.setCurrentGroupExprData(this, agg);
}
Object[] argValues = new Object[args.length]; Object[] argValues = new Object[args.length];
Object arg = null; Object arg = null;
for (int i = 0, len = args.length; i < len; i++) { for (int i = 0, len = args.length; i < len; i++) {
...@@ -265,4 +256,73 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -265,4 +256,73 @@ public class JavaAggregate extends AbstractAggregate {
} }
} }
private Aggregate getData(Session session, SelectGroups groupData, boolean ifExists) throws SQLException {
Aggregate data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Aggregate> map = (ValueHashMap<Aggregate>) groupData.getCurrentGroupExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
map.put(key, data);
}
} else {
data = (Aggregate) groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
private AggregateDataCollecting getDataDistinct(Session session, SelectGroups groupData, boolean ifExists) {
AggregateDataCollecting data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateDataCollecting> map =
(ValueHashMap<AggregateDataCollecting>) groupData.getCurrentGroupExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
map.put(key, data);
}
} else {
data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
} }
...@@ -5,9 +5,91 @@ ...@@ -5,9 +5,91 @@
*/ */
package org.h2.expression.aggregate; package org.h2.expression.aggregate;
import java.util.ArrayList;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.table.ColumnResolver;
import org.h2.util.StringUtils;
import org.h2.value.Value;
import org.h2.value.ValueArray;
/** /**
* Window clause. * Window clause.
*/ */
public final class Window { public final class Window {
private final ArrayList<Expression> partitionBy;
/**
* Creates a new instance of window clause.
*
* @param partitionBy
* PARTITION BY clause, or null
*/
public Window(ArrayList<Expression> partitionBy) {
this.partitionBy = partitionBy;
}
/**
* Map the columns of the resolver to expression columns.
*
* @param resolver
* the column resolver
* @param level
* the subquery nesting level
*/
public void mapColumns(ColumnResolver resolver, int level) {
if (partitionBy != null) {
for (Expression e : partitionBy) {
e.mapColumns(resolver, level);
}
}
}
/**
* Returns the key for the current group.
*
* @param session
* session
* @return key for the current group, or null
*/
public ValueArray getCurrentKey(Session session) {
if (partitionBy == null) {
return null;
}
int len = partitionBy.size();
Value[] keyValues = new Value[len];
// update group
for (int i = 0; i < len; i++) {
Expression expr = partitionBy.get(i);
keyValues[i] = expr.getValue(session);
}
return ValueArray.get(keyValues);
}
/**
* Returns SQL representation.
*
* @return SQL representation.
*/
public String getSQL() {
if (partitionBy == null) {
return "OVER ()";
}
StringBuilder builder = new StringBuilder().append("OVER (PARTITION BY ");
for (int i = 0; i < partitionBy.size(); i++) {
if (i > 0) {
builder.append(", ");
}
builder.append(StringUtils.unEnclose(partitionBy.get(i).getSQL()));
}
return builder.append(')').toString();
}
@Override
public String toString() {
return getSQL();
}
} }
...@@ -757,6 +757,16 @@ public class TestFunctions extends TestDb implements AggregateFunction { ...@@ -757,6 +757,16 @@ public class TestFunctions extends TestDb implements AggregateFunction {
"SELECT SIMPLE_MEDIAN(X) FILTER (WHERE X > 2) FROM SYSTEM_RANGE(1, 9)"); "SELECT SIMPLE_MEDIAN(X) FILTER (WHERE X > 2) FROM SYSTEM_RANGE(1, 9)");
rs.next(); rs.next();
assertEquals("6", rs.getString(1)); assertEquals("6", rs.getString(1));
rs = stat.executeQuery("SELECT SIMPLE_MEDIAN(X) OVER () FROM SYSTEM_RANGE(1, 9)");
for (int i = 1; i < 9; i++) {
assertTrue(rs.next());
assertEquals("5", rs.getString(1));
}
rs = stat.executeQuery("SELECT SIMPLE_MEDIAN(X) OVER (PARTITION BY X) FROM SYSTEM_RANGE(1, 9)");
for (int i = 1; i < 9; i++) {
assertTrue(rs.next());
assertEquals(Integer.toString(i), rs.getString(1));
}
conn.close(); conn.close();
if (config.memory) { if (config.memory) {
......
...@@ -86,14 +86,25 @@ SELECT ARRAY_AGG(ID), NAME FROM TEST GROUP BY NAME; ...@@ -86,14 +86,25 @@ SELECT ARRAY_AGG(ID), NAME FROM TEST GROUP BY NAME;
> rows: 3 > rows: 3
SELECT ARRAY_AGG(ID) OVER (), NAME FROM TEST; SELECT ARRAY_AGG(ID) OVER (), NAME FROM TEST;
> ARRAY_AGG(ID) OVER() NAME > ARRAY_AGG(ID) OVER () NAME
> -------------------- ---- > --------------------- ----
> (1, 2, 3, 4, 5, 6) a > (1, 2, 3, 4, 5, 6) a
> (1, 2, 3, 4, 5, 6) a > (1, 2, 3, 4, 5, 6) a
> (1, 2, 3, 4, 5, 6) b > (1, 2, 3, 4, 5, 6) b
> (1, 2, 3, 4, 5, 6) c > (1, 2, 3, 4, 5, 6) c
> (1, 2, 3, 4, 5, 6) c > (1, 2, 3, 4, 5, 6) c
> (1, 2, 3, 4, 5, 6) c > (1, 2, 3, 4, 5, 6) c
> rows: 6
SELECT ARRAY_AGG(ID) OVER (PARTITION BY NAME), NAME FROM TEST;
> ARRAY_AGG(ID) OVER (PARTITION BY NAME) NAME
> -------------------------------------- ----
> (1, 2) a
> (1, 2) a
> (3) b
> (4, 5, 6) c
> (4, 5, 6) c
> (4, 5, 6) c
> rows: 6 > rows: 6
DROP TABLE TEST; DROP TABLE TEST;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论