Unverified 提交 3ac73142 authored 作者: Noel Grandin's avatar Noel Grandin 提交者: GitHub

Merge pull request #967 from deepy/array-aggregate

Adds ARRAY_AGG function.
......@@ -2949,6 +2949,19 @@ Aggregates are only allowed in select statements.
GROUP_CONCAT(NAME ORDER BY ID SEPARATOR ', ')
"
"Functions (Aggregate)","ARRAY_AGG","
ARRAY_AGG ( [ DISTINCT ] string
[ ORDER BY { expression [ ASC | DESC ] } [,...] ] )
[ FILTER ( WHERE expression ) ]
","
Aggregate the value into an array.
This method returns an array.
If no rows are selected, the result is NULL.
Aggregates are only allowed in select statements.
","
ARRAY_AGG(NAME ORDER BY ID)
"
"Functions (Aggregate)","MAX","
MAX(value) [ FILTER ( WHERE expression ) ]
","
......
......@@ -21,6 +21,8 @@ Change Log
<h2>Next Version (unreleased)</h2>
<ul>
<li>PR #967: Adds ARRAY_AGG()
</li>
<li>PR #806: Implement setBytes() and setString() with offset and len
</li>
<li>PR #805: Improve support of TIMESTAMP WITH TIME ZONE
......
......@@ -2651,6 +2651,15 @@ public class Parser {
} else {
r = null;
}
} else if (aggregateType == AggregateType.ARRAY_AGG) {
boolean distinct = readIf("DISTINCT");
r = new Aggregate(AggregateType.ARRAY_AGG,
readExpression(), currentSelect, distinct);
if (readIf("ORDER")) {
read("BY");
r.setArrayAggOrder(parseSimpleOrderList());
}
} else {
boolean distinct = readIf("DISTINCT");
r = new Aggregate(aggregateType, readExpression(), currentSelect,
......
......@@ -128,10 +128,14 @@ public class Aggregate extends Expression {
/**
* The aggregate type for MEDIAN(expression).
*/
MEDIAN
MEDIAN,
/**
* The aggregate type for ARRAY_AGG(expression).
*/
ARRAY_AGG
}
private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(25);
private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(26);
private final AggregateType type;
private final Select select;
......@@ -140,7 +144,9 @@ public class Aggregate extends Expression {
private Expression on;
private Expression groupConcatSeparator;
private ArrayList<SelectOrderBy> groupConcatOrderList;
private ArrayList<SelectOrderBy> arrayAggOrderList;
private SortOrder groupConcatSort;
private SortOrder arrayOrderSort;
private int dataType, scale;
private long precision;
private int displaySize;
......@@ -195,6 +201,7 @@ public class Aggregate extends Expression {
addAggregate("BIT_OR", AggregateType.BIT_OR);
addAggregate("BIT_AND", AggregateType.BIT_AND);
addAggregate("MEDIAN", AggregateType.MEDIAN);
addAggregate("ARRAY_AGG", AggregateType.ARRAY_AGG);
}
private static void addAggregate(String name, AggregateType type) {
......@@ -221,6 +228,15 @@ public class Aggregate extends Expression {
this.groupConcatOrderList = orderBy;
}
/**
* Set the order for ARRAY_AGG() aggregate.
*
* @param orderBy the order by list
*/
public void setArrayAggOrder(ArrayList<SelectOrderBy> orderBy) {
this.arrayAggOrderList = orderBy;
}
/**
* Set the separator for the GROUP_CONCAT() aggregate.
*
......@@ -239,12 +255,12 @@ public class Aggregate extends Expression {
this.filterCondition = filterCondition;
}
private SortOrder initOrder(Session session) {
int size = groupConcatOrderList.size();
private SortOrder initOrder(ArrayList<SelectOrderBy> orderList, Session session) {
int size = orderList.size();
int[] index = new int[size];
int[] sortType = new int[size];
for (int i = 0; i < size; i++) {
SelectOrderBy o = groupConcatOrderList.get(i);
SelectOrderBy o = orderList.get(i);
index[i] = i + 1;
int order = o.descending ? SortOrder.DESCENDING : SortOrder.ASCENDING;
sortType[i] = order;
......@@ -292,6 +308,20 @@ public class Aggregate extends Expression {
}
}
}
if (type == AggregateType.ARRAY_AGG) {
if (v != ValueNull.INSTANCE) {
if (arrayAggOrderList != null) {
int size = arrayAggOrderList.size();
Value[] array = new Value[1 + size];
array[0] = v;
for (int i = 0; i < size; i++) {
SelectOrderBy o = arrayAggOrderList.get(i);
array[i + 1] = o.expression.getValue(session);
}
v = ValueArray.get(array);
}
}
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
......@@ -343,7 +373,7 @@ public class Aggregate extends Expression {
}
Value v = data.getValue(session.getDatabase(), dataType, distinct);
if (type == AggregateType.GROUP_CONCAT) {
ArrayList<Value> list = ((AggregateDataGroupConcat) data).getList();
ArrayList<Value> list = ((AggregateDataArrayCollecting) data).getList();
if (list == null || list.isEmpty()) {
return ValueNull.INSTANCE;
}
......@@ -377,6 +407,23 @@ public class Aggregate extends Expression {
buff.append(s);
}
v = ValueString.get(buff.toString());
} else if (type == AggregateType.ARRAY_AGG) {
ArrayList<Value> list = ((AggregateDataArrayCollecting) data).getList();
if (list == null || list.isEmpty()) {
return ValueNull.INSTANCE;
}
if (arrayAggOrderList != null) {
final SortOrder sortOrder = arrayOrderSort;
Collections.sort(list, new Comparator<Value>() {
@Override
public int compare(Value v1, Value v2) {
Value[] a1 = ((ValueArray) v1).getList();
Value[] a2 = ((ValueArray) v2).getList();
return sortOrder.compare(a1, a2);
}
});
}
v = ValueArray.get(list.toArray(new Value[list.size()]));
}
return v;
}
......@@ -396,6 +443,11 @@ public class Aggregate extends Expression {
o.expression.mapColumns(resolver, level);
}
}
if (arrayAggOrderList != null) {
for (SelectOrderBy o : arrayAggOrderList) {
o.expression.mapColumns(resolver, level);
}
}
if (groupConcatSeparator != null) {
groupConcatSeparator.mapColumns(resolver, level);
}
......@@ -417,7 +469,13 @@ public class Aggregate extends Expression {
for (SelectOrderBy o : groupConcatOrderList) {
o.expression = o.expression.optimize(session);
}
groupConcatSort = initOrder(session);
groupConcatSort = initOrder(groupConcatOrderList, session);
}
if (arrayAggOrderList != null) {
for (SelectOrderBy o : arrayAggOrderList) {
o.expression = o.expression.optimize(session);
}
arrayOrderSort = initOrder(arrayAggOrderList, session);
}
if (groupConcatSeparator != null) {
groupConcatSeparator = groupConcatSeparator.optimize(session);
......@@ -490,6 +548,11 @@ public class Aggregate extends Expression {
throw DbException.get(ErrorCode.SUM_OR_AVG_ON_WRONG_DATATYPE_1, getSQL());
}
break;
case ARRAY_AGG:
dataType = Value.ARRAY;
scale = 0;
precision = displaySize = Integer.MAX_VALUE;
break;
default:
DbException.throwInternalError("type=" + type);
}
......@@ -506,6 +569,11 @@ public class Aggregate extends Expression {
o.expression.setEvaluatable(tableFilter, b);
}
}
if (arrayAggOrderList != null) {
for (SelectOrderBy o : arrayAggOrderList) {
o.expression.setEvaluatable(tableFilter, b);
}
}
if (groupConcatSeparator != null) {
groupConcatSeparator.setEvaluatable(tableFilter, b);
}
......@@ -555,6 +623,29 @@ public class Aggregate extends Expression {
return buff.toString();
}
private String getSQLArrayAggregate() {
StatementBuilder buff = new StatementBuilder("ARRAY_AGG(");
if (distinct) {
buff.append("DISTINCT ");
}
buff.append(on.getSQL());
if (arrayAggOrderList != null) {
buff.append(" ORDER BY ");
for (SelectOrderBy o : arrayAggOrderList) {
buff.appendExceptFirst(", ");
buff.append(o.expression.getSQL());
if (o.descending) {
buff.append(" DESC");
}
}
}
buff.append(')');
if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
return buff.toString();
}
@Override
public String getSQL() {
String text;
......@@ -611,6 +702,8 @@ public class Aggregate extends Expression {
case MEDIAN:
text = "MEDIAN";
break;
case ARRAY_AGG:
return getSQLArrayAggregate();
default:
throw DbException.throwInternalError("type=" + type);
}
......@@ -681,6 +774,14 @@ public class Aggregate extends Expression {
}
}
}
if (arrayAggOrderList != null) {
for (int i = 0, size = arrayAggOrderList.size(); i < size; i++) {
SelectOrderBy o = arrayAggOrderList.get(i);
if (!o.expression.isEverything(visitor)) {
return false;
}
}
}
return true;
}
......
......@@ -24,7 +24,9 @@ abstract class AggregateData {
if (aggregateType == AggregateType.SELECTIVITY) {
return new AggregateDataSelectivity();
} else if (aggregateType == AggregateType.GROUP_CONCAT) {
return new AggregateDataGroupConcat();
return new AggregateDataArrayCollecting();
} else if (aggregateType == AggregateType.ARRAY_AGG) {
return new AggregateDataArrayCollecting();
} else if (aggregateType == AggregateType.COUNT_ALL) {
return new AggregateDataCountAll();
} else if (aggregateType == AggregateType.COUNT) {
......
......@@ -13,11 +13,11 @@ import org.h2.value.Value;
import org.h2.value.ValueNull;
/**
* Data stored while calculating a GROUP_CONCAT aggregate.
* Data stored while calculating a GROUP_CONCAT/ARRAY_AGG aggregate.
*/
class AggregateDataGroupConcat extends AggregateData {
class AggregateDataArrayCollecting extends AggregateData {
private ArrayList<Value> list;
private ValueHashMap<AggregateDataGroupConcat> distinctValues;
private ValueHashMap<AggregateDataArrayCollecting> distinctValues;
@Override
void add(Database database, int dataType, boolean distinct, Value v) {
......@@ -40,7 +40,7 @@ class AggregateDataGroupConcat extends AggregateData {
@Override
Value getValue(Database database, int dataType, boolean distinct) {
if (distinct) {
groupDistinct(database, dataType);
distinct(database, dataType);
}
return null;
}
......@@ -49,7 +49,7 @@ class AggregateDataGroupConcat extends AggregateData {
return list;
}
private void groupDistinct(Database database, int dataType) {
private void distinct(Database database, int dataType) {
if (distinctValues == null) {
return;
}
......
......@@ -110,7 +110,7 @@ public class TestScript extends TestBase {
}
for (String s : new String[] { "avg", "bit-and", "bit-or", "count",
"group-concat", "max", "median", "min", "selectivity", "stddev-pop",
"stddev-samp", "sum", "var-pop", "var-samp" }) {
"stddev-samp", "sum", "var-pop", "var-samp", "array-agg" }) {
testScript("functions/aggregate/" + s + ".sql");
}
for (String s : new String[] { "abs", "acos", "asin", "atan", "atan2",
......
-- Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: Alex Nordlund
--
-- with filter condition
create table test(v varchar);
> ok
insert into test values ('1'), ('2'), ('3'), ('4'), ('5'), ('6'), ('7'), ('8'), ('9');
> update count: 9
select array_agg(v order by v asc),
array_agg(v order by v desc) filter (where v >= '4')
from test where v >= '2';
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ---------------------------------------------------------------- ------------------------------------------------------
------------------------------
> (2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
create index test_idx on test(v);
select ARRAY_AGG(v order by v asc),
ARRAY_AGG(v order by v desc) filter (where v >= '4')
from test where v >= '2';
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ---------------------------------------------------------------- ------------------------------------------------------
------------------------------
> (2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
select ARRAY_AGG(v order by v asc),
ARRAY_AGG(v order by v desc) filter (where v >= '4')
from test;
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ------------------------------------------------------------------------ ------------------------------------------------------
------------------------------
> (1, 2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
drop table test;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论