Unverified 提交 19820759 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1450 from katzyn/aggregate

Evaluate window aggregates only once for each partition
......@@ -5,13 +5,16 @@
*/
package org.h2.expression.aggregate;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.message.DbException;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueArray;
/**
......@@ -137,8 +140,9 @@ public abstract class AbstractAggregate extends Expression {
protected Object getData(Session session, SelectGroups groupData, boolean ifExists) {
Object data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
if (over != null) {
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
......@@ -148,22 +152,36 @@ public abstract class AbstractAggregate extends Expression {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
PartitionData partition = (PartitionData) map.get(key);
if (partition == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
map.put(key, new PartitionData(data));
} else {
data = partition.getData();
}
} else {
PartitionData partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
map.put(key, data);
groupData.setCurrentGroupExprData(this, new PartitionData(data), true);
} else {
data = partition.getData();
}
}
} else {
data = groupData.getCurrentGroupExprData(this, over != null);
data = groupData.getCurrentGroupExprData(this, false);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
groupData.setCurrentGroupExprData(this, data, over != null);
groupData.setCurrentGroupExprData(this, data, false);
}
}
return data;
......@@ -171,6 +189,55 @@ public abstract class AbstractAggregate extends Expression {
protected abstract Object createAggregateData();
@Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
return over == null ? getAggregatedValue(session, getData(session, groupData, true))
: getWindowResult(session, groupData);
}
private Value getWindowResult(Session session, SelectGroups groupData) {
PartitionData partition;
Object data;
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
partition = (PartitionData) map.get(key);
if (partition == null) {
data = createAggregateData();
partition = new PartitionData(data);
map.put(key, partition);
} else {
data = partition.getData();
}
} else {
partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
data = createAggregateData();
partition = new PartitionData(data);
groupData.setCurrentGroupExprData(this, partition, true);
} else {
data = partition.getData();
}
}
Value result = partition.getResult();
if (result == null) {
result = getAggregatedValue(session, data);
partition.setResult(result);
}
return result;
}
protected abstract Value getAggregatedValue(Session session, Object aggregateData);
protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
......
......@@ -11,7 +11,6 @@ import java.util.Comparator;
import java.util.HashMap;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session;
import org.h2.expression.Expression;
......@@ -332,7 +331,10 @@ public class Aggregate extends AbstractAggregate {
@Override
public Value getValue(Session session) {
if (select.isQuickAggregateQuery()) {
return select.isQuickAggregateQuery() ? getValueQuick(session) : super.getValue(session);
}
private Value getValueQuick(Session session) {
switch (type) {
case COUNT:
case COUNT_ALL:
......@@ -361,14 +363,13 @@ public class Aggregate extends AbstractAggregate {
case ENVELOPE:
return ((MVSpatialIndex) AggregateDataEnvelope.getGeometryColumnIndex(on)).getBounds(session);
default:
DbException.throwInternalError("type=" + type);
}
throw DbException.throwInternalError("type=" + type);
}
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
AggregateData data = (AggregateData) getData(session, groupData, true);
@Override
public Value getAggregatedValue(Session session, Object aggregateData) {
AggregateData data = (AggregateData) aggregateData;
if (data == null) {
data = (AggregateData) createAggregateData();
}
......
......@@ -8,10 +8,8 @@ package org.h2.expression.aggregate;
import java.sql.Connection;
import java.sql.SQLException;
import org.h2.api.Aggregate;
import org.h2.api.ErrorCode;
import org.h2.command.Parser;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.engine.Session;
import org.h2.engine.UserAggregate;
import org.h2.expression.Expression;
......@@ -158,16 +156,12 @@ public class JavaAggregate extends AbstractAggregate {
}
@Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
public Value getAggregatedValue(Session session, Object aggregateData) {
try {
Aggregate agg;
if (distinct) {
agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) getData(session, groupData, true);
AggregateDataCollecting data = (AggregateDataCollecting) aggregateData;
if (data != null) {
for (Value value : data.values) {
if (args.length == 1) {
......@@ -183,7 +177,7 @@ public class JavaAggregate extends AbstractAggregate {
}
}
} else {
agg = (Aggregate) getData(session, groupData, true);
agg = (Aggregate) aggregateData;
if (agg == null) {
agg = getInstance();
}
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
import org.h2.value.Value;
/**
* Partition data of a window aggregate.
*/
final class PartitionData {
/**
* Aggregate data.
*/
private final Object data;
/**
* Evaluated result.
*/
private Value result;
/**
* Creates new instance of partition data.
*
* @param data
* aggregate data
*/
PartitionData(Object data) {
this.data = data;
}
/**
* Returns the aggregate data.
*
* @return the aggregate data
*/
Object getData() {
return data;
}
/**
* Returns the result.
*
* @return the result
*/
Value getResult() {
return result;
}
/**
* Sets the result.
*
* @param result
* the result to set
*/
void setResult(Value result) {
this.result = result;
}
}
......@@ -84,6 +84,7 @@ public class LocalResultImpl implements LocalResult {
return false;
}
@Override
public void setMaxMemoryRows(int maxValue) {
this.maxMemoryRows = maxValue;
}
......@@ -134,6 +135,7 @@ public class LocalResultImpl implements LocalResult {
*
* @param sort the sort order
*/
@Override
public void setSortOrder(SortOrder sort) {
this.sort = sort;
}
......@@ -141,6 +143,7 @@ public class LocalResultImpl implements LocalResult {
/**
* Remove duplicate rows.
*/
@Override
public void setDistinct() {
assert distinctIndexes == null;
distinct = true;
......@@ -152,6 +155,7 @@ public class LocalResultImpl implements LocalResult {
*
* @param distinctIndexes distinct indexes
*/
@Override
public void setDistinct(int[] distinctIndexes) {
assert !distinct;
this.distinctIndexes = distinctIndexes;
......@@ -170,6 +174,7 @@ public class LocalResultImpl implements LocalResult {
*
* @param values the row
*/
@Override
public void removeDistinct(Value[] values) {
if (!distinct) {
DbException.throwInternalError();
......@@ -329,6 +334,7 @@ public class LocalResultImpl implements LocalResult {
/**
* This method is called after all rows have been added.
*/
@Override
public void done() {
if (external != null) {
addRowsToDisk();
......@@ -455,6 +461,7 @@ public class LocalResultImpl implements LocalResult {
*
* @param limit the limit (-1 means no limit, 0 means no rows)
*/
@Override
public void setLimit(int limit) {
this.limit = limit;
}
......@@ -462,6 +469,7 @@ public class LocalResultImpl implements LocalResult {
/**
* @param fetchPercent whether limit expression specifies percentage of rows
*/
@Override
public void setFetchPercent(boolean fetchPercent) {
this.fetchPercent = fetchPercent;
}
......@@ -469,6 +477,7 @@ public class LocalResultImpl implements LocalResult {
/**
* @param withTies whether tied rows should be included in result too
*/
@Override
public void setWithTies(boolean withTies) {
this.withTies = withTies;
}
......@@ -542,6 +551,7 @@ public class LocalResultImpl implements LocalResult {
*
* @param offset the offset
*/
@Override
public void setOffset(int offset) {
this.offset = offset;
}
......
......@@ -151,13 +151,48 @@ SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME
> ((4, 5, 6)) c
> rows: 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME FROM TEST GROUP BY NAME ORDER /**/ BY NAME OFFSET 1 ROW;
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER /**/ BY NAME OFFSET 1 ROW;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) OVER (PARTITION BY NAME) NAME
> ------------------------------------------------------------- ----
> ((3)) b
> ((4, 5, 6)) c
> rows: 2
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'b') OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'b')) OVER (PARTITION BY NAME) NAME
> ----------------------------------------------------------------------------------------- ----
> null a
> null b
> ((4, 5, 6)) c
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'c') OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'c')) OVER (PARTITION BY NAME) NAME
> ----------------------------------------------------------------------------------------- ----
> null a
> null b
> null c
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'b') OVER () FROM TEST GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'b')) OVER ()
> ------------------------------------------------------------------------
> ((4, 5, 6))
> ((4, 5, 6))
> ((4, 5, 6))
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'c') OVER () FROM TEST GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'c')) OVER ()
> ------------------------------------------------------------------------
> null
> null
> null
> rows (ordered): 3
SELECT ARRAY_AGG(ID) OVER() FROM TEST GROUP BY NAME;
> exception MUST_GROUP_BY_COLUMN_1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论