Unverified 提交 3b296bb7 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1449 from katzyn/aggregate

Move more code from Aggregate and JavaAggregate to AbstractAggregate
...@@ -217,6 +217,23 @@ ...@@ -217,6 +217,23 @@
</testResource> </testResource>
</testResources> </testResources>
<plugins> <plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<archive>
<manifest>
<addDefaultImplementationEntries>true</addDefaultImplementationEntries>
<mainClass>org.h2.tools.Console</mainClass>
</manifest>
<manifestEntries>
<Automatic-Module-Name>com.h2database</Automatic-Module-Name>
<Multi-Release>true</Multi-Release>
<Premain-Class>org.h2.util.Profiler</Premain-Class>
</manifestEntries>
</archive>
</configuration>
</plugin>
<!-- Add tools folder to test sources but consider moving them to src/test --> <!-- Add tools folder to test sources but consider moving them to src/test -->
<plugin> <plugin>
<groupId>org.codehaus.mojo</groupId> <groupId>org.codehaus.mojo</groupId>
......
...@@ -21,6 +21,8 @@ Change Log ...@@ -21,6 +21,8 @@ Change Log
<h2>Next Version (unreleased)</h2> <h2>Next Version (unreleased)</h2>
<ul> <ul>
<li>PR #1449: Move more code from Aggregate and JavaAggregate to AbstractAggregate
</li>
<li>PR #1448: Add experimental implementation of grouped window queries <li>PR #1448: Add experimental implementation of grouped window queries
</li> </li>
<li>PR #1447: Refactor OVER() processing code and fix some issues <li>PR #1447: Refactor OVER() processing code and fix some issues
......
...@@ -5,19 +5,35 @@ ...@@ -5,19 +5,35 @@
*/ */
package org.h2.expression.aggregate; package org.h2.expression.aggregate;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.ValueArray;
/** /**
* A base class for aggregates. * A base class for aggregates.
*/ */
public abstract class AbstractAggregate extends Expression { public abstract class AbstractAggregate extends Expression {
protected final Select select;
protected final boolean distinct;
protected Expression filterCondition; protected Expression filterCondition;
protected Window over; protected Window over;
private int lastGroupRowId;
AbstractAggregate(Select select, boolean distinct) {
this.select = select;
this.distinct = distinct;
}
/** /**
* Sets the FILTER condition. * Sets the FILTER condition.
* *
...@@ -58,6 +74,103 @@ public abstract class AbstractAggregate extends Expression { ...@@ -58,6 +74,103 @@ public abstract class AbstractAggregate extends Expression {
} }
} }
@Override
public void updateAggregate(Session session, boolean window) {
if (window != (over != null)) {
if (!window && select.isWindowQuery()) {
updateGroupAggregates(session);
if (filterCondition != null) {
filterCondition.updateAggregate(session, false);
}
over.updateAggregate(session, false);
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
if (!select.isGroupQuery()) {
over.updateAggregate(session, true);
}
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
updateAggregate(session, getData(session, groupData, false));
}
/**
* Updates an aggregate value.
*
* @param session
* the session
* @param aggregateData
* aggregate data
*/
protected abstract void updateAggregate(Session session, Object aggregateData);
/**
* Invoked when processing group stage of grouped window queries to update
* arguments of this aggregate.
*
* @param session
* the session
*/
protected abstract void updateGroupAggregates(Session session);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists) {
Object data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
map.put(key, data);
}
} else {
data = groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
protected abstract Object createAggregateData();
protected StringBuilder appendTailConditions(StringBuilder builder) { protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) { if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
......
...@@ -29,7 +29,6 @@ import org.h2.table.Table; ...@@ -29,7 +29,6 @@ import org.h2.table.Table;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.util.StringUtils; import org.h2.util.StringUtils;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -155,8 +154,6 @@ public class Aggregate extends AbstractAggregate { ...@@ -155,8 +154,6 @@ public class Aggregate extends AbstractAggregate {
private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(64); private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(64);
private final AggregateType type; private final AggregateType type;
private final Select select;
private final boolean distinct;
private Expression on; private Expression on;
private Expression groupConcatSeparator; private Expression groupConcatSeparator;
...@@ -165,7 +162,6 @@ public class Aggregate extends AbstractAggregate { ...@@ -165,7 +162,6 @@ public class Aggregate extends AbstractAggregate {
private int dataType, scale; private int dataType, scale;
private long precision; private long precision;
private int displaySize; private int displaySize;
private int lastGroupRowId;
/** /**
* Create a new aggregate object. * Create a new aggregate object.
...@@ -180,10 +176,9 @@ public class Aggregate extends AbstractAggregate { ...@@ -180,10 +176,9 @@ public class Aggregate extends AbstractAggregate {
* if distinct is used * if distinct is used
*/ */
public Aggregate(AggregateType type, Expression on, Select select, boolean distinct) { public Aggregate(AggregateType type, Expression on, Select select, boolean distinct) {
super(select, distinct);
this.type = type; this.type = type;
this.on = on; this.on = on;
this.select = select;
this.distinct = distinct;
} }
static { static {
...@@ -289,52 +284,8 @@ public class Aggregate extends AbstractAggregate { ...@@ -289,52 +284,8 @@ public class Aggregate extends AbstractAggregate {
} }
@Override @Override
public void updateAggregate(Session session, boolean window) { protected void updateAggregate(Session session, Object aggregateData) {
if (window != (over != null)) { AggregateData data = (AggregateData) aggregateData;
if (!window && select.isWindowQuery()) {
if (on != null) {
on.updateAggregate(session, false);
}
if (orderByList != null) {
for (SelectOrderBy orderBy : orderByList) {
orderBy.expression.updateAggregate(session, false);
}
}
if (filterCondition != null) {
filterCondition.updateAggregate(session, false);
}
over.updateAggregate(session, false);
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
if (!select.isGroupQuery()) {
over.updateAggregate(session, true);
}
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
AggregateData data = getData(session, groupData);
Value v = on == null ? null : on.getValue(session); Value v = on == null ? null : on.getValue(session);
if (type == AggregateType.GROUP_CONCAT) { if (type == AggregateType.GROUP_CONCAT) {
if (v != ValueNull.INSTANCE) { if (v != ValueNull.INSTANCE) {
...@@ -348,6 +299,18 @@ public class Aggregate extends AbstractAggregate { ...@@ -348,6 +299,18 @@ public class Aggregate extends AbstractAggregate {
data.add(session.getDatabase(), dataType, distinct, v); data.add(session.getDatabase(), dataType, distinct, v);
} }
@Override
protected void updateGroupAggregates(Session session) {
if (on != null) {
on.updateAggregate(session, false);
}
if (orderByList != null) {
for (SelectOrderBy orderBy : orderByList) {
orderBy.expression.updateAggregate(session, false);
}
}
}
private Value updateCollecting(Session session, Value v) { private Value updateCollecting(Session session, Value v) {
if (orderByList != null) { if (orderByList != null) {
int size = orderByList.size(); int size = orderByList.size();
...@@ -362,6 +325,11 @@ public class Aggregate extends AbstractAggregate { ...@@ -362,6 +325,11 @@ public class Aggregate extends AbstractAggregate {
return v; return v;
} }
@Override
protected Object createAggregateData() {
return AggregateData.create(type);
}
@Override @Override
public Value getValue(Session session) { public Value getValue(Session session) {
if (select.isQuickAggregateQuery()) { if (select.isQuickAggregateQuery()) {
...@@ -400,7 +368,10 @@ public class Aggregate extends AbstractAggregate { ...@@ -400,7 +368,10 @@ public class Aggregate extends AbstractAggregate {
if (groupData == null) { if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL()); throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
AggregateData data = getData(session, groupData); AggregateData data = (AggregateData) getData(session, groupData, true);
if (data == null) {
data = (AggregateData) createAggregateData();
}
switch (type) { switch (type) {
case GROUP_CONCAT: { case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray(); Value[] array = ((AggregateDataCollecting) data).getArray();
...@@ -455,32 +426,6 @@ public class Aggregate extends AbstractAggregate { ...@@ -455,32 +426,6 @@ public class Aggregate extends AbstractAggregate {
} }
} }
private AggregateData getData(Session session, SelectGroups groupData) {
AggregateData data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateData> map = (ValueHashMap<AggregateData>) groupData.getCurrentGroupExprData(this,
true);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
data = AggregateData.create(type);
map.put(key, data);
}
} else {
data = (AggregateData) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
@Override @Override
public int getType() { public int getType() {
return dataType; return dataType;
......
...@@ -20,7 +20,6 @@ import org.h2.message.DbException; ...@@ -20,7 +20,6 @@ import org.h2.message.DbException;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -32,19 +31,15 @@ import org.h2.value.ValueNull; ...@@ -32,19 +31,15 @@ import org.h2.value.ValueNull;
public class JavaAggregate extends AbstractAggregate { public class JavaAggregate extends AbstractAggregate {
private final UserAggregate userAggregate; private final UserAggregate userAggregate;
private final Select select;
private final Expression[] args; private final Expression[] args;
private int[] argTypes; private int[] argTypes;
private final boolean distinct;
private int dataType; private int dataType;
private Connection userConnection; private Connection userConnection;
private int lastGroupRowId;
public JavaAggregate(UserAggregate userAggregate, Expression[] args, Select select, boolean distinct) { public JavaAggregate(UserAggregate userAggregate, Expression[] args, Select select, boolean distinct) {
super(select, distinct);
this.userAggregate = userAggregate; this.userAggregate = userAggregate;
this.args = args; this.args = args;
this.select = select;
this.distinct = distinct;
} }
@Override @Override
...@@ -152,9 +147,13 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -152,9 +147,13 @@ public class JavaAggregate extends AbstractAggregate {
super.setEvaluatable(tableFilter, b); super.setEvaluatable(tableFilter, b);
} }
private Aggregate getInstance() throws SQLException { private Aggregate getInstance() {
Aggregate agg = userAggregate.getInstance(); Aggregate agg = userAggregate.getInstance();
agg.init(userConnection); try {
agg.init(userConnection);
} catch (SQLException ex) {
throw DbException.convert(ex);
}
return agg; return agg;
} }
...@@ -168,7 +167,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -168,7 +167,7 @@ public class JavaAggregate extends AbstractAggregate {
Aggregate agg; Aggregate agg;
if (distinct) { if (distinct) {
agg = getInstance(); agg = getInstance();
AggregateDataCollecting data = getDataDistinct(session, groupData, true); AggregateDataCollecting data = (AggregateDataCollecting) getData(session, groupData, true);
if (data != null) { if (data != null) {
for (Value value : data.values) { for (Value value : data.values) {
if (args.length == 1) { if (args.length == 1) {
...@@ -184,7 +183,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -184,7 +183,7 @@ public class JavaAggregate extends AbstractAggregate {
} }
} }
} else { } else {
agg = getData(session, groupData, true); agg = (Aggregate) getData(session, groupData, true);
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
} }
...@@ -200,35 +199,10 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -200,35 +199,10 @@ public class JavaAggregate extends AbstractAggregate {
} }
@Override @Override
public void updateAggregate(Session session, boolean window) { protected void updateAggregate(Session session, Object aggregateData) {
if (window != (over != null)) {
return;
}
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
over.updateAggregate(session, true);
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
try { try {
if (distinct) { if (distinct) {
AggregateDataCollecting data = getDataDistinct(session, groupData, false); AggregateDataCollecting data = (AggregateDataCollecting) aggregateData;
Value[] argValues = new Value[args.length]; Value[] argValues = new Value[args.length];
Value arg = null; Value arg = null;
for (int i = 0, len = args.length; i < len; i++) { for (int i = 0, len = args.length; i < len; i++) {
...@@ -238,7 +212,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -238,7 +212,7 @@ public class JavaAggregate extends AbstractAggregate {
} }
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues)); data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else { } else {
Aggregate agg = getData(session, groupData, false); Aggregate agg = (Aggregate) aggregateData;
Object[] argValues = new Object[args.length]; Object[] argValues = new Object[args.length];
Object arg = null; Object arg = null;
for (int i = 0, len = args.length; i < len; i++) { for (int i = 0, len = args.length; i < len; i++) {
...@@ -254,73 +228,16 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -254,73 +228,16 @@ public class JavaAggregate extends AbstractAggregate {
} }
} }
private Aggregate getData(Session session, SelectGroups groupData, boolean ifExists) throws SQLException { @Override
Aggregate data; protected void updateGroupAggregates(Session session) {
ValueArray key; for (Expression expr : args) {
if (over != null && (key = over.getCurrentKey(session)) != null) { expr.updateAggregate(session, false);
@SuppressWarnings("unchecked")
ValueHashMap<Aggregate> map = (ValueHashMap<Aggregate>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
map.put(key, data);
}
} else {
data = (Aggregate) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
groupData.setCurrentGroupExprData(this, data, over != null);
}
} }
return data;
} }
private AggregateDataCollecting getDataDistinct(Session session, SelectGroups groupData, boolean ifExists) { @Override
AggregateDataCollecting data; protected Object createAggregateData() {
ValueArray key; return distinct ? new AggregateDataCollecting() : getInstance();
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateDataCollecting> map = (ValueHashMap<AggregateDataCollecting>) groupData
.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
map.put(key, data);
}
} else {
data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
} }
} }
...@@ -48,3 +48,17 @@ SELECT X, COUNT(*), SUM(COUNT(*)) OVER() FROM VALUES (1), (1), (1), (1), (2), (2 ...@@ -48,3 +48,17 @@ SELECT X, COUNT(*), SUM(COUNT(*)) OVER() FROM VALUES (1), (1), (1), (1), (2), (2
> 2 2 7 > 2 2 7
> 3 1 7 > 3 1 7
> rows: 3 > rows: 3
CREATE TABLE TEST(ID INT);
> ok
SELECT SUM(ID) FROM TEST;
>> null
SELECT SUM(ID) OVER () FROM TEST;
> SUM(ID) OVER ()
> ---------------
> rows: 0
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论