提交 eef18e96 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Use array of expressions in Aggregate

上级 b27b12be
...@@ -3041,21 +3041,22 @@ public class Parser { ...@@ -3041,21 +3041,22 @@ public class Parser {
switch (aggregateType) { switch (aggregateType) {
case COUNT: case COUNT:
if (readIf(ASTERISK)) { if (readIf(ASTERISK)) {
r = new Aggregate(AggregateType.COUNT_ALL, null, currentSelect, false); r = new Aggregate(AggregateType.COUNT_ALL, new Expression[0], currentSelect, false);
} else { } else {
boolean distinct = readDistinctAgg(); boolean distinct = readDistinctAgg();
Expression on = readExpression(); Expression on = readExpression();
if (on instanceof Wildcard && !distinct) { if (on instanceof Wildcard && !distinct) {
// PostgreSQL compatibility: count(t.*) // PostgreSQL compatibility: count(t.*)
r = new Aggregate(AggregateType.COUNT_ALL, null, currentSelect, false); r = new Aggregate(AggregateType.COUNT_ALL, new Expression[0], currentSelect, false);
} else { } else {
r = new Aggregate(AggregateType.COUNT, on, currentSelect, distinct); r = new Aggregate(AggregateType.COUNT, new Expression[] { on }, currentSelect, distinct);
} }
} }
break; break;
case GROUP_CONCAT: { case GROUP_CONCAT: {
boolean distinct = readDistinctAgg(); boolean distinct = readDistinctAgg();
r = new Aggregate(AggregateType.GROUP_CONCAT, readExpression(), currentSelect, distinct); r = new Aggregate(AggregateType.GROUP_CONCAT, new Expression[] { readExpression() }, currentSelect,
distinct);
if (equalsToken("STRING_AGG", aggregateName)) { if (equalsToken("STRING_AGG", aggregateName)) {
// PostgreSQL compatibility: string_agg(expression, delimiter) // PostgreSQL compatibility: string_agg(expression, delimiter)
read(COMMA); read(COMMA);
...@@ -3077,7 +3078,7 @@ public class Parser { ...@@ -3077,7 +3078,7 @@ public class Parser {
} }
case ARRAY_AGG: { case ARRAY_AGG: {
boolean distinct = readDistinctAgg(); boolean distinct = readDistinctAgg();
r = new Aggregate(AggregateType.ARRAY_AGG, readExpression(), currentSelect, distinct); r = new Aggregate(AggregateType.ARRAY_AGG, new Expression[] { readExpression() }, currentSelect, distinct);
if (readIf(ORDER)) { if (readIf(ORDER)) {
read("BY"); read("BY");
r.setOrderByList(parseSimpleOrderList()); r.setOrderByList(parseSimpleOrderList());
...@@ -3088,15 +3089,15 @@ public class Parser { ...@@ -3088,15 +3089,15 @@ public class Parser {
case PERCENTILE_DISC: { case PERCENTILE_DISC: {
Expression num = readExpression(); Expression num = readExpression();
read(CLOSE_PAREN); read(CLOSE_PAREN);
r = readWithinGroup(aggregateType, num); r = readWithinGroup(aggregateType, new Expression[] { num });
break; break;
} }
case MODE: { case MODE: {
if (readIf(CLOSE_PAREN)) { if (readIf(CLOSE_PAREN)) {
r = readWithinGroup(AggregateType.MODE, null); r = readWithinGroup(AggregateType.MODE, new Expression[0]);
} else { } else {
Expression expr = readExpression(); Expression expr = readExpression();
r = new Aggregate(aggregateType, null, currentSelect, false); r = new Aggregate(aggregateType, new Expression[0], currentSelect, false);
if (readIf(ORDER)) { if (readIf(ORDER)) {
read("BY"); read("BY");
Expression expr2 = readExpression(); Expression expr2 = readExpression();
...@@ -3114,7 +3115,7 @@ public class Parser { ...@@ -3114,7 +3115,7 @@ public class Parser {
} }
default: default:
boolean distinct = readDistinctAgg(); boolean distinct = readDistinctAgg();
r = new Aggregate(aggregateType, readExpression(), currentSelect, distinct); r = new Aggregate(aggregateType, new Expression[] { readExpression() }, currentSelect, distinct);
break; break;
} }
read(CLOSE_PAREN); read(CLOSE_PAREN);
...@@ -3122,7 +3123,7 @@ public class Parser { ...@@ -3122,7 +3123,7 @@ public class Parser {
return r; return r;
} }
private Aggregate readWithinGroup(AggregateType aggregateType, Expression argument) { private Aggregate readWithinGroup(AggregateType aggregateType, Expression[] args) {
Aggregate r; Aggregate r;
read("WITHIN"); read("WITHIN");
read(GROUP); read(GROUP);
...@@ -3130,7 +3131,7 @@ public class Parser { ...@@ -3130,7 +3131,7 @@ public class Parser {
read(ORDER); read(ORDER);
read("BY"); read("BY");
Expression expr = readExpression(); Expression expr = readExpression();
r = new Aggregate(aggregateType, argument, currentSelect, false); r = new Aggregate(aggregateType, args, currentSelect, false);
readAggregateOrder(r, expr, true); readAggregateOrder(r, expr, true);
return r; return r;
} }
......
...@@ -53,7 +53,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -53,7 +53,7 @@ public class Aggregate extends AbstractAggregate {
private final AggregateType aggregateType; private final AggregateType aggregateType;
private Expression on; private final Expression[] args;
private Expression groupConcatSeparator; private Expression groupConcatSeparator;
private ArrayList<SelectOrderBy> orderByList; private ArrayList<SelectOrderBy> orderByList;
private SortOrder orderBySort; private SortOrder orderBySort;
...@@ -64,20 +64,20 @@ public class Aggregate extends AbstractAggregate { ...@@ -64,20 +64,20 @@ public class Aggregate extends AbstractAggregate {
* *
* @param aggregateType * @param aggregateType
* the aggregate type * the aggregate type
* @param on * @param args
* the aggregated expression * the aggregated expressions
* @param select * @param select
* the select statement * the select statement
* @param distinct * @param distinct
* if distinct is used * if distinct is used
*/ */
public Aggregate(AggregateType aggregateType, Expression on, Select select, boolean distinct) { public Aggregate(AggregateType aggregateType, Expression[] args, Select select, boolean distinct) {
super(select, distinct); super(select, distinct);
if (distinct && aggregateType == AggregateType.COUNT_ALL) { if (distinct && aggregateType == AggregateType.COUNT_ALL) {
throw DbException.throwInternalError(); throw DbException.throwInternalError();
} }
this.aggregateType = aggregateType; this.aggregateType = aggregateType;
this.on = on; this.args = args;
} }
static { static {
...@@ -185,7 +185,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -185,7 +185,7 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
protected void updateAggregate(Session session, Object aggregateData) { protected void updateAggregate(Session session, Object aggregateData) {
AggregateData data = (AggregateData) aggregateData; AggregateData data = (AggregateData) aggregateData;
Value v = on == null ? null : on.getValue(session); Value v = args.length == 0 ? null : args[0].getValue(session);
updateData(session, data, v, null); updateData(session, data, v, null);
} }
...@@ -218,8 +218,8 @@ public class Aggregate extends AbstractAggregate { ...@@ -218,8 +218,8 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
protected void updateGroupAggregates(Session session, int stage) { protected void updateGroupAggregates(Session session, int stage) {
super.updateGroupAggregates(session, stage); super.updateGroupAggregates(session, stage);
if (on != null) { for (Expression arg : args) {
on.updateAggregate(session, stage); arg.updateAggregate(session, stage);
} }
if (orderByList != null) { if (orderByList != null) {
for (SelectOrderBy orderBy : orderByList) { for (SelectOrderBy orderBy : orderByList) {
...@@ -248,7 +248,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -248,7 +248,7 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
protected int getNumExpressions() { protected int getNumExpressions() {
int n = on != null ? 1 : 0; int n = args.length;
if (orderByList != null) { if (orderByList != null) {
n += orderByList.size(); n += orderByList.size();
} }
...@@ -261,8 +261,8 @@ public class Aggregate extends AbstractAggregate { ...@@ -261,8 +261,8 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
protected void rememberExpressions(Session session, Value[] array) { protected void rememberExpressions(Session session, Value[] array) {
int offset = 0; int offset = 0;
if (on != null) { for (Expression arg : args) {
array[offset++] = on.getValue(session); array[offset++] = arg.getValue(session);
} }
if (orderByList != null) { if (orderByList != null) {
for (SelectOrderBy o : orderByList) { for (SelectOrderBy o : orderByList) {
...@@ -278,7 +278,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -278,7 +278,7 @@ public class Aggregate extends AbstractAggregate {
protected void updateFromExpressions(Session session, Object aggregateData, Value[] array) { protected void updateFromExpressions(Session session, Object aggregateData, Value[] array) {
if (filterCondition == null || array[getNumExpressions() - 1].getBoolean()) { if (filterCondition == null || array[getNumExpressions() - 1].getBoolean()) {
AggregateData data = (AggregateData) aggregateData; AggregateData data = (AggregateData) aggregateData;
Value v = on == null ? null : array[0]; Value v = args.length == 0 ? null : array[0];
updateData(session, data, v, array); updateData(session, data, v, array);
} }
} }
...@@ -319,7 +319,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -319,7 +319,7 @@ public class Aggregate extends AbstractAggregate {
} }
case PERCENTILE_CONT: case PERCENTILE_CONT:
case PERCENTILE_DISC: { case PERCENTILE_DISC: {
Value v = on.getValue(session); Value v = args[0].getValue(session);
if (v == ValueNull.INSTANCE) { if (v == ValueNull.INSTANCE) {
return ValueNull.INSTANCE; return ValueNull.INSTANCE;
} }
...@@ -333,9 +333,9 @@ public class Aggregate extends AbstractAggregate { ...@@ -333,9 +333,9 @@ public class Aggregate extends AbstractAggregate {
} }
} }
case MEDIAN: case MEDIAN:
return Percentile.getFromIndex(session, on, type.getValueType(), orderByList, Percentile.HALF, true); return Percentile.getFromIndex(session, args[0], type.getValueType(), orderByList, Percentile.HALF, true);
case ENVELOPE: case ENVELOPE:
return ((MVSpatialIndex) AggregateDataEnvelope.getGeometryColumnIndex(on)).getBounds(session); return ((MVSpatialIndex) AggregateDataEnvelope.getGeometryColumnIndex(args[0])).getBounds(session);
default: default:
throw DbException.throwInternalError("type=" + aggregateType); throw DbException.throwInternalError("type=" + aggregateType);
} }
...@@ -528,8 +528,8 @@ public class Aggregate extends AbstractAggregate { ...@@ -528,8 +528,8 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
public void mapColumnsAnalysis(ColumnResolver resolver, int level, int innerState) { public void mapColumnsAnalysis(ColumnResolver resolver, int level, int innerState) {
if (on != null) { for (Expression arg : args) {
on.mapColumns(resolver, level, innerState); arg.mapColumns(resolver, level, innerState);
} }
if (orderByList != null) { if (orderByList != null) {
for (SelectOrderBy o : orderByList) { for (SelectOrderBy o : orderByList) {
...@@ -545,9 +545,11 @@ public class Aggregate extends AbstractAggregate { ...@@ -545,9 +545,11 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
public Expression optimize(Session session) { public Expression optimize(Session session) {
super.optimize(session); super.optimize(session);
if (on != null) { for (int i = 0; i < args.length; i++) {
on = on.optimize(session); args[i] = args[i].optimize(session);
type = on.getType(); }
if (args.length == 1) {
type = args[0].getType();
} }
if (orderByList != null) { if (orderByList != null) {
for (SelectOrderBy o : orderByList) { for (SelectOrderBy o : orderByList) {
...@@ -642,8 +644,8 @@ public class Aggregate extends AbstractAggregate { ...@@ -642,8 +644,8 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
public void setEvaluatable(TableFilter tableFilter, boolean b) { public void setEvaluatable(TableFilter tableFilter, boolean b) {
if (on != null) { for (Expression arg : args) {
on.setEvaluatable(tableFilter, b); arg.setEvaluatable(tableFilter, b);
} }
if (orderByList != null) { if (orderByList != null) {
for (SelectOrderBy o : orderByList) { for (SelectOrderBy o : orderByList) {
...@@ -661,7 +663,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -661,7 +663,7 @@ public class Aggregate extends AbstractAggregate {
if (distinct) { if (distinct) {
builder.append("DISTINCT "); builder.append("DISTINCT ");
} }
on.getSQL(builder); args[0].getSQL(builder);
Window.appendOrderBy(builder, orderByList); Window.appendOrderBy(builder, orderByList);
if (groupConcatSeparator != null) { if (groupConcatSeparator != null) {
builder.append(" SEPARATOR "); builder.append(" SEPARATOR ");
...@@ -676,7 +678,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -676,7 +678,7 @@ public class Aggregate extends AbstractAggregate {
if (distinct) { if (distinct) {
builder.append("DISTINCT "); builder.append("DISTINCT ");
} }
on.getSQL(builder); args[0].getSQL(builder);
Window.appendOrderBy(builder, orderByList); Window.appendOrderBy(builder, orderByList);
builder.append(')'); builder.append(')');
return appendTailConditions(builder); return appendTailConditions(builder);
...@@ -758,14 +760,14 @@ public class Aggregate extends AbstractAggregate { ...@@ -758,14 +760,14 @@ public class Aggregate extends AbstractAggregate {
builder.append(text); builder.append(text);
if (distinct) { if (distinct) {
builder.append("(DISTINCT "); builder.append("(DISTINCT ");
on.getSQL(builder).append(')'); args[0].getSQL(builder).append(')');
} else { } else {
builder.append('('); builder.append('(');
if (on != null) { for (Expression arg : args) {
if (on instanceof Subquery) { if (arg instanceof Subquery) {
on.getSQL(builder); arg.getSQL(builder);
} else { } else {
on.getUnenclosedSQL(builder); arg.getUnenclosedSQL(builder);
} }
} }
builder.append(')'); builder.append(')');
...@@ -779,8 +781,9 @@ public class Aggregate extends AbstractAggregate { ...@@ -779,8 +781,9 @@ public class Aggregate extends AbstractAggregate {
} }
private Index getMinMaxColumnIndex() { private Index getMinMaxColumnIndex() {
if (on instanceof ExpressionColumn) { Expression arg = args[0];
ExpressionColumn col = (ExpressionColumn) on; if (arg instanceof ExpressionColumn) {
ExpressionColumn col = (ExpressionColumn) arg;
Column column = col.getColumn(); Column column = col.getColumn();
TableFilter filter = col.getTableFilter(); TableFilter filter = col.getTableFilter();
if (filter != null) { if (filter != null) {
...@@ -802,7 +805,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -802,7 +805,7 @@ public class Aggregate extends AbstractAggregate {
if (visitor.getType() == ExpressionVisitor.OPTIMIZABLE_AGGREGATE) { if (visitor.getType() == ExpressionVisitor.OPTIMIZABLE_AGGREGATE) {
switch (aggregateType) { switch (aggregateType) {
case COUNT: case COUNT:
if (!distinct && on.getNullable() == Column.NOT_NULLABLE) { if (!distinct && args[0].getNullable() == Column.NOT_NULLABLE) {
return visitor.getTable().canGetRowCount(); return visitor.getTable().canGetRowCount();
} }
return false; return false;
...@@ -814,20 +817,22 @@ public class Aggregate extends AbstractAggregate { ...@@ -814,20 +817,22 @@ public class Aggregate extends AbstractAggregate {
return index != null; return index != null;
case PERCENTILE_CONT: case PERCENTILE_CONT:
case PERCENTILE_DISC: case PERCENTILE_DISC:
return on.isConstant() && Percentile.getColumnIndex(orderByList.get(0).expression) != null; return args[0].isConstant() && Percentile.getColumnIndex(orderByList.get(0).expression) != null;
case MEDIAN: case MEDIAN:
if (distinct) { if (distinct) {
return false; return false;
} }
return Percentile.getColumnIndex(on) != null; return Percentile.getColumnIndex(args[0]) != null;
case ENVELOPE: case ENVELOPE:
return AggregateDataEnvelope.getGeometryColumnIndex(on) != null; return AggregateDataEnvelope.getGeometryColumnIndex(args[0]) != null;
default: default:
return false; return false;
} }
} }
if (on != null && !on.isEverything(visitor)) { for (Expression arg : args) {
return false; if (!arg.isEverything(visitor)) {
return false;
}
} }
if (groupConcatSeparator != null && !groupConcatSeparator.isEverything(visitor)) { if (groupConcatSeparator != null && !groupConcatSeparator.isEverything(visitor)) {
return false; return false;
...@@ -845,8 +850,16 @@ public class Aggregate extends AbstractAggregate { ...@@ -845,8 +850,16 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
public int getCost() { public int getCost() {
int cost = 1; int cost = 1;
if (on != null) { for (Expression arg : args) {
cost += on.getCost(); cost += arg.getCost();
}
if (groupConcatSeparator != null) {
cost += groupConcatSeparator.getCost();
}
if (orderByList != null) {
for (SelectOrderBy o : orderByList) {
cost += o.expression.getCost();
}
} }
if (filterCondition != null) { if (filterCondition != null) {
cost += filterCondition.getCost(); cost += filterCondition.getCost();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论