提交 ba358f48 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Evaluate window aggregates only once for each partition

上级 3b296bb7
......@@ -5,13 +5,16 @@
*/
package org.h2.expression.aggregate;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.message.DbException;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueArray;
/**
......@@ -137,33 +140,48 @@ public abstract class AbstractAggregate extends Expression {
protected Object getData(Session session, SelectGroups groupData, boolean ifExists) {
Object data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
if (over != null) {
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
PartitionData partition = (PartitionData) map.get(key);
if (partition == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
map.put(key, new PartitionData(data));
} else {
data = partition.getData();
}
} else {
PartitionData partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
groupData.setCurrentGroupExprData(this, new PartitionData(data), true);
} else {
data = partition.getData();
}
data = createAggregateData();
map.put(key, data);
}
} else {
data = groupData.getCurrentGroupExprData(this, over != null);
data = groupData.getCurrentGroupExprData(this, false);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
groupData.setCurrentGroupExprData(this, data, over != null);
groupData.setCurrentGroupExprData(this, data, false);
}
}
return data;
......@@ -171,6 +189,55 @@ public abstract class AbstractAggregate extends Expression {
protected abstract Object createAggregateData();
@Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
return over == null ? getAggregatedValue(session, getData(session, groupData, true))
: getWindowResult(session, groupData);
}
private Value getWindowResult(Session session, SelectGroups groupData) {
PartitionData partition;
Object data;
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
partition = (PartitionData) map.get(key);
if (partition == null) {
data = createAggregateData();
partition = new PartitionData(data);
map.put(key, partition);
} else {
data = partition.getData();
}
} else {
partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
data = createAggregateData();
partition = new PartitionData(data);
groupData.setCurrentGroupExprData(this, partition, true);
} else {
data = partition.getData();
}
}
Value result = partition.getResult();
if (result == null) {
result = getAggregatedValue(session, data);
partition.setResult(result);
}
return result;
}
protected abstract Value getAggregatedValue(Session session, Object aggregateData);
protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
......
......@@ -11,7 +11,6 @@ import java.util.Comparator;
import java.util.HashMap;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session;
import org.h2.expression.Expression;
......@@ -332,43 +331,45 @@ public class Aggregate extends AbstractAggregate {
@Override
public Value getValue(Session session) {
if (select.isQuickAggregateQuery()) {
switch (type) {
case COUNT:
case COUNT_ALL:
Table table = select.getTopTableFilter().getTable();
return ValueLong.get(table.getRowCount(session));
case MIN:
case MAX: {
boolean first = type == AggregateType.MIN;
Index index = getMinMaxColumnIndex();
int sortType = index.getIndexColumns()[0].sortType;
if ((sortType & SortOrder.DESCENDING) != 0) {
first = !first;
}
Cursor cursor = index.findFirstOrLast(session, first);
SearchRow row = cursor.getSearchRow();
Value v;
if (row == null) {
v = ValueNull.INSTANCE;
} else {
v = row.getValue(index.getColumns()[0].getColumnId());
}
return v;
return select.isQuickAggregateQuery() ? getValueQuick(session) : super.getValue(session);
}
private Value getValueQuick(Session session) {
switch (type) {
case COUNT:
case COUNT_ALL:
Table table = select.getTopTableFilter().getTable();
return ValueLong.get(table.getRowCount(session));
case MIN:
case MAX: {
boolean first = type == AggregateType.MIN;
Index index = getMinMaxColumnIndex();
int sortType = index.getIndexColumns()[0].sortType;
if ((sortType & SortOrder.DESCENDING) != 0) {
first = !first;
}
case MEDIAN:
return AggregateDataMedian.getResultFromIndex(session, on, dataType);
case ENVELOPE:
return ((MVSpatialIndex) AggregateDataEnvelope.getGeometryColumnIndex(on)).getBounds(session);
default:
DbException.throwInternalError("type=" + type);
Cursor cursor = index.findFirstOrLast(session, first);
SearchRow row = cursor.getSearchRow();
Value v;
if (row == null) {
v = ValueNull.INSTANCE;
} else {
v = row.getValue(index.getColumns()[0].getColumnId());
}
return v;
}
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
case MEDIAN:
return AggregateDataMedian.getResultFromIndex(session, on, dataType);
case ENVELOPE:
return ((MVSpatialIndex) AggregateDataEnvelope.getGeometryColumnIndex(on)).getBounds(session);
default:
throw DbException.throwInternalError("type=" + type);
}
AggregateData data = (AggregateData) getData(session, groupData, true);
}
@Override
public Value getAggregatedValue(Session session, Object aggregateData) {
AggregateData data = (AggregateData) aggregateData;
if (data == null) {
data = (AggregateData) createAggregateData();
}
......
......@@ -8,7 +8,6 @@ package org.h2.expression.aggregate;
import java.sql.Connection;
import java.sql.SQLException;
import org.h2.api.Aggregate;
import org.h2.api.ErrorCode;
import org.h2.command.Parser;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
......@@ -158,16 +157,12 @@ public class JavaAggregate extends AbstractAggregate {
}
@Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
public Value getAggregatedValue(Session session, Object aggregateData) {
try {
Aggregate agg;
if (distinct) {
agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) getData(session, groupData, true);
AggregateDataCollecting data = (AggregateDataCollecting) aggregateData;
if (data != null) {
for (Value value : data.values) {
if (args.length == 1) {
......@@ -183,7 +178,7 @@ public class JavaAggregate extends AbstractAggregate {
}
}
} else {
agg = (Aggregate) getData(session, groupData, true);
agg = (Aggregate) aggregateData;
if (agg == null) {
agg = getInstance();
}
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
import org.h2.value.Value;
/**
* Partition data of a window aggregate.
*/
final class PartitionData {
/**
* Aggregate data.
*/
private final Object data;
/**
* Evaluated result.
*/
private Value result;
/**
* Creates new instance of partition data.
*
* @param data
* aggregate data
*/
PartitionData(Object data) {
this.data = data;
}
/**
* Returns the aggregate data.
*
* @return the aggregate data
*/
Object getData() {
return data;
}
/**
* Returns the result.
*
* @return the result
*/
Value getResult() {
return result;
}
/**
* Sets the result.
*
* @param result
* the result to set
*/
void setResult(Value result) {
this.result = result;
}
}
......@@ -151,13 +151,48 @@ SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME
> ((4, 5, 6)) c
> rows: 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME FROM TEST GROUP BY NAME ORDER /**/ BY NAME OFFSET 1 ROW;
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER /**/ BY NAME OFFSET 1 ROW;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) OVER (PARTITION BY NAME) NAME
> ------------------------------------------------------------- ----
> ((3)) b
> ((4, 5, 6)) c
> rows: 2
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'b') OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'b')) OVER (PARTITION BY NAME) NAME
> ----------------------------------------------------------------------------------------- ----
> null a
> null b
> ((4, 5, 6)) c
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'c') OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'c')) OVER (PARTITION BY NAME) NAME
> ----------------------------------------------------------------------------------------- ----
> null a
> null b
> null c
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'b') OVER () FROM TEST GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'b')) OVER ()
> ------------------------------------------------------------------------
> ((4, 5, 6))
> ((4, 5, 6))
> ((4, 5, 6))
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'c') OVER () FROM TEST GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'c')) OVER ()
> ------------------------------------------------------------------------
> null
> null
> null
> rows (ordered): 3
SELECT ARRAY_AGG(ID) OVER() FROM TEST GROUP BY NAME;
> exception MUST_GROUP_BY_COLUMN_1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论