提交 1a30a9a7 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Support DISTINCT in Java aggregates

上级 048f834c
......@@ -2736,6 +2736,7 @@ public class Parser {
}
private JavaAggregate readJavaAggregate(UserAggregate aggregate) {
boolean distinct = readIf("DISTINCT");
ArrayList<Expression> params = New.arrayList();
do {
params.add(readExpression());
......@@ -2750,7 +2751,7 @@ public class Parser {
filterCondition = null;
}
Expression[] list = params.toArray(new Expression[0]);
JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, filterCondition);
JavaAggregate agg = new JavaAggregate(aggregate, list, currentSelect, distinct, filterCondition);
currentSelect.setGroupQuery();
return agg;
}
......
......@@ -20,6 +20,7 @@ import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueArray;
import org.h2.value.ValueNull;
/**
......@@ -31,16 +32,18 @@ public class JavaAggregate extends Expression {
private final Select select;
private final Expression[] args;
private int[] argTypes;
private final boolean distinct;
private Expression filterCondition;
private int dataType;
private Connection userConnection;
private int lastGroupRowId;
public JavaAggregate(UserAggregate userAggregate, Expression[] args,
Select select, Expression filterCondition) {
Select select, boolean distinct, Expression filterCondition) {
this.userAggregate = userAggregate;
this.args = args;
this.select = select;
this.distinct = distinct;
this.filterCondition = filterCondition;
}
......@@ -169,10 +172,30 @@ public class JavaAggregate extends Expression {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
try {
Aggregate agg = (Aggregate) group.get(this);
Aggregate agg;
if (distinct) {
agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) group.get(this);
if (data != null) {
for (Value value : data.values) {
if (args.length == 1) {
agg.add(value.getObject());
} else {
Value[] values = ((ValueArray) value).getList();
Object[] argValues = new Object[args.length];
for (int i = 0, len = args.length; i < len; i++) {
argValues[i] = values[i].getObject();
}
agg.add(argValues);
}
}
}
} else {
agg = (Aggregate) group.get(this);
if (agg == null) {
agg = getInstance();
}
}
Object obj = agg.getResult();
if (obj == null) {
return ValueNull.INSTANCE;
......@@ -204,8 +227,23 @@ public class JavaAggregate extends Expression {
}
}
Aggregate agg = (Aggregate) group.get(this);
try {
if (distinct) {
AggregateDataCollecting data = (AggregateDataCollecting) group.get(this);
if (data == null) {
data = new AggregateDataCollecting();
group.put(this, data);
}
Value[] argValues = new Value[args.length];
Value arg = null;
for (int i = 0, len = args.length; i < len; i++) {
arg = args[i].getValue(session);
arg = arg.convertTo(argTypes[i]);
argValues[i] = arg;
}
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else {
Aggregate agg = (Aggregate) group.get(this);
if (agg == null) {
agg = getInstance();
group.put(this, agg);
......@@ -218,10 +256,7 @@ public class JavaAggregate extends Expression {
arg = v.getObject();
argValues[i] = arg;
}
if (args.length == 1) {
agg.add(arg);
} else {
agg.add(argValues);
agg.add(args.length == 1 ? arg : argValues);
}
} catch (SQLException e) {
throw DbException.convert(e);
......
......@@ -32,6 +32,7 @@ import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.Currency;
import java.util.Date;
import java.util.GregorianCalendar;
......@@ -694,6 +695,7 @@ public class TestFunctions extends TestBase implements AggregateFunction {
@Override
public Object getResult() {
Collections.sort(list);
return list.get(list.size() / 2);
}
......@@ -793,6 +795,15 @@ public class TestFunctions extends TestBase implements AggregateFunction {
"SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
rs.next();
assertEquals("5", rs.getString(1));
stat.execute("CREATE TABLE DATA(V INT)");
stat.execute("INSERT INTO DATA VALUES (1), (3), (2), (1), (1), (2), (1), (1), (1), (1), (1)");
rs = stat.executeQuery(
"SELECT SIMPLE_MEDIAN(V), SIMPLE_MEDIAN(DISTINCT V) FROM DATA");
rs.next();
assertEquals("1", rs.getString(1));
assertEquals("2", rs.getString(2));
conn.close();
if (config.memory) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论