提交 c5cb7339 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Use AggregateDataCollecting instead of AggregateDataDefault for distinct aggregates

上级 f2584f9a
......@@ -12,6 +12,7 @@ import java.util.HashMap;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Database;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.expression.ExpressionColumn;
......@@ -417,12 +418,30 @@ public class Aggregate extends AbstractAggregate {
data = (AggregateData) createAggregateData();
}
switch (type) {
case COUNT: {
if (!distinct) {
return data.getValue(session.getDatabase(), dataType, distinct);
}
case COUNT:
if (distinct) {
return ValueLong.get(((AggregateDataCollecting) data).getCount());
}
break;
case SUM:
case AVG:
case STDDEV_POP:
case STDDEV_SAMP:
case VAR_POP:
case VAR_SAMP:
if (distinct) {
AggregateDataCollecting c = ((AggregateDataCollecting) data);
if (c.getCount() == 0) {
return ValueNull.INSTANCE;
}
AggregateDataDefault d = new AggregateDataDefault(type);
Database db = session.getDatabase();
for (Value v : c) {
d.add(db, dataType, false, v);
}
return d.getValue(db, dataType, false);
}
break;
case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray();
if (array == null) {
......@@ -479,8 +498,9 @@ public class Aggregate extends AbstractAggregate {
}
//$FALL-THROUGH$
default:
return data.getValue(session.getDatabase(), dataType, distinct);
// Avoid compiler warning
}
return data.getValue(session.getDatabase(), dataType, distinct);
}
@Override
......
......@@ -7,6 +7,7 @@ package org.h2.expression.aggregate;
import org.h2.engine.Database;
import org.h2.expression.aggregate.Aggregate.AggregateType;
import org.h2.message.DbException;
import org.h2.value.Value;
/**
......@@ -23,12 +24,6 @@ abstract class AggregateData {
*/
static AggregateData create(AggregateType aggregateType, boolean distinct) {
switch (aggregateType) {
case SELECTIVITY:
return new AggregateDataSelectivity();
case GROUP_CONCAT:
case ARRAY_AGG:
case MEDIAN:
break;
case COUNT_ALL:
return new AggregateDataCountAll();
case COUNT:
......@@ -36,6 +31,29 @@ abstract class AggregateData {
return new AggregateDataCount();
}
break;
case GROUP_CONCAT:
case ARRAY_AGG:
case MEDIAN:
break;
case MIN:
case MAX:
case BIT_OR:
case BIT_AND:
case BOOL_OR:
case BOOL_AND:
return new AggregateDataDefault(aggregateType);
case SUM:
case AVG:
case STDDEV_POP:
case STDDEV_SAMP:
case VAR_POP:
case VAR_SAMP:
if (!distinct) {
return new AggregateDataDefault(aggregateType);
}
break;
case SELECTIVITY:
return new AggregateDataSelectivity();
case HISTOGRAM:
return new AggregateDataHistogram();
case MODE:
......@@ -43,7 +61,7 @@ abstract class AggregateData {
case ENVELOPE:
return new AggregateDataEnvelope();
default:
return new AggregateDataDefault(aggregateType);
throw DbException.throwInternalError("type=" + aggregateType);
}
return new AggregateDataCollecting();
}
......
......@@ -7,7 +7,9 @@ package org.h2.expression.aggregate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import org.h2.engine.Database;
import org.h2.value.Value;
......@@ -23,7 +25,8 @@ import org.h2.value.ValueNull;
* class instead.
* </p>
*/
class AggregateDataCollecting extends AggregateData {
class AggregateDataCollecting extends AggregateData implements Iterable<Value> {
Collection<Value> values;
@Override
......@@ -64,4 +67,10 @@ class AggregateDataCollecting extends AggregateData {
}
return values.toArray(new Value[0]);
}
@Override
public Iterator<Value> iterator() {
return values != null ? values.iterator() : Collections.<Value>emptyIterator();
}
}
......@@ -8,7 +8,6 @@ package org.h2.expression.aggregate;
import org.h2.engine.Database;
import org.h2.expression.aggregate.Aggregate.AggregateType;
import org.h2.message.DbException;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueBoolean;
......@@ -20,9 +19,9 @@ import org.h2.value.ValueNull;
* Data stored while calculating an aggregate.
*/
class AggregateDataDefault extends AggregateData {
private final AggregateType aggregateType;
private long count;
private ValueHashMap<AggregateDataDefault> distinctValues;
private Value value;
private double m2, mean;
......@@ -39,13 +38,6 @@ class AggregateDataDefault extends AggregateData {
return;
}
count++;
if (distinct) {
if (distinctValues == null) {
distinctValues = ValueHashMap.newInstance();
}
distinctValues.put(v, this);
return;
}
switch (aggregateType) {
case SUM:
if (value == null) {
......@@ -128,9 +120,6 @@ class AggregateDataDefault extends AggregateData {
@Override
Value getValue(Database database, int dataType, boolean distinct) {
if (distinct) {
return getDistinct(database, dataType);
}
Value v = null;
switch (aggregateType) {
case SUM:
......@@ -191,15 +180,4 @@ class AggregateDataDefault extends AggregateData {
return a;
}
private Value getDistinct(Database database, int dataType) {
if (distinctValues == null) {
return ValueNull.INSTANCE;
}
AggregateDataDefault d = new AggregateDataDefault(aggregateType);
for (Value v : distinctValues.keys()) {
d.add(database, dataType, false, v);
}
return d.getValue(database, dataType, false);
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论