提交 ee65a1bb authored 作者: Alex Nordlund's avatar Alex Nordlund

Adds ARRAY_AGG function.

上级 606eebb5
...@@ -2949,6 +2949,19 @@ Aggregates are only allowed in select statements. ...@@ -2949,6 +2949,19 @@ Aggregates are only allowed in select statements.
GROUP_CONCAT(NAME ORDER BY ID SEPARATOR ', ') GROUP_CONCAT(NAME ORDER BY ID SEPARATOR ', ')
" "
"Functions (Aggregate)","ARRAY_AGG","
ARRAY_AGG ( [ DISTINCT ] string
[ ORDER BY { expression [ ASC | DESC ] } [,...] ] )
[ FILTER ( WHERE expression ) ]
","
Aggregate the value into an array.
This method returns an array.
If no rows are selected, the result is NULL.
Aggregates are only allowed in select statements.
","
ARRAY_AGG(NAME ORDER BY ID)
"
"Functions (Aggregate)","MAX"," "Functions (Aggregate)","MAX","
MAX(value) [ FILTER ( WHERE expression ) ] MAX(value) [ FILTER ( WHERE expression ) ]
"," ","
......
...@@ -2651,6 +2651,15 @@ public class Parser { ...@@ -2651,6 +2651,15 @@ public class Parser {
} else { } else {
r = null; r = null;
} }
} else if (aggregateType == AggregateType.ARRAY_AGG) {
boolean distinct = readIf("DISTINCT");
r = new Aggregate(AggregateType.ARRAY_AGG,
readExpression(), currentSelect, distinct);
if (readIf("ORDER")) {
read("BY");
r.setArrayAggOrder(parseSimpleOrderList());
}
} else { } else {
boolean distinct = readIf("DISTINCT"); boolean distinct = readIf("DISTINCT");
r = new Aggregate(aggregateType, readExpression(), currentSelect, r = new Aggregate(aggregateType, readExpression(), currentSelect,
......
...@@ -128,7 +128,11 @@ public class Aggregate extends Expression { ...@@ -128,7 +128,11 @@ public class Aggregate extends Expression {
/** /**
* The aggregate type for MEDIAN(expression). * The aggregate type for MEDIAN(expression).
*/ */
MEDIAN MEDIAN,
/**
* The aggregate type for ARRAY_AGG(expression).
*/
ARRAY_AGG
} }
private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(25); private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(25);
...@@ -140,7 +144,9 @@ public class Aggregate extends Expression { ...@@ -140,7 +144,9 @@ public class Aggregate extends Expression {
private Expression on; private Expression on;
private Expression groupConcatSeparator; private Expression groupConcatSeparator;
private ArrayList<SelectOrderBy> groupConcatOrderList; private ArrayList<SelectOrderBy> groupConcatOrderList;
private ArrayList<SelectOrderBy> arrayAggOrderList;
private SortOrder groupConcatSort; private SortOrder groupConcatSort;
private SortOrder arrayOrderSort;
private int dataType, scale; private int dataType, scale;
private long precision; private long precision;
private int displaySize; private int displaySize;
...@@ -195,6 +201,7 @@ public class Aggregate extends Expression { ...@@ -195,6 +201,7 @@ public class Aggregate extends Expression {
addAggregate("BIT_OR", AggregateType.BIT_OR); addAggregate("BIT_OR", AggregateType.BIT_OR);
addAggregate("BIT_AND", AggregateType.BIT_AND); addAggregate("BIT_AND", AggregateType.BIT_AND);
addAggregate("MEDIAN", AggregateType.MEDIAN); addAggregate("MEDIAN", AggregateType.MEDIAN);
addAggregate("ARRAY_AGG", AggregateType.ARRAY_AGG);
} }
private static void addAggregate(String name, AggregateType type) { private static void addAggregate(String name, AggregateType type) {
...@@ -221,6 +228,15 @@ public class Aggregate extends Expression { ...@@ -221,6 +228,15 @@ public class Aggregate extends Expression {
this.groupConcatOrderList = orderBy; this.groupConcatOrderList = orderBy;
} }
/**
* Set the order for ARRAY_AGG() aggregate.
*
* @param orderBy the order by list
*/
public void setArrayAggOrder(ArrayList<SelectOrderBy> orderBy) {
this.arrayAggOrderList = orderBy;
}
/** /**
* Set the separator for the GROUP_CONCAT() aggregate. * Set the separator for the GROUP_CONCAT() aggregate.
* *
...@@ -239,12 +255,12 @@ public class Aggregate extends Expression { ...@@ -239,12 +255,12 @@ public class Aggregate extends Expression {
this.filterCondition = filterCondition; this.filterCondition = filterCondition;
} }
private SortOrder initOrder(Session session) { private SortOrder initOrder(ArrayList<SelectOrderBy> orderList, Session session) {
int size = groupConcatOrderList.size(); int size = orderList.size();
int[] index = new int[size]; int[] index = new int[size];
int[] sortType = new int[size]; int[] sortType = new int[size];
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
SelectOrderBy o = groupConcatOrderList.get(i); SelectOrderBy o = orderList.get(i);
index[i] = i + 1; index[i] = i + 1;
int order = o.descending ? SortOrder.DESCENDING : SortOrder.ASCENDING; int order = o.descending ? SortOrder.DESCENDING : SortOrder.ASCENDING;
sortType[i] = order; sortType[i] = order;
...@@ -292,6 +308,20 @@ public class Aggregate extends Expression { ...@@ -292,6 +308,20 @@ public class Aggregate extends Expression {
} }
} }
} }
if (type == AggregateType.ARRAY_AGG) {
if (v != ValueNull.INSTANCE) {
if (arrayAggOrderList != null) {
int size = arrayAggOrderList.size();
Value[] array = new Value[1 + size];
array[0] = v;
for (int i = 0; i < size; i++) {
SelectOrderBy o = arrayAggOrderList.get(i);
array[i + 1] = o.expression.getValue(session);
}
v = ValueArray.get(array);
}
}
}
if (filterCondition != null) { if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) { if (!filterCondition.getBooleanValue(session)) {
return; return;
...@@ -343,7 +373,7 @@ public class Aggregate extends Expression { ...@@ -343,7 +373,7 @@ public class Aggregate extends Expression {
} }
Value v = data.getValue(session.getDatabase(), dataType, distinct); Value v = data.getValue(session.getDatabase(), dataType, distinct);
if (type == AggregateType.GROUP_CONCAT) { if (type == AggregateType.GROUP_CONCAT) {
ArrayList<Value> list = ((AggregateDataGroupConcat) data).getList(); ArrayList<Value> list = ((AggregateDataArrayCollecting) data).getList();
if (list == null || list.isEmpty()) { if (list == null || list.isEmpty()) {
return ValueNull.INSTANCE; return ValueNull.INSTANCE;
} }
...@@ -377,6 +407,23 @@ public class Aggregate extends Expression { ...@@ -377,6 +407,23 @@ public class Aggregate extends Expression {
buff.append(s); buff.append(s);
} }
v = ValueString.get(buff.toString()); v = ValueString.get(buff.toString());
} else if (type == AggregateType.ARRAY_AGG) {
ArrayList<Value> list = ((AggregateDataArrayCollecting) data).getList();
if (list == null || list.isEmpty()) {
return ValueNull.INSTANCE;
}
if (arrayAggOrderList != null) {
final SortOrder sortOrder = arrayOrderSort;
Collections.sort(list, new Comparator<Value>() {
@Override
public int compare(Value v1, Value v2) {
Value[] a1 = ((ValueArray) v1).getList();
Value[] a2 = ((ValueArray) v2).getList();
return sortOrder.compare(a1, a2);
}
});
}
v = ValueArray.get(list.toArray(new Value[list.size()]));
} }
return v; return v;
} }
...@@ -396,6 +443,11 @@ public class Aggregate extends Expression { ...@@ -396,6 +443,11 @@ public class Aggregate extends Expression {
o.expression.mapColumns(resolver, level); o.expression.mapColumns(resolver, level);
} }
} }
if (arrayAggOrderList != null) {
for (SelectOrderBy o : arrayAggOrderList) {
o.expression.mapColumns(resolver, level);
}
}
if (groupConcatSeparator != null) { if (groupConcatSeparator != null) {
groupConcatSeparator.mapColumns(resolver, level); groupConcatSeparator.mapColumns(resolver, level);
} }
...@@ -417,7 +469,13 @@ public class Aggregate extends Expression { ...@@ -417,7 +469,13 @@ public class Aggregate extends Expression {
for (SelectOrderBy o : groupConcatOrderList) { for (SelectOrderBy o : groupConcatOrderList) {
o.expression = o.expression.optimize(session); o.expression = o.expression.optimize(session);
} }
groupConcatSort = initOrder(session); groupConcatSort = initOrder(groupConcatOrderList, session);
}
if (arrayAggOrderList != null) {
for (SelectOrderBy o : arrayAggOrderList) {
o.expression = o.expression.optimize(session);
}
arrayOrderSort = initOrder(arrayAggOrderList, session);
} }
if (groupConcatSeparator != null) { if (groupConcatSeparator != null) {
groupConcatSeparator = groupConcatSeparator.optimize(session); groupConcatSeparator = groupConcatSeparator.optimize(session);
...@@ -490,6 +548,11 @@ public class Aggregate extends Expression { ...@@ -490,6 +548,11 @@ public class Aggregate extends Expression {
throw DbException.get(ErrorCode.SUM_OR_AVG_ON_WRONG_DATATYPE_1, getSQL()); throw DbException.get(ErrorCode.SUM_OR_AVG_ON_WRONG_DATATYPE_1, getSQL());
} }
break; break;
case ARRAY_AGG:
dataType = Value.ARRAY;
scale = 0;
precision = displaySize = Integer.MAX_VALUE;
break;
default: default:
DbException.throwInternalError("type=" + type); DbException.throwInternalError("type=" + type);
} }
...@@ -506,6 +569,11 @@ public class Aggregate extends Expression { ...@@ -506,6 +569,11 @@ public class Aggregate extends Expression {
o.expression.setEvaluatable(tableFilter, b); o.expression.setEvaluatable(tableFilter, b);
} }
} }
if (arrayAggOrderList != null) {
for (SelectOrderBy o : arrayAggOrderList) {
o.expression.setEvaluatable(tableFilter, b);
}
}
if (groupConcatSeparator != null) { if (groupConcatSeparator != null) {
groupConcatSeparator.setEvaluatable(tableFilter, b); groupConcatSeparator.setEvaluatable(tableFilter, b);
} }
...@@ -555,6 +623,29 @@ public class Aggregate extends Expression { ...@@ -555,6 +623,29 @@ public class Aggregate extends Expression {
return buff.toString(); return buff.toString();
} }
private String getSQLArrayAggregate() {
StatementBuilder buff = new StatementBuilder("ARRAY_AGG(");
if (distinct) {
buff.append("DISTINCT ");
}
buff.append(on.getSQL());
if (arrayAggOrderList != null) {
buff.append(" ORDER BY ");
for (SelectOrderBy o : arrayAggOrderList) {
buff.appendExceptFirst(", ");
buff.append(o.expression.getSQL());
if (o.descending) {
buff.append(" DESC");
}
}
}
buff.append(')');
if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
return buff.toString();
}
@Override @Override
public String getSQL() { public String getSQL() {
String text; String text;
...@@ -611,6 +702,8 @@ public class Aggregate extends Expression { ...@@ -611,6 +702,8 @@ public class Aggregate extends Expression {
case MEDIAN: case MEDIAN:
text = "MEDIAN"; text = "MEDIAN";
break; break;
case ARRAY_AGG:
return getSQLArrayAggregate();
default: default:
throw DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
} }
...@@ -681,6 +774,14 @@ public class Aggregate extends Expression { ...@@ -681,6 +774,14 @@ public class Aggregate extends Expression {
} }
} }
} }
if (arrayAggOrderList != null) {
for (int i = 0, size = arrayAggOrderList.size(); i < size; i++) {
SelectOrderBy o = arrayAggOrderList.get(i);
if (!o.expression.isEverything(visitor)) {
return false;
}
}
}
return true; return true;
} }
......
...@@ -24,7 +24,9 @@ abstract class AggregateData { ...@@ -24,7 +24,9 @@ abstract class AggregateData {
if (aggregateType == AggregateType.SELECTIVITY) { if (aggregateType == AggregateType.SELECTIVITY) {
return new AggregateDataSelectivity(); return new AggregateDataSelectivity();
} else if (aggregateType == AggregateType.GROUP_CONCAT) { } else if (aggregateType == AggregateType.GROUP_CONCAT) {
return new AggregateDataGroupConcat(); return new AggregateDataArrayCollecting();
} else if (aggregateType == AggregateType.ARRAY_AGG) {
return new AggregateDataArrayCollecting();
} else if (aggregateType == AggregateType.COUNT_ALL) { } else if (aggregateType == AggregateType.COUNT_ALL) {
return new AggregateDataCountAll(); return new AggregateDataCountAll();
} else if (aggregateType == AggregateType.COUNT) { } else if (aggregateType == AggregateType.COUNT) {
......
...@@ -13,11 +13,11 @@ import org.h2.value.Value; ...@@ -13,11 +13,11 @@ import org.h2.value.Value;
import org.h2.value.ValueNull; import org.h2.value.ValueNull;
/** /**
* Data stored while calculating a GROUP_CONCAT aggregate. * Data stored while calculating a GROUP_CONCAT/ARRAY_AGG aggregate.
*/ */
class AggregateDataGroupConcat extends AggregateData { class AggregateDataArrayCollecting extends AggregateData {
private ArrayList<Value> list; private ArrayList<Value> list;
private ValueHashMap<AggregateDataGroupConcat> distinctValues; private ValueHashMap<AggregateDataArrayCollecting> distinctValues;
@Override @Override
void add(Database database, int dataType, boolean distinct, Value v) { void add(Database database, int dataType, boolean distinct, Value v) {
...@@ -40,7 +40,7 @@ class AggregateDataGroupConcat extends AggregateData { ...@@ -40,7 +40,7 @@ class AggregateDataGroupConcat extends AggregateData {
@Override @Override
Value getValue(Database database, int dataType, boolean distinct) { Value getValue(Database database, int dataType, boolean distinct) {
if (distinct) { if (distinct) {
groupDistinct(database, dataType); distinct(database, dataType);
} }
return null; return null;
} }
...@@ -49,7 +49,7 @@ class AggregateDataGroupConcat extends AggregateData { ...@@ -49,7 +49,7 @@ class AggregateDataGroupConcat extends AggregateData {
return list; return list;
} }
private void groupDistinct(Database database, int dataType) { private void distinct(Database database, int dataType) {
if (distinctValues == null) { if (distinctValues == null) {
return; return;
} }
......
...@@ -110,7 +110,7 @@ public class TestScript extends TestBase { ...@@ -110,7 +110,7 @@ public class TestScript extends TestBase {
} }
for (String s : new String[] { "avg", "bit-and", "bit-or", "count", for (String s : new String[] { "avg", "bit-and", "bit-or", "count",
"group-concat", "max", "median", "min", "selectivity", "stddev-pop", "group-concat", "max", "median", "min", "selectivity", "stddev-pop",
"stddev-samp", "sum", "var-pop", "var-samp" }) { "stddev-samp", "sum", "var-pop", "var-samp", "array-agg" }) {
testScript("functions/aggregate/" + s + ".sql"); testScript("functions/aggregate/" + s + ".sql");
} }
for (String s : new String[] { "abs", "acos", "asin", "atan", "atan2", for (String s : new String[] { "abs", "acos", "asin", "atan", "atan2",
......
-- Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: Alex Nordlund
--
-- with filter condition
create table test(v varchar);
> ok
insert into test values ('1'), ('2'), ('3'), ('4'), ('5'), ('6'), ('7'), ('8'), ('9');
> update count: 9
select array_agg(v order by v asc),
array_agg(v order by v desc) filter (where v >= '4')
from test where v >= '2';
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ---------------------------------------------------------------- ------------------------------------------------------
------------------------------
> (2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
create index test_idx on test(v);
select ARRAY_AGG(v order by v asc),
ARRAY_AGG(v order by v desc) filter (where v >= '4')
from test where v >= '2';
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ---------------------------------------------------------------- ------------------------------------------------------
------------------------------
> (2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
select ARRAY_AGG(v order by v asc),
ARRAY_AGG(v order by v desc) filter (where v >= '4')
from test;
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ------------------------------------------------------------------------ ------------------------------------------------------
------------------------------
> (1, 2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
drop table test;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论