提交 cbb47698 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Add OVER (PARTITION BY *) clause to aggregates

上级 5a5f919f
......@@ -3041,8 +3041,17 @@ public class Parser {
}
if (readIf("OVER")) {
read(OPEN_PAREN);
ArrayList<Expression> partitionBy = null;
if (readIf("PARTITION")) {
read("BY");
partitionBy = Utils.newSmallArrayList();
do {
Expression expr = readExpression();
partitionBy.add(expr);
} while (readIf(COMMA));
}
read(CLOSE_PAREN);
aggregate.setOverCondition(new Window());
aggregate.setOverCondition(new Window(partitionBy));
currentSelect.setWindowQuery();
} else {
currentSelect.setGroupQuery();
......
......@@ -6,6 +6,7 @@
package org.h2.expression.aggregate;
import org.h2.expression.Expression;
import org.h2.table.ColumnResolver;
/**
* A base class for aggregates.
......@@ -36,4 +37,14 @@ public abstract class AbstractAggregate extends Expression {
this.over = over;
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
if (filterCondition != null) {
filterCondition.mapColumns(resolver, level);
}
if (over != null) {
over.mapColumns(resolver, level);
}
}
}
......@@ -29,6 +29,7 @@ import org.h2.table.Table;
import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder;
import org.h2.util.StringUtils;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueArray;
......@@ -303,11 +304,7 @@ public class Aggregate extends AbstractAggregate {
return;
}
}
AggregateData data = (AggregateData) groupData.getCurrentGroupExprData(this);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data);
}
AggregateData data = getData(session, groupData);
Value v = on == null ? null : on.getValue(session);
if (type == AggregateType.GROUP_CONCAT) {
if (v != ValueNull.INSTANCE) {
......@@ -373,11 +370,7 @@ public class Aggregate extends AbstractAggregate {
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
AggregateData data = (AggregateData) groupData.getCurrentGroupExprData(this);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data);
}
AggregateData data = getData(session, groupData);
switch (type) {
case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray();
......@@ -433,6 +426,31 @@ public class Aggregate extends AbstractAggregate {
}
}
private AggregateData getData(Session session, SelectGroups groupData) {
AggregateData data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateData> map = (ValueHashMap<AggregateData>) groupData.getCurrentGroupExprData(this);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
data = AggregateData.create(type);
map.put(key, data);
}
} else {
data = (AggregateData) groupData.getCurrentGroupExprData(this);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
@Override
public int getType() {
return dataType;
......@@ -451,9 +469,7 @@ public class Aggregate extends AbstractAggregate {
if (groupConcatSeparator != null) {
groupConcatSeparator.mapColumns(resolver, level);
}
if (filterCondition != null) {
filterCondition.mapColumns(resolver, level);
}
super.mapColumns(resolver, level);
}
@Override
......@@ -614,7 +630,7 @@ public class Aggregate extends AbstractAggregate {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
if (over != null) {
buff.append(" OVER()");
buff.append(' ').append(over.getSQL());
}
return buff.toString();
}
......@@ -638,7 +654,7 @@ public class Aggregate extends AbstractAggregate {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
if (over != null) {
buff.append(" OVER()");
buff.append(' ').append(over.getSQL());
}
return buff.toString();
}
......@@ -719,7 +735,7 @@ public class Aggregate extends AbstractAggregate {
text += " FILTER (WHERE " + filterCondition.getSQL() + ')';
}
if (over != null) {
text += " OVER()";
text += ' ' + over.getSQL();
}
return text;
}
......
......@@ -20,6 +20,7 @@ import org.h2.message.DbException;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueArray;
......@@ -86,7 +87,7 @@ public class JavaAggregate extends AbstractAggregate {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
if (over != null) {
buff.append(" OVER()");
buff.append(' ').append(over.getSQL());
}
return buff.toString();
}
......@@ -123,9 +124,7 @@ public class JavaAggregate extends AbstractAggregate {
for (Expression arg : args) {
arg.mapColumns(resolver, level);
}
if (filterCondition != null) {
filterCondition.mapColumns(resolver, level);
}
super.mapColumns(resolver, level);
}
@Override
......@@ -177,7 +176,7 @@ public class JavaAggregate extends AbstractAggregate {
Aggregate agg;
if (distinct) {
agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this);
AggregateDataCollecting data = getDataDistinct(session, groupData, true);
if (data != null) {
for (Value value : data.values) {
if (args.length == 1) {
......@@ -193,7 +192,7 @@ public class JavaAggregate extends AbstractAggregate {
}
}
} else {
agg = (Aggregate) groupData.getCurrentGroupExprData(this);
agg = getData(session, groupData, true);
if (agg == null) {
agg = getInstance();
}
......@@ -231,11 +230,7 @@ public class JavaAggregate extends AbstractAggregate {
try {
if (distinct) {
AggregateDataCollecting data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this);
if (data == null) {
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data);
}
AggregateDataCollecting data = getDataDistinct(session, groupData, false);
Value[] argValues = new Value[args.length];
Value arg = null;
for (int i = 0, len = args.length; i < len; i++) {
......@@ -245,11 +240,7 @@ public class JavaAggregate extends AbstractAggregate {
}
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else {
Aggregate agg = (Aggregate) groupData.getCurrentGroupExprData(this);
if (agg == null) {
agg = getInstance();
groupData.setCurrentGroupExprData(this, agg);
}
Aggregate agg = getData(session, groupData, false);
Object[] argValues = new Object[args.length];
Object arg = null;
for (int i = 0, len = args.length; i < len; i++) {
......@@ -265,4 +256,73 @@ public class JavaAggregate extends AbstractAggregate {
}
}
private Aggregate getData(Session session, SelectGroups groupData, boolean ifExists) throws SQLException {
Aggregate data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Aggregate> map = (ValueHashMap<Aggregate>) groupData.getCurrentGroupExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
map.put(key, data);
}
} else {
data = (Aggregate) groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
private AggregateDataCollecting getDataDistinct(Session session, SelectGroups groupData, boolean ifExists) {
AggregateDataCollecting data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateDataCollecting> map =
(ValueHashMap<AggregateDataCollecting>) groupData.getCurrentGroupExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
map.put(key, data);
}
} else {
data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
}
......@@ -5,9 +5,91 @@
*/
package org.h2.expression.aggregate;
import java.util.ArrayList;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.table.ColumnResolver;
import org.h2.util.StringUtils;
import org.h2.value.Value;
import org.h2.value.ValueArray;
/**
* Window clause.
*/
public final class Window {
private final ArrayList<Expression> partitionBy;
/**
* Creates a new instance of window clause.
*
* @param partitionBy
* PARTITION BY clause, or null
*/
public Window(ArrayList<Expression> partitionBy) {
this.partitionBy = partitionBy;
}
/**
* Map the columns of the resolver to expression columns.
*
* @param resolver
* the column resolver
* @param level
* the subquery nesting level
*/
public void mapColumns(ColumnResolver resolver, int level) {
if (partitionBy != null) {
for (Expression e : partitionBy) {
e.mapColumns(resolver, level);
}
}
}
/**
* Returns the key for the current group.
*
* @param session
* session
* @return key for the current group, or null
*/
public ValueArray getCurrentKey(Session session) {
if (partitionBy == null) {
return null;
}
int len = partitionBy.size();
Value[] keyValues = new Value[len];
// update group
for (int i = 0; i < len; i++) {
Expression expr = partitionBy.get(i);
keyValues[i] = expr.getValue(session);
}
return ValueArray.get(keyValues);
}
/**
* Returns SQL representation.
*
* @return SQL representation.
*/
public String getSQL() {
if (partitionBy == null) {
return "OVER ()";
}
StringBuilder builder = new StringBuilder().append("OVER (PARTITION BY ");
for (int i = 0; i < partitionBy.size(); i++) {
if (i > 0) {
builder.append(", ");
}
builder.append(StringUtils.unEnclose(partitionBy.get(i).getSQL()));
}
return builder.append(')').toString();
}
@Override
public String toString() {
return getSQL();
}
}
......@@ -757,6 +757,16 @@ public class TestFunctions extends TestDb implements AggregateFunction {
"SELECT SIMPLE_MEDIAN(X) FILTER (WHERE X > 2) FROM SYSTEM_RANGE(1, 9)");
rs.next();
assertEquals("6", rs.getString(1));
rs = stat.executeQuery("SELECT SIMPLE_MEDIAN(X) OVER () FROM SYSTEM_RANGE(1, 9)");
for (int i = 1; i < 9; i++) {
assertTrue(rs.next());
assertEquals("5", rs.getString(1));
}
rs = stat.executeQuery("SELECT SIMPLE_MEDIAN(X) OVER (PARTITION BY X) FROM SYSTEM_RANGE(1, 9)");
for (int i = 1; i < 9; i++) {
assertTrue(rs.next());
assertEquals(Integer.toString(i), rs.getString(1));
}
conn.close();
if (config.memory) {
......
......@@ -86,8 +86,8 @@ SELECT ARRAY_AGG(ID), NAME FROM TEST GROUP BY NAME;
> rows: 3
SELECT ARRAY_AGG(ID) OVER (), NAME FROM TEST;
> ARRAY_AGG(ID) OVER() NAME
> -------------------- ----
> ARRAY_AGG(ID) OVER () NAME
> --------------------- ----
> (1, 2, 3, 4, 5, 6) a
> (1, 2, 3, 4, 5, 6) a
> (1, 2, 3, 4, 5, 6) b
......@@ -96,5 +96,16 @@ SELECT ARRAY_AGG(ID) OVER (), NAME FROM TEST;
> (1, 2, 3, 4, 5, 6) c
> rows: 6
SELECT ARRAY_AGG(ID) OVER (PARTITION BY NAME), NAME FROM TEST;
> ARRAY_AGG(ID) OVER (PARTITION BY NAME) NAME
> -------------------------------------- ----
> (1, 2) a
> (1, 2) a
> (3) b
> (4, 5, 6) c
> (4, 5, 6) c
> (4, 5, 6) c
> rows: 6
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论