提交 2a200352 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Add MODE aggregate function

上级 57c070e7
...@@ -3416,6 +3416,18 @@ Aggregates are only allowed in select statements. ...@@ -3416,6 +3416,18 @@ Aggregates are only allowed in select statements.
MEDIAN(X) MEDIAN(X)
" "
"Functions (Aggregate)","MODE","
MODE( value ) [ FILTER ( WHERE expression ) ]
","
Returns the value that occurs with the greatest frequency.
If there are multiple values with the same frequency only one value will be returned.
NULL values are ignored in the calculation.
If no rows are selected, the result is NULL.
Aggregates are only allowed in select statements.
","
MODE(X)
"
"Functions (Numeric)","ABS"," "Functions (Numeric)","ABS","
ABS ( { numeric } ) ABS ( { numeric } )
"," ","
......
...@@ -2916,40 +2916,35 @@ public class Parser { ...@@ -2916,40 +2916,35 @@ public class Parser {
} }
currentSelect.setGroupQuery(); currentSelect.setGroupQuery();
Aggregate r; Aggregate r;
if (aggregateType == AggregateType.COUNT) { switch (aggregateType) {
case COUNT:
if (readIf(ASTERISK)) { if (readIf(ASTERISK)) {
r = new Aggregate(AggregateType.COUNT_ALL, null, currentSelect, r = new Aggregate(AggregateType.COUNT_ALL, null, currentSelect, false);
false);
} else { } else {
boolean distinct = readIf(DISTINCT); boolean distinct = readIf(DISTINCT);
Expression on = readExpression(); Expression on = readExpression();
if (on instanceof Wildcard && !distinct) { if (on instanceof Wildcard && !distinct) {
// PostgreSQL compatibility: count(t.*) // PostgreSQL compatibility: count(t.*)
r = new Aggregate(AggregateType.COUNT_ALL, null, currentSelect, r = new Aggregate(AggregateType.COUNT_ALL, null, currentSelect, false);
false);
} else { } else {
r = new Aggregate(AggregateType.COUNT, on, currentSelect, r = new Aggregate(AggregateType.COUNT, on, currentSelect, distinct);
distinct);
} }
} }
} else if (aggregateType == AggregateType.GROUP_CONCAT) { break;
case GROUP_CONCAT: {
boolean distinct = readIf(DISTINCT); boolean distinct = readIf(DISTINCT);
if (equalsToken("GROUP_CONCAT", aggregateName)) { if (equalsToken("GROUP_CONCAT", aggregateName)) {
r = new Aggregate(AggregateType.GROUP_CONCAT, r = new Aggregate(AggregateType.GROUP_CONCAT, readExpression(), currentSelect, distinct);
readExpression(), currentSelect, distinct);
if (readIf(ORDER)) { if (readIf(ORDER)) {
read("BY"); read("BY");
r.setOrderByList(parseSimpleOrderList()); r.setOrderByList(parseSimpleOrderList());
} }
if (readIf("SEPARATOR")) { if (readIf("SEPARATOR")) {
r.setGroupConcatSeparator(readExpression()); r.setGroupConcatSeparator(readExpression());
} }
} else if (equalsToken("STRING_AGG", aggregateName)) { } else if (equalsToken("STRING_AGG", aggregateName)) {
// PostgreSQL compatibility: string_agg(expression, delimiter) // PostgreSQL compatibility: string_agg(expression, delimiter)
r = new Aggregate(AggregateType.GROUP_CONCAT, r = new Aggregate(AggregateType.GROUP_CONCAT, readExpression(), currentSelect, distinct);
readExpression(), currentSelect, distinct);
read(COMMA); read(COMMA);
r.setGroupConcatSeparator(readExpression()); r.setGroupConcatSeparator(readExpression());
if (readIf(ORDER)) { if (readIf(ORDER)) {
...@@ -2959,19 +2954,24 @@ public class Parser { ...@@ -2959,19 +2954,24 @@ public class Parser {
} else { } else {
r = null; r = null;
} }
} else if (aggregateType == AggregateType.ARRAY_AGG) { break;
}
case ARRAY_AGG: {
boolean distinct = readIf(DISTINCT); boolean distinct = readIf(DISTINCT);
r = new Aggregate(AggregateType.ARRAY_AGG, readExpression(), currentSelect, distinct);
r = new Aggregate(AggregateType.ARRAY_AGG,
readExpression(), currentSelect, distinct);
if (readIf(ORDER)) { if (readIf(ORDER)) {
read("BY"); read("BY");
r.setOrderByList(parseSimpleOrderList()); r.setOrderByList(parseSimpleOrderList());
} }
} else { break;
}
case MODE:
r = new Aggregate(aggregateType, readExpression(), currentSelect, false);
break;
default:
boolean distinct = readIf(DISTINCT); boolean distinct = readIf(DISTINCT);
r = new Aggregate(aggregateType, readExpression(), currentSelect, r = new Aggregate(aggregateType, readExpression(), currentSelect, distinct);
distinct); break;
} }
read(CLOSE_PAREN); read(CLOSE_PAREN);
if (r != null) { if (r != null) {
......
...@@ -132,10 +132,16 @@ public class Aggregate extends Expression { ...@@ -132,10 +132,16 @@ public class Aggregate extends Expression {
* The aggregate type for MEDIAN(expression). * The aggregate type for MEDIAN(expression).
*/ */
MEDIAN, MEDIAN,
/** /**
* The aggregate type for ARRAY_AGG(expression). * The aggregate type for ARRAY_AGG(expression).
*/ */
ARRAY_AGG ARRAY_AGG,
/**
* The aggregate type for MODE(expression).
*/
MODE,
} }
private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(64); private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(64);
...@@ -203,6 +209,7 @@ public class Aggregate extends Expression { ...@@ -203,6 +209,7 @@ public class Aggregate extends Expression {
addAggregate("BIT_AND", AggregateType.BIT_AND); addAggregate("BIT_AND", AggregateType.BIT_AND);
addAggregate("MEDIAN", AggregateType.MEDIAN); addAggregate("MEDIAN", AggregateType.MEDIAN);
addAggregate("ARRAY_AGG", AggregateType.ARRAY_AGG); addAggregate("ARRAY_AGG", AggregateType.ARRAY_AGG);
addAggregate("MODE", AggregateType.MODE);
} }
private static void addAggregate(String name, AggregateType type) { private static void addAggregate(String name, AggregateType type) {
...@@ -506,6 +513,7 @@ public class Aggregate extends Expression { ...@@ -506,6 +513,7 @@ public class Aggregate extends Expression {
case MIN: case MIN:
case MAX: case MAX:
case MEDIAN: case MEDIAN:
case MODE:
break; break;
case STDDEV_POP: case STDDEV_POP:
case STDDEV_SAMP: case STDDEV_SAMP:
...@@ -676,6 +684,9 @@ public class Aggregate extends Expression { ...@@ -676,6 +684,9 @@ public class Aggregate extends Expression {
break; break;
case ARRAY_AGG: case ARRAY_AGG:
return getSQLArrayAggregate(); return getSQLArrayAggregate();
case MODE:
text = "MODE";
break;
default: default:
throw DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
} }
......
...@@ -35,6 +35,8 @@ abstract class AggregateData { ...@@ -35,6 +35,8 @@ abstract class AggregateData {
return new AggregateDataHistogram(); return new AggregateDataHistogram();
case MEDIAN: case MEDIAN:
return new AggregateDataMedian(); return new AggregateDataMedian();
case MODE:
return new AggregateDataMode();
default: default:
return new AggregateDataDefault(aggregateType); return new AggregateDataDefault(aggregateType);
} }
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
import java.util.Map.Entry;
import org.h2.engine.Database;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueNull;
/**
* Data stored while calculating a MODE aggregate.
*/
class AggregateDataMode extends AggregateData {
private ValueHashMap<LongDataCounter> distinctValues;
@Override
void add(Database database, int dataType, boolean distinct, Value v) {
if (v == ValueNull.INSTANCE) {
return;
}
if (distinctValues == null) {
distinctValues = ValueHashMap.newInstance();
}
LongDataCounter a = distinctValues.get(v);
if (a == null) {
a = new LongDataCounter();
distinctValues.put(v, a);
}
a.count++;
}
@Override
Value getValue(Database database, int dataType, boolean distinct) {
Value v = ValueNull.INSTANCE;
if (distinctValues != null) {
long count = 0L;
for (Entry<Value, LongDataCounter> entry : distinctValues.entries()) {
long c = entry.getValue().count;
if (c > count) {
v = entry.getKey();
count = c;
}
}
}
return v.convertTo(dataType);
}
}
...@@ -137,7 +137,7 @@ public class TestScript extends TestDb { ...@@ -137,7 +137,7 @@ public class TestScript extends TestDb {
testScript("other/" + s + ".sql"); testScript("other/" + s + ".sql");
} }
for (String s : new String[] { "avg", "bit-and", "bit-or", "count", for (String s : new String[] { "avg", "bit-and", "bit-or", "count",
"group-concat", "max", "median", "min", "selectivity", "stddev-pop", "group-concat", "max", "median", "min", "mode", "selectivity", "stddev-pop",
"stddev-samp", "sum", "var-pop", "var-samp", "array-agg" }) { "stddev-samp", "sum", "var-pop", "var-samp", "array-agg" }) {
testScript("functions/aggregate/" + s + ".sql"); testScript("functions/aggregate/" + s + ".sql");
} }
......
-- Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
CREATE TABLE TEST(V INT);
> ok
SELECT MODE(V) FROM TEST;
>> null
SELECT MODE(DISTINCT V) FROM TEST;
> exception SYNTAX_ERROR_2
INSERT INTO TEST VALUES (NULL);
> update count: 1
SELECT MODE(V) FROM TEST;
>> null
INSERT INTO TEST VALUES (1), (2), (3), (1), (2), (1);
> update count: 6
SELECT MODE(V), MODE(V) FILTER (WHERE (V > 1)), MODE(V) FILTER (WHERE (V < 0)) FROM TEST;
> MODE(V) MODE(V) FILTER (WHERE (V > 1)) MODE(V) FILTER (WHERE (V < 0))
> ------- ------------------------------ ------------------------------
> 1 2 null
> rows: 1
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论