提交 9e6dbf3b authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Add optional FILTER clause to aggregates

上级 aeecde02
......@@ -2636,7 +2636,7 @@ GEOMETRY
"
"Functions (Aggregate)","AVG","
AVG ( [ DISTINCT ] { numeric } )
AVG ( [ DISTINCT ] { numeric } ) [ FILTER ( WHERE expression ) ]
","
The average (mean) value.
If no rows are selected, the result is NULL.
......@@ -2647,7 +2647,7 @@ AVG(X)
"
"Functions (Aggregate)","BIT_AND","
BIT_AND(expression)
BIT_AND(expression) [ FILTER ( WHERE expression ) ]
","
The bitwise AND of all non-null values.
If no rows are selected, the result is NULL.
......@@ -2657,7 +2657,7 @@ BIT_AND(ID)
"
"Functions (Aggregate)","BIT_OR","
BIT_OR(expression)
BIT_OR(expression) [ FILTER ( WHERE expression ) ]
","
The bitwise OR of all non-null values.
If no rows are selected, the result is NULL.
......@@ -2667,7 +2667,7 @@ BIT_OR(ID)
"
"Functions (Aggregate)","BOOL_AND","
BOOL_AND(boolean)
BOOL_AND(boolean) [ FILTER ( WHERE expression ) ]
","
Returns true if all expressions are true.
If no rows are selected, the result is NULL.
......@@ -2677,7 +2677,7 @@ BOOL_AND(ID>10)
"
"Functions (Aggregate)","BOOL_OR","
BOOL_OR(boolean)
BOOL_OR(boolean) [ FILTER ( WHERE expression ) ]
","
Returns true if any expression is true.
If no rows are selected, the result is NULL.
......@@ -2687,7 +2687,7 @@ BOOL_OR(NAME LIKE 'W%')
"
"Functions (Aggregate)","COUNT","
COUNT( { * | { [ DISTINCT ] expression } } )
COUNT( { * | { [ DISTINCT ] expression } } ) [ FILTER ( WHERE expression ) ]
","
The count of all row, or of the non-null values.
This method returns a long.
......@@ -2700,7 +2700,7 @@ COUNT(*)
"Functions (Aggregate)","GROUP_CONCAT","
GROUP_CONCAT ( [ DISTINCT ] string
[ ORDER BY { expression [ ASC | DESC ] } [,...] ]
[ SEPARATOR expression ] )
[ SEPARATOR expression ] ) [ FILTER ( WHERE expression ) ]
","
Concatenates strings with a separator.
The default separator is a ',' (without space).
......@@ -2712,7 +2712,7 @@ GROUP_CONCAT(NAME ORDER BY ID SEPARATOR ', ')
"
"Functions (Aggregate)","MAX","
MAX(value)
MAX(value) [ FILTER ( WHERE expression ) ]
","
The highest value.
If no rows are selected, the result is NULL.
......@@ -2723,7 +2723,7 @@ MAX(NAME)
"
"Functions (Aggregate)","MIN","
MIN(value)
MIN(value) [ FILTER ( WHERE expression ) ]
","
The lowest value.
If no rows are selected, the result is NULL.
......@@ -2734,7 +2734,7 @@ MIN(NAME)
"
"Functions (Aggregate)","SUM","
SUM( [ DISTINCT ] { numeric } )
SUM( [ DISTINCT ] { numeric } ) [ FILTER ( WHERE expression ) ]
","
The sum of all values.
If no rows are selected, the result is NULL.
......@@ -2746,7 +2746,7 @@ SUM(X)
"
"Functions (Aggregate)","SELECTIVITY","
SELECTIVITY(value)
SELECTIVITY(value) [ FILTER ( WHERE expression ) ]
","
Estimates the selectivity (0-100) of a value.
The value is defined as (100 * distinctCount / rowCount).
......@@ -2758,7 +2758,7 @@ SELECT SELECTIVITY(FIRSTNAME), SELECTIVITY(NAME) FROM TEST WHERE ROWNUM()<20000
"
"Functions (Aggregate)","STDDEV_POP","
STDDEV_POP( [ DISTINCT ] numeric )
STDDEV_POP( [ DISTINCT ] numeric ) [ FILTER ( WHERE expression ) ]
","
The population standard deviation.
This method returns a double.
......@@ -2769,7 +2769,7 @@ STDDEV_POP(X)
"
"Functions (Aggregate)","STDDEV_SAMP","
STDDEV_SAMP( [ DISTINCT ] numeric )
STDDEV_SAMP( [ DISTINCT ] numeric ) [ FILTER ( WHERE expression ) ]
","
The sample standard deviation.
This method returns a double.
......@@ -2780,7 +2780,7 @@ STDDEV(X)
"
"Functions (Aggregate)","VAR_POP","
VAR_POP( [ DISTINCT ] numeric )
VAR_POP( [ DISTINCT ] numeric ) [ FILTER ( WHERE expression ) ]
","
The population variance (square of the population standard deviation).
This method returns a double.
......@@ -2791,7 +2791,7 @@ VAR_POP(X)
"
"Functions (Aggregate)","VAR_SAMP","
VAR_SAMP( [ DISTINCT ] numeric )
VAR_SAMP( [ DISTINCT ] numeric ) [ FILTER ( WHERE expression ) ]
","
The sample variance (square of the sample standard deviation).
This method returns a double.
......@@ -2802,7 +2802,7 @@ VAR_SAMP(X)
"
"Functions (Aggregate)","MEDIAN","
MEDIAN( [ DISTINCT ] value )
MEDIAN( [ DISTINCT ] value ) [ FILTER ( WHERE expression ) ]
","
The value separating the higher half of a values from the lower half.
Returns the middle value or an interpolated value between two middle values if number of values is even.
......
......@@ -2627,7 +2627,7 @@ public class Parser {
throw getSyntaxError();
}
currentSelect.setGroupQuery();
Expression r;
Aggregate r;
if (aggregateType == AggregateType.COUNT) {
if (readIf("*")) {
r = new Aggregate(AggregateType.COUNT_ALL, null, currentSelect,
......@@ -2645,38 +2645,45 @@ public class Parser {
}
}
} else if (aggregateType == AggregateType.GROUP_CONCAT) {
Aggregate agg = null;
boolean distinct = readIf("DISTINCT");
if (equalsToken("GROUP_CONCAT", aggregateName)) {
agg = new Aggregate(AggregateType.GROUP_CONCAT,
r = new Aggregate(AggregateType.GROUP_CONCAT,
readExpression(), currentSelect, distinct);
if (readIf("ORDER")) {
read("BY");
agg.setGroupConcatOrder(parseSimpleOrderList());
r.setGroupConcatOrder(parseSimpleOrderList());
}
if (readIf("SEPARATOR")) {
agg.setGroupConcatSeparator(readExpression());
r.setGroupConcatSeparator(readExpression());
}
} else if (equalsToken("STRING_AGG", aggregateName)) {
// PostgreSQL compatibility: string_agg(expression, delimiter)
agg = new Aggregate(AggregateType.GROUP_CONCAT,
r = new Aggregate(AggregateType.GROUP_CONCAT,
readExpression(), currentSelect, distinct);
read(",");
agg.setGroupConcatSeparator(readExpression());
r.setGroupConcatSeparator(readExpression());
if (readIf("ORDER")) {
read("BY");
agg.setGroupConcatOrder(parseSimpleOrderList());
r.setGroupConcatOrder(parseSimpleOrderList());
}
} else {
r = null;
}
r = agg;
} else {
boolean distinct = readIf("DISTINCT");
r = new Aggregate(aggregateType, readExpression(), currentSelect,
distinct);
}
read(")");
if (r != null && readIf("FILTER")) {
read("(");
read("WHERE");
Expression condition = readExpression();
read(")");
r.setFilterCondition(condition);
}
return r;
}
......@@ -2727,8 +2734,17 @@ public class Parser {
params.add(readExpression());
} while (readIf(","));
read(")");
Expression filterCondition;
if (readIf("FILTER")) {
read("(");
read("WHERE");
filterCondition = readExpression();
read(")");
} else {
filterCondition = null;
}
Expression[] list = params.toArray(new Expression[0]);
JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect);
JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, filterCondition);
currentSelect.setGroupQuery();
return agg;
}
......
......@@ -146,6 +146,8 @@ public class Aggregate extends Expression {
private int displaySize;
private int lastGroupRowId;
private Expression filterCondition;
/**
* Create a new aggregate object.
*
......@@ -228,6 +230,15 @@ public class Aggregate extends Expression {
this.groupConcatSeparator = separator;
}
/**
* Sets the FILTER condition.
*
* @param filterCondition condition
*/
public void setFilterCondition(Expression filterCondition) {
this.filterCondition = filterCondition;
}
private SortOrder initOrder(Session session) {
int size = groupConcatOrderList.size();
int[] index = new int[size];
......@@ -281,6 +292,11 @@ public class Aggregate extends Expression {
}
}
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
data.add(session.getDatabase(), dataType, distinct, v);
}
......@@ -383,6 +399,9 @@ public class Aggregate extends Expression {
if (groupConcatSeparator != null) {
groupConcatSeparator.mapColumns(resolver, level);
}
if (filterCondition != null) {
filterCondition.mapColumns(resolver, level);
}
}
@Override
......@@ -403,6 +422,9 @@ public class Aggregate extends Expression {
if (groupConcatSeparator != null) {
groupConcatSeparator = groupConcatSeparator.optimize(session);
}
if (filterCondition != null) {
filterCondition = filterCondition.optimize(session);
}
switch (type) {
case GROUP_CONCAT:
dataType = Value.STRING;
......@@ -487,6 +509,9 @@ public class Aggregate extends Expression {
if (groupConcatSeparator != null) {
groupConcatSeparator.setEvaluatable(tableFilter, b);
}
if (filterCondition != null) {
filterCondition.setEvaluatable(tableFilter, b);
}
}
@Override
......@@ -523,7 +548,11 @@ public class Aggregate extends Expression {
if (groupConcatSeparator != null) {
buff.append(" SEPARATOR ").append(groupConcatSeparator.getSQL());
}
return buff.append(')').toString();
buff.append(')');
if (filterCondition != null) {
buff.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
}
return buff.toString();
}
@Override
......@@ -586,9 +615,14 @@ public class Aggregate extends Expression {
throw DbException.throwInternalError("type=" + type);
}
if (distinct) {
return text + "(DISTINCT " + on.getSQL() + ")";
text += "(DISTINCT " + on.getSQL() + ')';
} else {
text += StringUtils.enclose(on.getSQL());
}
return text + StringUtils.enclose(on.getSQL());
if (filterCondition != null) {
text += " FILTER (WHERE " + filterCondition.getSQL() + ')';
}
return text;
}
private Index getMinMaxColumnIndex() {
......@@ -607,6 +641,9 @@ public class Aggregate extends Expression {
@Override
public boolean isEverything(ExpressionVisitor visitor) {
if (filterCondition != null && !filterCondition.isEverything(visitor)) {
return false;
}
if (visitor.getType() == ExpressionVisitor.OPTIMIZABLE_MIN_MAX_COUNT_ALL) {
switch (type) {
case COUNT:
......@@ -649,7 +686,14 @@ public class Aggregate extends Expression {
@Override
public int getCost() {
return (on == null) ? 1 : on.getCost() + 1;
int cost = 1;
if (on != null) {
cost += on.getCost();
}
if (filterCondition != null) {
cost += filterCondition.getCost();
}
return cost;
}
}
......@@ -31,15 +31,17 @@ public class JavaAggregate extends Expression {
private final Select select;
private final Expression[] args;
private int[] argTypes;
private Expression filterCondition;
private int dataType;
private Connection userConnection;
private int lastGroupRowId;
public JavaAggregate(UserAggregate userAggregate, Expression[] args,
Select select) {
Select select, Expression filterCondition) {
this.userAggregate = userAggregate;
this.args = args;
this.select = select;
this.filterCondition = filterCondition;
}
@Override
......@@ -48,6 +50,9 @@ public class JavaAggregate extends Expression {
for (Expression e : args) {
cost += e.getCost();
}
if (filterCondition != null) {
cost += filterCondition.getCost();
}
return cost;
}
......@@ -101,6 +106,9 @@ public class JavaAggregate extends Expression {
return false;
}
}
if (filterCondition != null && !filterCondition.isEverything(visitor)) {
return false;
}
return true;
}
......@@ -109,6 +117,9 @@ public class JavaAggregate extends Expression {
for (Expression arg : args) {
arg.mapColumns(resolver, level);
}
if (filterCondition != null) {
filterCondition.mapColumns(resolver, level);
}
}
@Override
......@@ -128,6 +139,9 @@ public class JavaAggregate extends Expression {
} catch (SQLException e) {
throw DbException.convert(e);
}
if (filterCondition != null) {
filterCondition = filterCondition.optimize(session);
}
return this;
}
......@@ -136,6 +150,9 @@ public class JavaAggregate extends Expression {
for (Expression e : args) {
e.setEvaluatable(tableFilter, b);
}
if (filterCondition != null) {
filterCondition.setEvaluatable(tableFilter, b);
}
}
private Aggregate getInstance() throws SQLException {
......@@ -180,6 +197,12 @@ public class JavaAggregate extends Expression {
}
lastGroupRowId = groupRowId;
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
Aggregate agg = (Aggregate) group.get(this);
try {
if (agg == null) {
......
......@@ -747,6 +747,10 @@ public class TestFunctions extends TestBase implements AggregateFunction {
"SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
rs.next();
assertEquals("5", rs.getString(1));
rs = stat.executeQuery(
"SELECT SIMPLE_MEDIAN(X) FILTER (WHERE X > 2) FROM SYSTEM_RANGE(1, 9)");
rs.next();
assertEquals("6", rs.getString(1));
conn.close();
if (config.memory) {
......
......@@ -2,3 +2,28 @@
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
-- with filter condition
create table test(v int);
> ok
insert into test values (10), (20), (30), (40), (50), (60), (70), (80), (90), (100), (110), (120);
> update count: 12
select avg(v), avg(v) filter (where v >= 40) from test where v <= 100;
> AVG(V) AVG(V) FILTER (WHERE (V >= 40))
> ------ -------------------------------
> 55 70
> rows: 1
create index test_idx on test(v);
select avg(v), avg(v) filter (where v >= 40) from test where v <= 100;
> AVG(V) AVG(V) FILTER (WHERE (V >= 40))
> ------ -------------------------------
> 55 70
> rows: 1
drop table test;
> ok
......@@ -2,3 +2,31 @@
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
-- with filter condition
create table test(v bigint);
> ok
insert into test values
(0xfffffffffff0), (0xffffffffff0f), (0xfffffffff0ff), (0xffffffff0fff),
(0xfffffff0ffff), (0xffffff0fffff), (0xfffff0ffffff), (0xffff0fffffff),
(0xfff0ffffffff), (0xff0fffffffff), (0xf0ffffffffff), (0x0fffffffffff);
> update count: 12
select bit_and(v), bit_and(v) filter (where v <= 0xffffffff0fff) from test where v >= 0xff0fffffffff;
> BIT_AND(V) BIT_AND(V) FILTER (WHERE (V <= 281474976649215))
> --------------- ------------------------------------------------
> 280375465082880 280375465086975
> rows: 1
create index test_idx on test(v);
select bit_and(v), bit_and(v) filter (where v <= 0xffffffff0fff) from test where v >= 0xff0fffffffff;
> BIT_AND(V) BIT_AND(V) FILTER (WHERE (V <= 281474976649215))
> --------------- ------------------------------------------------
> 280375465082880 280375465086975
> rows: 1
drop table test;
> ok
......@@ -2,3 +2,30 @@
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
-- with filter condition
-- with filter condition
create table test(v bigint);
> ok
insert into test values (1), (2), (4), (8), (16), (32), (64), (128), (256), (512), (1024), (2048);
> update count: 12
select bit_or(v), bit_or(v) filter (where v >= 8) from test where v <= 512;
> BIT_OR(V) BIT_OR(V) FILTER (WHERE (V >= 8))
> --------- ---------------------------------
> 1023 1016
> rows: 1
create index test_idx on test(v);
select bit_or(v), bit_or(v) filter (where v >= 8) from test where v <= 512;
> BIT_OR(V) BIT_OR(V) FILTER (WHERE (V >= 8))
> --------- ---------------------------------
> 1023 1016
> rows: 1
drop table test;
> ok
......@@ -2,3 +2,34 @@
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
-- with filter condition
create table test(v int);
> ok
insert into test values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11), (12);
> update count: 12
select count(v), count(v) filter (where v >= 4) from test where v <= 10;
> COUNT(V) COUNT(V) FILTER (WHERE (V >= 4))
> -------- --------------------------------
> 10 7
> rows: 1
create index test_idx on test(v);
select count(v), count(v) filter (where v >= 4) from test where v <= 10;
> COUNT(V) COUNT(V) FILTER (WHERE (V >= 4))
> -------- --------------------------------
> 10 7
> rows: 1
select count(v), count(v) filter (where v >= 4) from test;
> COUNT(V) COUNT(V) FILTER (WHERE (V >= 4))
> -------- --------------------------------
> 12 9
> rows: 1
drop table test;
> ok
......@@ -2,3 +2,41 @@
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
-- with filter condition
create table test(v varchar);
> ok
insert into test values ('1'), ('2'), ('3'), ('4'), ('5'), ('6'), ('7'), ('8'), ('9');
> update count: 9
select group_concat(v order by v asc separator '-'),
group_concat(v order by v desc separator '-') filter (where v >= '4')
from test where v >= '2';
> GROUP_CONCAT(V ORDER BY V SEPARATOR '-') GROUP_CONCAT(V ORDER BY V DESC SEPARATOR '-') FILTER (WHERE (V >= '4'))
> ---------------------------------------- -----------------------------------------------------------------------
> 2-3-4-5-6-7-8-9 9-8-7-6-5-4
> rows (ordered): 1
create index test_idx on test(v);
select group_concat(v order by v asc separator '-'),
group_concat(v order by v desc separator '-') filter (where v >= '4')
from test where v >= '2';
> GROUP_CONCAT(V ORDER BY V SEPARATOR '-') GROUP_CONCAT(V ORDER BY V DESC SEPARATOR '-') FILTER (WHERE (V >= '4'))
> ---------------------------------------- -----------------------------------------------------------------------
> 2-3-4-5-6-7-8-9 9-8-7-6-5-4
> rows (ordered): 1
select group_concat(v order by v asc separator '-'),
group_concat(v order by v desc separator '-') filter (where v >= '4')
from test;
> GROUP_CONCAT(V ORDER BY V SEPARATOR '-') GROUP_CONCAT(V ORDER BY V DESC SEPARATOR '-') FILTER (WHERE (V >= '4'))
> ---------------------------------------- -----------------------------------------------------------------------
> 1-2-3-4-5-6-7-8-9 9-8-7-6-5-4
> rows (ordered): 1
drop table test;
> ok
......@@ -2,3 +2,34 @@
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
-- with filter condition
create table test(v int);
> ok
insert into test values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11), (12);
> update count: 12
select max(v), max(v) filter (where v <= 8) from test where v <= 10;
> MAX(V) MAX(V) FILTER (WHERE (V <= 8))
> ------ ------------------------------
> 10 8
> rows: 1
create index test_idx on test(v);
select max(v), max(v) filter (where v <= 8) from test where v <= 10;
> MAX(V) MAX(V) FILTER (WHERE (V <= 8))
> ------ ------------------------------
> 10 8
> rows: 1
select max(v), max(v) filter (where v <= 8) from test;
> MAX(V) MAX(V) FILTER (WHERE (V <= 8))
> ------ ------------------------------
> 12 8
> rows: 1
drop table test;
> ok
......@@ -734,3 +734,34 @@ select median(v) from test;
drop table test;
> ok
-- with filter condition
create table test(v int);
> ok
insert into test values (10), (20), (30), (40), (50), (60), (70), (80), (90), (100), (110), (120);
> update count: 12
select median(v), median(v) filter (where v >= 40) from test where v <= 100;
> MEDIAN(V) MEDIAN(V) FILTER (WHERE (V >= 40))
> --------- ----------------------------------
> 55 70
> rows: 1
create index test_idx on test(v);
select median(v), median(v) filter (where v >= 40) from test where v <= 100;
> MEDIAN(V) MEDIAN(V) FILTER (WHERE (V >= 40))
> --------- ----------------------------------
> 55 70
> rows: 1
select median(v), median(v) filter (where v >= 40) from test;
> MEDIAN(V) MEDIAN(V) FILTER (WHERE (V >= 40))
> --------- ----------------------------------
> 65 80
> rows: 1
drop table test;
> ok
......@@ -2,3 +2,34 @@
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
-- with filter condition
create table test(v int);
> ok
insert into test values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11), (12);
> update count: 12
select min(v), min(v) filter (where v >= 4) from test where v >= 2;
> MIN(V) MIN(V) FILTER (WHERE (V >= 4))
> ------ ------------------------------
> 2 4
> rows: 1
create index test_idx on test(v);
select min(v), min(v) filter (where v >= 4) from test where v >= 2;
> MIN(V) MIN(V) FILTER (WHERE (V >= 4))
> ------ ------------------------------
> 2 4
> rows: 1
select min(v), min(v) filter (where v >= 4) from test;
> MIN(V) MIN(V) FILTER (WHERE (V >= 4))
> ------ ------------------------------
> 1 4
> rows: 1
drop table test;
> ok
......@@ -2,3 +2,28 @@
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
-- with filter condition
create table test(v int);
> ok
insert into test values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11), (12);
> update count: 12
select sum(v), sum(v) filter (where v >= 4) from test where v <= 10;
> SUM(V) SUM(V) FILTER (WHERE (V >= 4))
> ------ ------------------------------
> 55 49
> rows: 1
create index test_idx on test(v);
select sum(v), sum(v) filter (where v >= 4) from test where v <= 10;
> SUM(V) SUM(V) FILTER (WHERE (V >= 4))
> ------ ------------------------------
> 55 49
> rows: 1
drop table test;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论