Unverified 提交 3b296bb7 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1449 from katzyn/aggregate

Move more code from Aggregate and JavaAggregate to AbstractAggregate
......@@ -217,6 +217,23 @@
</testResource>
</testResources>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<archive>
<manifest>
<addDefaultImplementationEntries>true</addDefaultImplementationEntries>
<mainClass>org.h2.tools.Console</mainClass>
</manifest>
<manifestEntries>
<Automatic-Module-Name>com.h2database</Automatic-Module-Name>
<Multi-Release>true</Multi-Release>
<Premain-Class>org.h2.util.Profiler</Premain-Class>
</manifestEntries>
</archive>
</configuration>
</plugin>
<!-- Add tools folder to test sources but consider moving them to src/test -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
......
......@@ -21,6 +21,8 @@ Change Log
<h2>Next Version (unreleased)</h2>
<ul>
<li>PR #1449: Move more code from Aggregate and JavaAggregate to AbstractAggregate
</li>
<li>PR #1448: Add experimental implementation of grouped window queries
</li>
<li>PR #1447: Refactor OVER() processing code and fix some issues
......
......@@ -5,19 +5,35 @@
*/
package org.h2.expression.aggregate;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.ValueArray;
/**
* A base class for aggregates.
*/
public abstract class AbstractAggregate extends Expression {
protected final Select select;
protected final boolean distinct;
protected Expression filterCondition;
protected Window over;
private int lastGroupRowId;
AbstractAggregate(Select select, boolean distinct) {
this.select = select;
this.distinct = distinct;
}
/**
* Sets the FILTER condition.
*
......@@ -58,6 +74,103 @@ public abstract class AbstractAggregate extends Expression {
}
}
@Override
public void updateAggregate(Session session, boolean window) {
if (window != (over != null)) {
if (!window && select.isWindowQuery()) {
updateGroupAggregates(session);
if (filterCondition != null) {
filterCondition.updateAggregate(session, false);
}
over.updateAggregate(session, false);
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
if (!select.isGroupQuery()) {
over.updateAggregate(session, true);
}
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
updateAggregate(session, getData(session, groupData, false));
}
/**
* Updates an aggregate value.
*
* @param session
* the session
* @param aggregateData
* aggregate data
*/
protected abstract void updateAggregate(Session session, Object aggregateData);
/**
* Invoked when processing group stage of grouped window queries to update
* arguments of this aggregate.
*
* @param session
* the session
*/
protected abstract void updateGroupAggregates(Session session);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists) {
Object data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
map.put(key, data);
}
} else {
data = groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
protected abstract Object createAggregateData();
protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
......
......@@ -29,7 +29,6 @@ import org.h2.table.Table;
import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder;
import org.h2.util.StringUtils;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueArray;
......@@ -155,8 +154,6 @@ public class Aggregate extends AbstractAggregate {
private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(64);
private final AggregateType type;
private final Select select;
private final boolean distinct;
private Expression on;
private Expression groupConcatSeparator;
......@@ -165,7 +162,6 @@ public class Aggregate extends AbstractAggregate {
private int dataType, scale;
private long precision;
private int displaySize;
private int lastGroupRowId;
/**
* Create a new aggregate object.
......@@ -180,10 +176,9 @@ public class Aggregate extends AbstractAggregate {
* if distinct is used
*/
public Aggregate(AggregateType type, Expression on, Select select, boolean distinct) {
super(select, distinct);
this.type = type;
this.on = on;
this.select = select;
this.distinct = distinct;
}
static {
......@@ -289,52 +284,8 @@ public class Aggregate extends AbstractAggregate {
}
@Override
public void updateAggregate(Session session, boolean window) {
if (window != (over != null)) {
if (!window && select.isWindowQuery()) {
if (on != null) {
on.updateAggregate(session, false);
}
if (orderByList != null) {
for (SelectOrderBy orderBy : orderByList) {
orderBy.expression.updateAggregate(session, false);
}
}
if (filterCondition != null) {
filterCondition.updateAggregate(session, false);
}
over.updateAggregate(session, false);
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
if (!select.isGroupQuery()) {
over.updateAggregate(session, true);
}
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
AggregateData data = getData(session, groupData);
protected void updateAggregate(Session session, Object aggregateData) {
AggregateData data = (AggregateData) aggregateData;
Value v = on == null ? null : on.getValue(session);
if (type == AggregateType.GROUP_CONCAT) {
if (v != ValueNull.INSTANCE) {
......@@ -348,6 +299,18 @@ public class Aggregate extends AbstractAggregate {
data.add(session.getDatabase(), dataType, distinct, v);
}
@Override
protected void updateGroupAggregates(Session session) {
if (on != null) {
on.updateAggregate(session, false);
}
if (orderByList != null) {
for (SelectOrderBy orderBy : orderByList) {
orderBy.expression.updateAggregate(session, false);
}
}
}
private Value updateCollecting(Session session, Value v) {
if (orderByList != null) {
int size = orderByList.size();
......@@ -362,6 +325,11 @@ public class Aggregate extends AbstractAggregate {
return v;
}
@Override
protected Object createAggregateData() {
return AggregateData.create(type);
}
@Override
public Value getValue(Session session) {
if (select.isQuickAggregateQuery()) {
......@@ -400,7 +368,10 @@ public class Aggregate extends AbstractAggregate {
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
AggregateData data = getData(session, groupData);
AggregateData data = (AggregateData) getData(session, groupData, true);
if (data == null) {
data = (AggregateData) createAggregateData();
}
switch (type) {
case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray();
......@@ -455,32 +426,6 @@ public class Aggregate extends AbstractAggregate {
}
}
private AggregateData getData(Session session, SelectGroups groupData) {
AggregateData data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateData> map = (ValueHashMap<AggregateData>) groupData.getCurrentGroupExprData(this,
true);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
data = AggregateData.create(type);
map.put(key, data);
}
} else {
data = (AggregateData) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
@Override
public int getType() {
return dataType;
......
......@@ -20,7 +20,6 @@ import org.h2.message.DbException;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueArray;
......@@ -32,19 +31,15 @@ import org.h2.value.ValueNull;
public class JavaAggregate extends AbstractAggregate {
private final UserAggregate userAggregate;
private final Select select;
private final Expression[] args;
private int[] argTypes;
private final boolean distinct;
private int dataType;
private Connection userConnection;
private int lastGroupRowId;
public JavaAggregate(UserAggregate userAggregate, Expression[] args, Select select, boolean distinct) {
super(select, distinct);
this.userAggregate = userAggregate;
this.args = args;
this.select = select;
this.distinct = distinct;
}
@Override
......@@ -152,9 +147,13 @@ public class JavaAggregate extends AbstractAggregate {
super.setEvaluatable(tableFilter, b);
}
private Aggregate getInstance() throws SQLException {
private Aggregate getInstance() {
Aggregate agg = userAggregate.getInstance();
agg.init(userConnection);
try {
agg.init(userConnection);
} catch (SQLException ex) {
throw DbException.convert(ex);
}
return agg;
}
......@@ -168,7 +167,7 @@ public class JavaAggregate extends AbstractAggregate {
Aggregate agg;
if (distinct) {
agg = getInstance();
AggregateDataCollecting data = getDataDistinct(session, groupData, true);
AggregateDataCollecting data = (AggregateDataCollecting) getData(session, groupData, true);
if (data != null) {
for (Value value : data.values) {
if (args.length == 1) {
......@@ -184,7 +183,7 @@ public class JavaAggregate extends AbstractAggregate {
}
}
} else {
agg = getData(session, groupData, true);
agg = (Aggregate) getData(session, groupData, true);
if (agg == null) {
agg = getInstance();
}
......@@ -200,35 +199,10 @@ public class JavaAggregate extends AbstractAggregate {
}
@Override
public void updateAggregate(Session session, boolean window) {
if (window != (over != null)) {
return;
}
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
over.updateAggregate(session, true);
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
protected void updateAggregate(Session session, Object aggregateData) {
try {
if (distinct) {
AggregateDataCollecting data = getDataDistinct(session, groupData, false);
AggregateDataCollecting data = (AggregateDataCollecting) aggregateData;
Value[] argValues = new Value[args.length];
Value arg = null;
for (int i = 0, len = args.length; i < len; i++) {
......@@ -238,7 +212,7 @@ public class JavaAggregate extends AbstractAggregate {
}
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else {
Aggregate agg = getData(session, groupData, false);
Aggregate agg = (Aggregate) aggregateData;
Object[] argValues = new Object[args.length];
Object arg = null;
for (int i = 0, len = args.length; i < len; i++) {
......@@ -254,73 +228,16 @@ public class JavaAggregate extends AbstractAggregate {
}
}
private Aggregate getData(Session session, SelectGroups groupData, boolean ifExists) throws SQLException {
Aggregate data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Aggregate> map = (ValueHashMap<Aggregate>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
map.put(key, data);
}
} else {
data = (Aggregate) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
groupData.setCurrentGroupExprData(this, data, over != null);
}
@Override
protected void updateGroupAggregates(Session session) {
for (Expression expr : args) {
expr.updateAggregate(session, false);
}
return data;
}
private AggregateDataCollecting getDataDistinct(Session session, SelectGroups groupData, boolean ifExists) {
AggregateDataCollecting data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateDataCollecting> map = (ValueHashMap<AggregateDataCollecting>) groupData
.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
map.put(key, data);
}
} else {
data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
@Override
protected Object createAggregateData() {
return distinct ? new AggregateDataCollecting() : getInstance();
}
}
......@@ -48,3 +48,17 @@ SELECT X, COUNT(*), SUM(COUNT(*)) OVER() FROM VALUES (1), (1), (1), (1), (2), (2
> 2 2 7
> 3 1 7
> rows: 3
CREATE TABLE TEST(ID INT);
> ok
SELECT SUM(ID) FROM TEST;
>> null
SELECT SUM(ID) OVER () FROM TEST;
> SUM(ID) OVER ()
> ---------------
> rows: 0
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论