Unverified 提交 280e1972 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1053 from katzyn/aggregate

Support DISTINCT in Java aggregates
...@@ -21,6 +21,18 @@ Change Log ...@@ -21,6 +21,18 @@ Change Log
<h2>Next Version (unreleased)</h2> <h2>Next Version (unreleased)</h2>
<ul> <ul>
<li>Issue #1047: Support DISTINCT in custom aggregate functions
</li>
<li>PR #1046: Split off Transaction TransactionMap VersionedValue
</li>
<li>PR #1045: TransactionStore move into separate org.h2.mvstore.tx package
</li>
<li>PR #1044: Encapsulate TransactionStore.store field in preparation to a move
</li>
<li>PR #1040: generate less garbage for String substring+trim
</li>
<li>PR #1035: Minor free space accounting changes
</li>
<li>Issue #1034: MERGE USING should not require the same column count in tables <li>Issue #1034: MERGE USING should not require the same column count in tables
</li> </li>
<li>PR #1033: Fix issues with BUILTIN_ALIAS_OVERRIDE=1 <li>PR #1033: Fix issues with BUILTIN_ALIAS_OVERRIDE=1
......
...@@ -2736,6 +2736,7 @@ public class Parser { ...@@ -2736,6 +2736,7 @@ public class Parser {
} }
private JavaAggregate readJavaAggregate(UserAggregate aggregate) { private JavaAggregate readJavaAggregate(UserAggregate aggregate) {
boolean distinct = readIf("DISTINCT");
ArrayList<Expression> params = New.arrayList(); ArrayList<Expression> params = New.arrayList();
do { do {
params.add(readExpression()); params.add(readExpression());
...@@ -2750,7 +2751,7 @@ public class Parser { ...@@ -2750,7 +2751,7 @@ public class Parser {
filterCondition = null; filterCondition = null;
} }
Expression[] list = params.toArray(new Expression[0]); Expression[] list = params.toArray(new Expression[0]);
JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, filterCondition); JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, distinct, filterCondition);
currentSelect.setGroupQuery(); currentSelect.setGroupQuery();
return agg; return agg;
} }
......
...@@ -20,6 +20,7 @@ import org.h2.table.TableFilter; ...@@ -20,6 +20,7 @@ import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray;
import org.h2.value.ValueNull; import org.h2.value.ValueNull;
/** /**
...@@ -31,16 +32,18 @@ public class JavaAggregate extends Expression { ...@@ -31,16 +32,18 @@ public class JavaAggregate extends Expression {
private final Select select; private final Select select;
private final Expression[] args; private final Expression[] args;
private int[] argTypes; private int[] argTypes;
private final boolean distinct;
private Expression filterCondition; private Expression filterCondition;
private int dataType; private int dataType;
private Connection userConnection; private Connection userConnection;
private int lastGroupRowId; private int lastGroupRowId;
public JavaAggregate(UserAggregate userAggregate, Expression[] args, public JavaAggregate(UserAggregate userAggregate, Expression[] args,
Select select, Expression filterCondition) { Select select, boolean distinct, Expression filterCondition) {
this.userAggregate = userAggregate; this.userAggregate = userAggregate;
this.args = args; this.args = args;
this.select = select; this.select = select;
this.distinct = distinct;
this.filterCondition = filterCondition; this.filterCondition = filterCondition;
} }
...@@ -169,9 +172,29 @@ public class JavaAggregate extends Expression { ...@@ -169,9 +172,29 @@ public class JavaAggregate extends Expression {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL()); throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
try { try {
Aggregate agg = (Aggregate) group.get(this); Aggregate agg;
if (agg == null) { if (distinct) {
agg = getInstance(); agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) group.get(this);
if (data != null) {
for (Value value : data.values) {
if (args.length == 1) {
agg.add(value.getObject());
} else {
Value[] values = ((ValueArray) value).getList();
Object[] argValues = new Object[args.length];
for (int i = 0, len = args.length; i < len; i++) {
argValues[i] = values[i].getObject();
}
agg.add(argValues);
}
}
}
} else {
agg = (Aggregate) group.get(this);
if (agg == null) {
agg = getInstance();
}
} }
Object obj = agg.getResult(); Object obj = agg.getResult();
if (obj == null) { if (obj == null) {
...@@ -204,24 +227,36 @@ public class JavaAggregate extends Expression { ...@@ -204,24 +227,36 @@ public class JavaAggregate extends Expression {
} }
} }
Aggregate agg = (Aggregate) group.get(this);
try { try {
if (agg == null) { if (distinct) {
agg = getInstance(); AggregateDataCollecting data = (AggregateDataCollecting) group.get(this);
group.put(this, agg); if (data == null) {
} data = new AggregateDataCollecting();
Object[] argValues = new Object[args.length]; group.put(this, data);
Object arg = null; }
for (int i = 0, len = args.length; i < len; i++) { Value[] argValues = new Value[args.length];
Value v = args[i].getValue(session); Value arg = null;
v = v.convertTo(argTypes[i]); for (int i = 0, len = args.length; i < len; i++) {
arg = v.getObject(); arg = args[i].getValue(session);
argValues[i] = arg; arg = arg.convertTo(argTypes[i]);
} argValues[i] = arg;
if (args.length == 1) { }
agg.add(arg); data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else { } else {
agg.add(argValues); Aggregate agg = (Aggregate) group.get(this);
if (agg == null) {
agg = getInstance();
group.put(this, agg);
}
Object[] argValues = new Object[args.length];
Object arg = null;
for (int i = 0, len = args.length; i < len; i++) {
Value v = args[i].getValue(session);
v = v.convertTo(argTypes[i]);
arg = v.getObject();
argValues[i] = arg;
}
agg.add(args.length == 1 ? arg : argValues);
} }
} catch (SQLException e) { } catch (SQLException e) {
throw DbException.convert(e); throw DbException.convert(e);
......
...@@ -32,6 +32,7 @@ import java.text.ParseException; ...@@ -32,6 +32,7 @@ import java.text.ParseException;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Calendar; import java.util.Calendar;
import java.util.Collections;
import java.util.Currency; import java.util.Currency;
import java.util.Date; import java.util.Date;
import java.util.GregorianCalendar; import java.util.GregorianCalendar;
...@@ -694,6 +695,7 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -694,6 +695,7 @@ public class TestFunctions extends TestBase implements AggregateFunction {
@Override @Override
public Object getResult() { public Object getResult() {
Collections.sort(list);
return list.get(list.size() / 2); return list.get(list.size() / 2);
} }
...@@ -793,6 +795,15 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -793,6 +795,15 @@ public class TestFunctions extends TestBase implements AggregateFunction {
"SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)"); "SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
rs.next(); rs.next();
assertEquals("5", rs.getString(1)); assertEquals("5", rs.getString(1));
stat.execute("CREATE TABLE DATA(V INT)");
stat.execute("INSERT INTO DATA VALUES (1), (3), (2), (1), (1), (2), (1), (1), (1), (1), (1)");
rs = stat.executeQuery(
"SELECT SIMPLE_MEDIAN(V), SIMPLE_MEDIAN(DISTINCT V) FROM DATA");
rs.next();
assertEquals("1", rs.getString(1));
assertEquals("2", rs.getString(2));
conn.close(); conn.close();
if (config.memory) { if (config.memory) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论