Unverified 提交 ca7eaa69 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1009 from katzyn/aggregate

Fix ARRAY_AGG with ORDER BY and refactor aggregates
......@@ -2628,7 +2628,7 @@ public class Parser {
readExpression(), currentSelect, distinct);
if (readIf("ORDER")) {
read("BY");
r.setGroupConcatOrder(parseSimpleOrderList());
r.setOrderByList(parseSimpleOrderList());
}
if (readIf("SEPARATOR")) {
......@@ -2642,7 +2642,7 @@ public class Parser {
r.setGroupConcatSeparator(readExpression());
if (readIf("ORDER")) {
read("BY");
r.setGroupConcatOrder(parseSimpleOrderList());
r.setOrderByList(parseSimpleOrderList());
}
} else {
r = null;
......@@ -2654,7 +2654,7 @@ public class Parser {
readExpression(), currentSelect, distinct);
if (readIf("ORDER")) {
read("BY");
r.setArrayAggOrder(parseSimpleOrderList());
r.setOrderByList(parseSimpleOrderList());
}
} else {
boolean distinct = readIf("DISTINCT");
......
......@@ -21,21 +21,21 @@ abstract class AggregateData {
* @return the aggregate data object of the specified type
*/
static AggregateData create(AggregateType aggregateType) {
if (aggregateType == AggregateType.SELECTIVITY) {
switch (aggregateType) {
case SELECTIVITY:
return new AggregateDataSelectivity();
} else if (aggregateType == AggregateType.GROUP_CONCAT) {
return new AggregateDataArrayCollecting();
} else if (aggregateType == AggregateType.ARRAY_AGG) {
return new AggregateDataArrayCollecting();
} else if (aggregateType == AggregateType.COUNT_ALL) {
case GROUP_CONCAT:
case ARRAY_AGG:
return new AggregateDataCollecting();
case COUNT_ALL:
return new AggregateDataCountAll();
} else if (aggregateType == AggregateType.COUNT) {
case COUNT:
return new AggregateDataCount();
} else if (aggregateType == AggregateType.HISTOGRAM) {
case HISTOGRAM:
return new AggregateDataHistogram();
} else if (aggregateType == AggregateType.MEDIAN) {
case MEDIAN:
return new AggregateDataMedian();
} else {
default:
return new AggregateDataDefault(aggregateType);
}
}
......
......@@ -6,55 +6,54 @@
package org.h2.expression;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import org.h2.engine.Database;
import org.h2.util.New;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueNull;
/**
* Data stored while calculating a GROUP_CONCAT/ARRAY_AGG aggregate.
* Data stored while calculating an aggregate that needs collecting of all
* values.
*
* <p>
* NULL values are not collected. {@link #getValue(Database, int, boolean)}
* method returns {@code null}. Use {@link #getArray()} for instances of this
* class instead. Notice that subclasses like {@link AggregateDataMedian} may
* override {@link #getValue(Database, int, boolean)} to return useful result.
* </p>
*/
class AggregateDataArrayCollecting extends AggregateData {
private ArrayList<Value> list;
private ValueHashMap<AggregateDataArrayCollecting> distinctValues;
class AggregateDataCollecting extends AggregateData {
Collection<Value> values;
@Override
void add(Database database, int dataType, boolean distinct, Value v) {
if (v == ValueNull.INSTANCE) {
return;
}
if (distinct) {
if (distinctValues == null) {
distinctValues = ValueHashMap.newInstance();
}
distinctValues.put(v, this);
return;
}
if (list == null) {
list = New.arrayList();
Collection<Value> c = values;
if (c == null) {
values = c = distinct ? new HashSet<Value>() : new ArrayList<Value>();
}
list.add(v);
c.add(v);
}
@Override
Value getValue(Database database, int dataType, boolean distinct) {
if (distinct) {
distinct(database, dataType);
}
return null;
}
ArrayList<Value> getList() {
return list;
}
private void distinct(Database database, int dataType) {
if (distinctValues == null) {
return;
}
for (Value v : distinctValues.keys()) {
add(database, dataType, false, v);
/**
* Returns array with values or {@code null}.
*
* @return array with values or {@code null}
*/
Value[] getArray() {
Collection<Value> values = this.values;
if (values == null) {
return null;
}
return values.toArray(new Value[0]);
}
}
......@@ -8,9 +8,7 @@ package org.h2.expression;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import org.h2.engine.Database;
import org.h2.engine.Session;
......@@ -39,9 +37,7 @@ import org.h2.value.ValueTimestampTimeZone;
/**
* Data stored while calculating a MEDIAN aggregate.
*/
class AggregateDataMedian extends AggregateData {
private Collection<Value> values;
class AggregateDataMedian extends AggregateDataCollecting {
private static boolean isNullsLast(Index index) {
IndexColumn ic = index.getIndexColumns()[0];
int sortType = ic.sortType;
......@@ -168,29 +164,12 @@ class AggregateDataMedian extends AggregateData {
return v;
}
@Override
void add(Database database, int dataType, boolean distinct, Value v) {
if (v == ValueNull.INSTANCE) {
return;
}
Collection<Value> c = values;
if (c == null) {
values = c = distinct ? new HashSet<Value>() : new ArrayList<Value>();
}
c.add(v);
}
@Override
Value getValue(Database database, int dataType, boolean distinct) {
Collection<Value> c = values;
// Non-null collection cannot be empty here
if (c == null) {
Value[] a = getArray();
if (a == null) {
return ValueNull.INSTANCE;
}
if (distinct && c instanceof ArrayList) {
c = new HashSet<>(c);
}
Value[] a = c.toArray(new Value[0]);
final CompareMode mode = database.getCompareMode();
Arrays.sort(a, new Comparator<Value>() {
@Override
......
......@@ -14,10 +14,9 @@ insert into test values ('1'), ('2'), ('3'), ('4'), ('5'), ('6'), ('7'), ('8'),
select array_agg(v order by v asc),
array_agg(v order by v desc) filter (where v >= '4')
from test where v >= '2';
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ---------------------------------------------------------------- ------------------------------------------------------
------------------------------
> (2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ------------------------ ------------------------------------------------------
> (2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
create index test_idx on test(v);
......@@ -25,21 +24,45 @@ create index test_idx on test(v);
select ARRAY_AGG(v order by v asc),
ARRAY_AGG(v order by v desc) filter (where v >= '4')
from test where v >= '2';
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ---------------------------------------------------------------- ------------------------------------------------------
------------------------------
> (2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ------------------------ ------------------------------------------------------
> (2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
select ARRAY_AGG(v order by v asc),
ARRAY_AGG(v order by v desc) filter (where v >= '4')
from test;
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> ------------------------------------------------------------------------ ------------------------------------------------------
------------------------------
> (1, 2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> ARRAY_AGG(V ORDER BY V) ARRAY_AGG(V ORDER BY V DESC) FILTER (WHERE (V >= '4'))
> --------------------------- ------------------------------------------------------
> (1, 2, 3, 4, 5, 6, 7, 8, 9) (9, 8, 7, 6, 5, 4)
> rows (ordered): 1
drop table test;
> ok
create table test (id int auto_increment primary key, v int);
> ok
insert into test(v) values (7), (2), (8), (3), (7), (3), (9), (-1);
> update count: 8
select array_agg(v) from test;
> ARRAY_AGG(V)
> -------------------------
> (7, 2, 8, 3, 7, 3, 9, -1)
> rows: 1
select array_agg(distinct v) from test;
> ARRAY_AGG(DISTINCT V)
> ---------------------
> (-1, 2, 3, 7, 8, 9)
> rows: 1
select array_agg(distinct v order by v desc) from test;
> ARRAY_AGG(DISTINCT V ORDER BY V DESC)
> -------------------------------------
> (9, 8, 7, 3, 2, -1)
> rows (ordered): 1
drop table test;
> ok
......@@ -37,6 +37,32 @@ select group_concat(v order by v asc separator '-'),
> 1-2-3-4-5-6-7-8-9 9-8-7-6-5-4
> rows (ordered): 1
drop table test;
> ok
create table test (id int auto_increment primary key, v int);
> ok
insert into test(v) values (7), (2), (8), (3), (7), (3), (9), (-1);
> update count: 8
select group_concat(v) from test;
> GROUP_CONCAT(V)
> ----------------
> 7,2,8,3,7,3,9,-1
> rows: 1
select group_concat(distinct v) from test;
> GROUP_CONCAT(DISTINCT V)
> ------------------------
> -1,2,3,7,8,9
> rows: 1
select group_concat(distinct v order by v desc) from test;
> GROUP_CONCAT(DISTINCT V ORDER BY V DESC)
> ----------------------------------------
> 9,8,7,3,2,-1
> rows (ordered): 1
drop table test;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论