提交 c5cb7339 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Use AggregateDataCollecting instead of AggregateDataDefault for distinct aggregates

上级 f2584f9a
...@@ -12,6 +12,7 @@ import java.util.HashMap; ...@@ -12,6 +12,7 @@ import java.util.HashMap;
import org.h2.api.ErrorCode; import org.h2.api.ErrorCode;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectOrderBy; import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Database;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.expression.ExpressionColumn; import org.h2.expression.ExpressionColumn;
...@@ -417,12 +418,30 @@ public class Aggregate extends AbstractAggregate { ...@@ -417,12 +418,30 @@ public class Aggregate extends AbstractAggregate {
data = (AggregateData) createAggregateData(); data = (AggregateData) createAggregateData();
} }
switch (type) { switch (type) {
case COUNT: { case COUNT:
if (!distinct) { if (distinct) {
return data.getValue(session.getDatabase(), dataType, distinct); return ValueLong.get(((AggregateDataCollecting) data).getCount());
} }
return ValueLong.get(((AggregateDataCollecting) data).getCount()); break;
} case SUM:
case AVG:
case STDDEV_POP:
case STDDEV_SAMP:
case VAR_POP:
case VAR_SAMP:
if (distinct) {
AggregateDataCollecting c = ((AggregateDataCollecting) data);
if (c.getCount() == 0) {
return ValueNull.INSTANCE;
}
AggregateDataDefault d = new AggregateDataDefault(type);
Database db = session.getDatabase();
for (Value v : c) {
d.add(db, dataType, false, v);
}
return d.getValue(db, dataType, false);
}
break;
case GROUP_CONCAT: { case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray(); Value[] array = ((AggregateDataCollecting) data).getArray();
if (array == null) { if (array == null) {
...@@ -479,8 +498,9 @@ public class Aggregate extends AbstractAggregate { ...@@ -479,8 +498,9 @@ public class Aggregate extends AbstractAggregate {
} }
//$FALL-THROUGH$ //$FALL-THROUGH$
default: default:
return data.getValue(session.getDatabase(), dataType, distinct); // Avoid compiler warning
} }
return data.getValue(session.getDatabase(), dataType, distinct);
} }
@Override @Override
......
...@@ -7,6 +7,7 @@ package org.h2.expression.aggregate; ...@@ -7,6 +7,7 @@ package org.h2.expression.aggregate;
import org.h2.engine.Database; import org.h2.engine.Database;
import org.h2.expression.aggregate.Aggregate.AggregateType; import org.h2.expression.aggregate.Aggregate.AggregateType;
import org.h2.message.DbException;
import org.h2.value.Value; import org.h2.value.Value;
/** /**
...@@ -23,12 +24,6 @@ abstract class AggregateData { ...@@ -23,12 +24,6 @@ abstract class AggregateData {
*/ */
static AggregateData create(AggregateType aggregateType, boolean distinct) { static AggregateData create(AggregateType aggregateType, boolean distinct) {
switch (aggregateType) { switch (aggregateType) {
case SELECTIVITY:
return new AggregateDataSelectivity();
case GROUP_CONCAT:
case ARRAY_AGG:
case MEDIAN:
break;
case COUNT_ALL: case COUNT_ALL:
return new AggregateDataCountAll(); return new AggregateDataCountAll();
case COUNT: case COUNT:
...@@ -36,6 +31,29 @@ abstract class AggregateData { ...@@ -36,6 +31,29 @@ abstract class AggregateData {
return new AggregateDataCount(); return new AggregateDataCount();
} }
break; break;
case GROUP_CONCAT:
case ARRAY_AGG:
case MEDIAN:
break;
case MIN:
case MAX:
case BIT_OR:
case BIT_AND:
case BOOL_OR:
case BOOL_AND:
return new AggregateDataDefault(aggregateType);
case SUM:
case AVG:
case STDDEV_POP:
case STDDEV_SAMP:
case VAR_POP:
case VAR_SAMP:
if (!distinct) {
return new AggregateDataDefault(aggregateType);
}
break;
case SELECTIVITY:
return new AggregateDataSelectivity();
case HISTOGRAM: case HISTOGRAM:
return new AggregateDataHistogram(); return new AggregateDataHistogram();
case MODE: case MODE:
...@@ -43,7 +61,7 @@ abstract class AggregateData { ...@@ -43,7 +61,7 @@ abstract class AggregateData {
case ENVELOPE: case ENVELOPE:
return new AggregateDataEnvelope(); return new AggregateDataEnvelope();
default: default:
return new AggregateDataDefault(aggregateType); throw DbException.throwInternalError("type=" + aggregateType);
} }
return new AggregateDataCollecting(); return new AggregateDataCollecting();
} }
......
...@@ -7,7 +7,9 @@ package org.h2.expression.aggregate; ...@@ -7,7 +7,9 @@ package org.h2.expression.aggregate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator;
import org.h2.engine.Database; import org.h2.engine.Database;
import org.h2.value.Value; import org.h2.value.Value;
...@@ -23,7 +25,8 @@ import org.h2.value.ValueNull; ...@@ -23,7 +25,8 @@ import org.h2.value.ValueNull;
* class instead. * class instead.
* </p> * </p>
*/ */
class AggregateDataCollecting extends AggregateData { class AggregateDataCollecting extends AggregateData implements Iterable<Value> {
Collection<Value> values; Collection<Value> values;
@Override @Override
...@@ -64,4 +67,10 @@ class AggregateDataCollecting extends AggregateData { ...@@ -64,4 +67,10 @@ class AggregateDataCollecting extends AggregateData {
} }
return values.toArray(new Value[0]); return values.toArray(new Value[0]);
} }
@Override
public Iterator<Value> iterator() {
return values != null ? values.iterator() : Collections.<Value>emptyIterator();
}
} }
...@@ -8,7 +8,6 @@ package org.h2.expression.aggregate; ...@@ -8,7 +8,6 @@ package org.h2.expression.aggregate;
import org.h2.engine.Database; import org.h2.engine.Database;
import org.h2.expression.aggregate.Aggregate.AggregateType; import org.h2.expression.aggregate.Aggregate.AggregateType;
import org.h2.message.DbException; import org.h2.message.DbException;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueBoolean; import org.h2.value.ValueBoolean;
...@@ -20,9 +19,9 @@ import org.h2.value.ValueNull; ...@@ -20,9 +19,9 @@ import org.h2.value.ValueNull;
* Data stored while calculating an aggregate. * Data stored while calculating an aggregate.
*/ */
class AggregateDataDefault extends AggregateData { class AggregateDataDefault extends AggregateData {
private final AggregateType aggregateType; private final AggregateType aggregateType;
private long count; private long count;
private ValueHashMap<AggregateDataDefault> distinctValues;
private Value value; private Value value;
private double m2, mean; private double m2, mean;
...@@ -39,13 +38,6 @@ class AggregateDataDefault extends AggregateData { ...@@ -39,13 +38,6 @@ class AggregateDataDefault extends AggregateData {
return; return;
} }
count++; count++;
if (distinct) {
if (distinctValues == null) {
distinctValues = ValueHashMap.newInstance();
}
distinctValues.put(v, this);
return;
}
switch (aggregateType) { switch (aggregateType) {
case SUM: case SUM:
if (value == null) { if (value == null) {
...@@ -128,9 +120,6 @@ class AggregateDataDefault extends AggregateData { ...@@ -128,9 +120,6 @@ class AggregateDataDefault extends AggregateData {
@Override @Override
Value getValue(Database database, int dataType, boolean distinct) { Value getValue(Database database, int dataType, boolean distinct) {
if (distinct) {
return getDistinct(database, dataType);
}
Value v = null; Value v = null;
switch (aggregateType) { switch (aggregateType) {
case SUM: case SUM:
...@@ -191,15 +180,4 @@ class AggregateDataDefault extends AggregateData { ...@@ -191,15 +180,4 @@ class AggregateDataDefault extends AggregateData {
return a; return a;
} }
private Value getDistinct(Database database, int dataType) {
if (distinctValues == null) {
return ValueNull.INSTANCE;
}
AggregateDataDefault d = new AggregateDataDefault(aggregateType);
for (Value v : distinctValues.keys()) {
d.add(database, dataType, false, v);
}
return d.getValue(database, dataType, false);
}
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论