Unverified 提交 19820759 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1450 from katzyn/aggregate

Evaluate window aggregates only once for each partition
...@@ -5,13 +5,16 @@ ...@@ -5,13 +5,16 @@
*/ */
package org.h2.expression.aggregate; package org.h2.expression.aggregate;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups; import org.h2.command.dml.SelectGroups;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.message.DbException;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap; import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
/** /**
...@@ -137,8 +140,9 @@ public abstract class AbstractAggregate extends Expression { ...@@ -137,8 +140,9 @@ public abstract class AbstractAggregate extends Expression {
protected Object getData(Session session, SelectGroups groupData, boolean ifExists) { protected Object getData(Session session, SelectGroups groupData, boolean ifExists) {
Object data; Object data;
ValueArray key; if (over != null) {
if (over != null && (key = over.getCurrentKey(session)) != null) { ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true); ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) { if (map == null) {
...@@ -148,22 +152,36 @@ public abstract class AbstractAggregate extends Expression { ...@@ -148,22 +152,36 @@ public abstract class AbstractAggregate extends Expression {
map = new ValueHashMap<>(); map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true); groupData.setCurrentGroupExprData(this, map, true);
} }
data = map.get(key); PartitionData partition = (PartitionData) map.get(key);
if (data == null) { if (partition == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
map.put(key, new PartitionData(data));
} else {
data = partition.getData();
}
} else {
PartitionData partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
if (ifExists) { if (ifExists) {
return null; return null;
} }
data = createAggregateData(); data = createAggregateData();
map.put(key, data); groupData.setCurrentGroupExprData(this, new PartitionData(data), true);
} else {
data = partition.getData();
}
} }
} else { } else {
data = groupData.getCurrentGroupExprData(this, over != null); data = groupData.getCurrentGroupExprData(this, false);
if (data == null) { if (data == null) {
if (ifExists) { if (ifExists) {
return null; return null;
} }
data = createAggregateData(); data = createAggregateData();
groupData.setCurrentGroupExprData(this, data, over != null); groupData.setCurrentGroupExprData(this, data, false);
} }
} }
return data; return data;
...@@ -171,6 +189,55 @@ public abstract class AbstractAggregate extends Expression { ...@@ -171,6 +189,55 @@ public abstract class AbstractAggregate extends Expression {
protected abstract Object createAggregateData(); protected abstract Object createAggregateData();
@Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
return over == null ? getAggregatedValue(session, getData(session, groupData, true))
: getWindowResult(session, groupData);
}
private Value getWindowResult(Session session, SelectGroups groupData) {
PartitionData partition;
Object data;
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
partition = (PartitionData) map.get(key);
if (partition == null) {
data = createAggregateData();
partition = new PartitionData(data);
map.put(key, partition);
} else {
data = partition.getData();
}
} else {
partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
data = createAggregateData();
partition = new PartitionData(data);
groupData.setCurrentGroupExprData(this, partition, true);
} else {
data = partition.getData();
}
}
Value result = partition.getResult();
if (result == null) {
result = getAggregatedValue(session, data);
partition.setResult(result);
}
return result;
}
protected abstract Value getAggregatedValue(Session session, Object aggregateData);
protected StringBuilder appendTailConditions(StringBuilder builder) { protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) { if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
......
...@@ -11,7 +11,6 @@ import java.util.Comparator; ...@@ -11,7 +11,6 @@ import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import org.h2.api.ErrorCode; import org.h2.api.ErrorCode;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy; import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
...@@ -332,7 +331,10 @@ public class Aggregate extends AbstractAggregate { ...@@ -332,7 +331,10 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
public Value getValue(Session session) { public Value getValue(Session session) {
if (select.isQuickAggregateQuery()) { return select.isQuickAggregateQuery() ? getValueQuick(session) : super.getValue(session);
}
private Value getValueQuick(Session session) {
switch (type) { switch (type) {
case COUNT: case COUNT:
case COUNT_ALL: case COUNT_ALL:
...@@ -361,14 +363,13 @@ public class Aggregate extends AbstractAggregate { ...@@ -361,14 +363,13 @@ public class Aggregate extends AbstractAggregate {
case ENVELOPE: case ENVELOPE:
return ((MVSpatialIndex) AggregateDataEnvelope.getGeometryColumnIndex(on)).getBounds(session); return ((MVSpatialIndex) AggregateDataEnvelope.getGeometryColumnIndex(on)).getBounds(session);
default: default:
DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
}
} }
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
AggregateData data = (AggregateData) getData(session, groupData, true);
@Override
public Value getAggregatedValue(Session session, Object aggregateData) {
AggregateData data = (AggregateData) aggregateData;
if (data == null) { if (data == null) {
data = (AggregateData) createAggregateData(); data = (AggregateData) createAggregateData();
} }
......
...@@ -8,10 +8,8 @@ package org.h2.expression.aggregate; ...@@ -8,10 +8,8 @@ package org.h2.expression.aggregate;
import java.sql.Connection; import java.sql.Connection;
import java.sql.SQLException; import java.sql.SQLException;
import org.h2.api.Aggregate; import org.h2.api.Aggregate;
import org.h2.api.ErrorCode;
import org.h2.command.Parser; import org.h2.command.Parser;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.engine.UserAggregate; import org.h2.engine.UserAggregate;
import org.h2.expression.Expression; import org.h2.expression.Expression;
...@@ -158,16 +156,12 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -158,16 +156,12 @@ public class JavaAggregate extends AbstractAggregate {
} }
@Override @Override
public Value getValue(Session session) { public Value getAggregatedValue(Session session, Object aggregateData) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
try { try {
Aggregate agg; Aggregate agg;
if (distinct) { if (distinct) {
agg = getInstance(); agg = getInstance();
AggregateDataCollecting data = (AggregateDataCollecting) getData(session, groupData, true); AggregateDataCollecting data = (AggregateDataCollecting) aggregateData;
if (data != null) { if (data != null) {
for (Value value : data.values) { for (Value value : data.values) {
if (args.length == 1) { if (args.length == 1) {
...@@ -183,7 +177,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -183,7 +177,7 @@ public class JavaAggregate extends AbstractAggregate {
} }
} }
} else { } else {
agg = (Aggregate) getData(session, groupData, true); agg = (Aggregate) aggregateData;
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
} }
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
import org.h2.value.Value;
/**
* Partition data of a window aggregate.
*/
final class PartitionData {
/**
* Aggregate data.
*/
private final Object data;
/**
* Evaluated result.
*/
private Value result;
/**
* Creates new instance of partition data.
*
* @param data
* aggregate data
*/
PartitionData(Object data) {
this.data = data;
}
/**
* Returns the aggregate data.
*
* @return the aggregate data
*/
Object getData() {
return data;
}
/**
* Returns the result.
*
* @return the result
*/
Value getResult() {
return result;
}
/**
* Sets the result.
*
* @param result
* the result to set
*/
void setResult(Value result) {
this.result = result;
}
}
...@@ -84,6 +84,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -84,6 +84,7 @@ public class LocalResultImpl implements LocalResult {
return false; return false;
} }
@Override
public void setMaxMemoryRows(int maxValue) { public void setMaxMemoryRows(int maxValue) {
this.maxMemoryRows = maxValue; this.maxMemoryRows = maxValue;
} }
...@@ -134,6 +135,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -134,6 +135,7 @@ public class LocalResultImpl implements LocalResult {
* *
* @param sort the sort order * @param sort the sort order
*/ */
@Override
public void setSortOrder(SortOrder sort) { public void setSortOrder(SortOrder sort) {
this.sort = sort; this.sort = sort;
} }
...@@ -141,6 +143,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -141,6 +143,7 @@ public class LocalResultImpl implements LocalResult {
/** /**
* Remove duplicate rows. * Remove duplicate rows.
*/ */
@Override
public void setDistinct() { public void setDistinct() {
assert distinctIndexes == null; assert distinctIndexes == null;
distinct = true; distinct = true;
...@@ -152,6 +155,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -152,6 +155,7 @@ public class LocalResultImpl implements LocalResult {
* *
* @param distinctIndexes distinct indexes * @param distinctIndexes distinct indexes
*/ */
@Override
public void setDistinct(int[] distinctIndexes) { public void setDistinct(int[] distinctIndexes) {
assert !distinct; assert !distinct;
this.distinctIndexes = distinctIndexes; this.distinctIndexes = distinctIndexes;
...@@ -170,6 +174,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -170,6 +174,7 @@ public class LocalResultImpl implements LocalResult {
* *
* @param values the row * @param values the row
*/ */
@Override
public void removeDistinct(Value[] values) { public void removeDistinct(Value[] values) {
if (!distinct) { if (!distinct) {
DbException.throwInternalError(); DbException.throwInternalError();
...@@ -329,6 +334,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -329,6 +334,7 @@ public class LocalResultImpl implements LocalResult {
/** /**
* This method is called after all rows have been added. * This method is called after all rows have been added.
*/ */
@Override
public void done() { public void done() {
if (external != null) { if (external != null) {
addRowsToDisk(); addRowsToDisk();
...@@ -455,6 +461,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -455,6 +461,7 @@ public class LocalResultImpl implements LocalResult {
* *
* @param limit the limit (-1 means no limit, 0 means no rows) * @param limit the limit (-1 means no limit, 0 means no rows)
*/ */
@Override
public void setLimit(int limit) { public void setLimit(int limit) {
this.limit = limit; this.limit = limit;
} }
...@@ -462,6 +469,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -462,6 +469,7 @@ public class LocalResultImpl implements LocalResult {
/** /**
* @param fetchPercent whether limit expression specifies percentage of rows * @param fetchPercent whether limit expression specifies percentage of rows
*/ */
@Override
public void setFetchPercent(boolean fetchPercent) { public void setFetchPercent(boolean fetchPercent) {
this.fetchPercent = fetchPercent; this.fetchPercent = fetchPercent;
} }
...@@ -469,6 +477,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -469,6 +477,7 @@ public class LocalResultImpl implements LocalResult {
/** /**
* @param withTies whether tied rows should be included in result too * @param withTies whether tied rows should be included in result too
*/ */
@Override
public void setWithTies(boolean withTies) { public void setWithTies(boolean withTies) {
this.withTies = withTies; this.withTies = withTies;
} }
...@@ -542,6 +551,7 @@ public class LocalResultImpl implements LocalResult { ...@@ -542,6 +551,7 @@ public class LocalResultImpl implements LocalResult {
* *
* @param offset the offset * @param offset the offset
*/ */
@Override
public void setOffset(int offset) { public void setOffset(int offset) {
this.offset = offset; this.offset = offset;
} }
......
...@@ -151,13 +151,48 @@ SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME ...@@ -151,13 +151,48 @@ SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME
> ((4, 5, 6)) c > ((4, 5, 6)) c
> rows: 3 > rows: 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME FROM TEST GROUP BY NAME ORDER /**/ BY NAME OFFSET 1 ROW; SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER /**/ BY ID)) OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER /**/ BY NAME OFFSET 1 ROW;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) OVER (PARTITION BY NAME) NAME > ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) OVER (PARTITION BY NAME) NAME
> ------------------------------------------------------------- ---- > ------------------------------------------------------------- ----
> ((3)) b > ((3)) b
> ((4, 5, 6)) c > ((4, 5, 6)) c
> rows: 2 > rows: 2
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'b') OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'b')) OVER (PARTITION BY NAME) NAME
> ----------------------------------------------------------------------------------------- ----
> null a
> null b
> ((4, 5, 6)) c
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'c') OVER (PARTITION BY NAME), NAME FROM TEST
GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'c')) OVER (PARTITION BY NAME) NAME
> ----------------------------------------------------------------------------------------- ----
> null a
> null b
> null c
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'b') OVER () FROM TEST GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'b')) OVER ()
> ------------------------------------------------------------------------
> ((4, 5, 6))
> ((4, 5, 6))
> ((4, 5, 6))
> rows (ordered): 3
SELECT ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE NAME > 'c') OVER () FROM TEST GROUP BY NAME ORDER BY NAME;
> ARRAY_AGG(ARRAY_AGG(ID ORDER BY ID)) FILTER (WHERE (NAME > 'c')) OVER ()
> ------------------------------------------------------------------------
> null
> null
> null
> rows (ordered): 3
SELECT ARRAY_AGG(ID) OVER() FROM TEST GROUP BY NAME; SELECT ARRAY_AGG(ID) OVER() FROM TEST GROUP BY NAME;
> exception MUST_GROUP_BY_COLUMN_1 > exception MUST_GROUP_BY_COLUMN_1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论