提交 1a30a9a7 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Support DISTINCT in Java aggregates

上级 048f834c
...@@ -2736,6 +2736,7 @@ public class Parser { ...@@ -2736,6 +2736,7 @@ public class Parser {
} }
private JavaAggregate readJavaAggregate(UserAggregate aggregate) { private JavaAggregate readJavaAggregate(UserAggregate aggregate) {
boolean distinct = readIf("DISTINCT");
ArrayList<Expression> params = New.arrayList(); ArrayList<Expression> params = New.arrayList();
do { do {
params.add(readExpression()); params.add(readExpression());
...@@ -2750,7 +2751,7 @@ public class Parser { ...@@ -2750,7 +2751,7 @@ public class Parser {
filterCondition = null; filterCondition = null;
} }
Expression[] list = params.toArray(new Expression[0]); Expression[] list = params.toArray(new Expression[0]);
JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, filterCondition); JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, distinct, filterCondition);
currentSelect.setGroupQuery(); currentSelect.setGroupQuery();
return agg; return agg;
} }
......
...@@ -20,6 +20,7 @@ import org.h2.table.TableFilter; ...@@ -20,6 +20,7 @@ import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray;
import org.h2.value.ValueNull; import org.h2.value.ValueNull;
/** /**
...@@ -31,16 +32,18 @@ public class JavaAggregate extends Expression { ...@@ -31,16 +32,18 @@ public class JavaAggregate extends Expression {
private final Select select; private final Select select;
private final Expression[] args; private final Expression[] args;
private int[] argTypes; private int[] argTypes;
private final boolean distinct;
private Expression filterCondition; private Expression filterCondition;
private int dataType; private int dataType;
private Connection userConnection; private Connection userConnection;
private int lastGroupRowId; private int lastGroupRowId;
public JavaAggregate(UserAggregate userAggregate, Expression[] args, public JavaAggregate(UserAggregate userAggregate, Expression[] args,
Select select, Expression filterCondition) { Select select, boolean distinct, Expression filterCondition) {
this.userAggregate = userAggregate; this.userAggregate = userAggregate;
this.args = args; this.args = args;
this.select = select; this.select = select;
this.distinct = distinct;
this.filterCondition = filterCondition; this.filterCondition = filterCondition;
} }
...@@ -169,10 +172,30 @@ public class JavaAggregate extends Expression { ...@@ -169,10 +172,30 @@ public class JavaAggregate extends Expression {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL()); throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
try { try {
Aggregate agg = (Aggregate) group.get(this); Aggregate agg;
if (distinct) {
agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) group.get(this);
if (data != null) {
for (Value value : data.values) {
if (args.length == 1) {
agg.add(value.getObject());
} else {
Value[] values = ((ValueArray) value).getList();
Object[] argValues = new Object[args.length];
for (int i = 0, len = args.length; i < len; i++) {
argValues[i] = values[i].getObject();
}
agg.add(argValues);
}
}
}
} else {
agg = (Aggregate) group.get(this);
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
} }
}
Object obj = agg.getResult(); Object obj = agg.getResult();
if (obj == null) { if (obj == null) {
return ValueNull.INSTANCE; return ValueNull.INSTANCE;
...@@ -204,8 +227,23 @@ public class JavaAggregate extends Expression { ...@@ -204,8 +227,23 @@ public class JavaAggregate extends Expression {
} }
} }
Aggregate agg = (Aggregate) group.get(this);
try { try {
if (distinct) {
AggregateDataCollecting data = (AggregateDataCollecting) group.get(this);
if (data == null) {
data = new AggregateDataCollecting();
group.put(this, data);
}
Value[] argValues = new Value[args.length];
Value arg = null;
for (int i = 0, len = args.length; i < len; i++) {
arg = args[i].getValue(session);
arg = arg.convertTo(argTypes[i]);
argValues[i] = arg;
}
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else {
Aggregate agg = (Aggregate) group.get(this);
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
group.put(this, agg); group.put(this, agg);
...@@ -218,10 +256,7 @@ public class JavaAggregate extends Expression { ...@@ -218,10 +256,7 @@ public class JavaAggregate extends Expression {
arg = v.getObject(); arg = v.getObject();
argValues[i] = arg; argValues[i] = arg;
} }
if (args.length == 1) { agg.add(args.length == 1 ? arg : argValues);
agg.add(arg);
} else {
agg.add(argValues);
} }
} catch (SQLException e) { } catch (SQLException e) {
throw DbException.convert(e); throw DbException.convert(e);
......
...@@ -32,6 +32,7 @@ import java.text.ParseException; ...@@ -32,6 +32,7 @@ import java.text.ParseException;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Calendar; import java.util.Calendar;
import java.util.Collections;
import java.util.Currency; import java.util.Currency;
import java.util.Date; import java.util.Date;
import java.util.GregorianCalendar; import java.util.GregorianCalendar;
...@@ -694,6 +695,7 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -694,6 +695,7 @@ public class TestFunctions extends TestBase implements AggregateFunction {
@Override @Override
public Object getResult() { public Object getResult() {
Collections.sort(list);
return list.get(list.size() / 2); return list.get(list.size() / 2);
} }
...@@ -793,6 +795,15 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -793,6 +795,15 @@ public class TestFunctions extends TestBase implements AggregateFunction {
"SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)"); "SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
rs.next(); rs.next();
assertEquals("5", rs.getString(1)); assertEquals("5", rs.getString(1));
stat.execute("CREATE TABLE DATA(V INT)");
stat.execute("INSERT INTO DATA VALUES (1), (3), (2), (1), (1), (2), (1), (1), (1), (1), (1)");
rs = stat.executeQuery(
"SELECT SIMPLE_MEDIAN(V), SIMPLE_MEDIAN(DISTINCT V) FROM DATA");
rs.next();
assertEquals("1", rs.getString(1));
assertEquals("2", rs.getString(2));
conn.close(); conn.close();
if (config.memory) { if (config.memory) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论