提交 afa9a0b5 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Move more common code to AbstractAggregate

上级 bd864e33
...@@ -11,6 +11,8 @@ import org.h2.engine.Session; ...@@ -11,6 +11,8 @@ import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.ValueArray;
/** /**
* A base class for aggregates. * A base class for aggregates.
...@@ -111,7 +113,7 @@ public abstract class AbstractAggregate extends Expression { ...@@ -111,7 +113,7 @@ public abstract class AbstractAggregate extends Expression {
return; return;
} }
} }
updateAggregate(session, groupData); updateAggregate(session, getData(session, groupData, false));
} }
/** /**
...@@ -119,10 +121,10 @@ public abstract class AbstractAggregate extends Expression { ...@@ -119,10 +121,10 @@ public abstract class AbstractAggregate extends Expression {
* *
* @param session * @param session
* the session * the session
* @param groupData * @param aggregateData
* group data from the select * aggregate data
*/ */
protected abstract void updateAggregate(Session session, SelectGroups groupData); protected abstract void updateAggregate(Session session, Object aggregateData);
/** /**
* Invoked when processing group stage of grouped window queries to update * Invoked when processing group stage of grouped window queries to update
...@@ -133,6 +135,42 @@ public abstract class AbstractAggregate extends Expression { ...@@ -133,6 +135,42 @@ public abstract class AbstractAggregate extends Expression {
*/ */
protected abstract void updateGroupAggregates(Session session); protected abstract void updateGroupAggregates(Session session);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists) {
Object data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
map.put(key, data);
}
} else {
data = groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
protected abstract Object createAggregateData();
protected StringBuilder appendTailConditions(StringBuilder builder) { protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) { if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
......
...@@ -29,7 +29,6 @@ import org.h2.table.Table; ...@@ -29,7 +29,6 @@ import org.h2.table.Table;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.util.StringUtils; import org.h2.util.StringUtils;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -285,8 +284,8 @@ public class Aggregate extends AbstractAggregate { ...@@ -285,8 +284,8 @@ public class Aggregate extends AbstractAggregate {
} }
@Override @Override
protected void updateAggregate(Session session, SelectGroups groupData) { protected void updateAggregate(Session session, Object aggregateData) {
AggregateData data = getData(session, groupData); AggregateData data = (AggregateData) aggregateData;
Value v = on == null ? null : on.getValue(session); Value v = on == null ? null : on.getValue(session);
if (type == AggregateType.GROUP_CONCAT) { if (type == AggregateType.GROUP_CONCAT) {
if (v != ValueNull.INSTANCE) { if (v != ValueNull.INSTANCE) {
...@@ -326,6 +325,11 @@ public class Aggregate extends AbstractAggregate { ...@@ -326,6 +325,11 @@ public class Aggregate extends AbstractAggregate {
return v; return v;
} }
@Override
protected Object createAggregateData() {
return AggregateData.create(type);
}
@Override @Override
public Value getValue(Session session) { public Value getValue(Session session) {
if (select.isQuickAggregateQuery()) { if (select.isQuickAggregateQuery()) {
...@@ -364,7 +368,10 @@ public class Aggregate extends AbstractAggregate { ...@@ -364,7 +368,10 @@ public class Aggregate extends AbstractAggregate {
if (groupData == null) { if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL()); throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
AggregateData data = getData(session, groupData); AggregateData data = (AggregateData) getData(session, groupData, true);
if (data == null) {
data = (AggregateData) createAggregateData();
}
switch (type) { switch (type) {
case GROUP_CONCAT: { case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray(); Value[] array = ((AggregateDataCollecting) data).getArray();
...@@ -419,32 +426,6 @@ public class Aggregate extends AbstractAggregate { ...@@ -419,32 +426,6 @@ public class Aggregate extends AbstractAggregate {
} }
} }
private AggregateData getData(Session session, SelectGroups groupData) {
AggregateData data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateData> map = (ValueHashMap<AggregateData>) groupData.getCurrentGroupExprData(this,
true);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
data = AggregateData.create(type);
map.put(key, data);
}
} else {
data = (AggregateData) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
@Override @Override
public int getType() { public int getType() {
return dataType; return dataType;
......
...@@ -20,7 +20,6 @@ import org.h2.message.DbException; ...@@ -20,7 +20,6 @@ import org.h2.message.DbException;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType; import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -148,9 +147,13 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -148,9 +147,13 @@ public class JavaAggregate extends AbstractAggregate {
super.setEvaluatable(tableFilter, b); super.setEvaluatable(tableFilter, b);
} }
private Aggregate getInstance() throws SQLException { private Aggregate getInstance() {
Aggregate agg = userAggregate.getInstance(); Aggregate agg = userAggregate.getInstance();
agg.init(userConnection); try {
agg.init(userConnection);
} catch (SQLException ex) {
throw DbException.convert(ex);
}
return agg; return agg;
} }
...@@ -164,7 +167,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -164,7 +167,7 @@ public class JavaAggregate extends AbstractAggregate {
Aggregate agg; Aggregate agg;
if (distinct) { if (distinct) {
agg = getInstance(); agg = getInstance();
AggregateDataCollecting data = getDataDistinct(session, groupData, true); AggregateDataCollecting data = (AggregateDataCollecting) getData(session, groupData, true);
if (data != null) { if (data != null) {
for (Value value : data.values) { for (Value value : data.values) {
if (args.length == 1) { if (args.length == 1) {
...@@ -180,7 +183,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -180,7 +183,7 @@ public class JavaAggregate extends AbstractAggregate {
} }
} }
} else { } else {
agg = getData(session, groupData, true); agg = (Aggregate) getData(session, groupData, true);
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
} }
...@@ -196,10 +199,10 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -196,10 +199,10 @@ public class JavaAggregate extends AbstractAggregate {
} }
@Override @Override
protected void updateAggregate(Session session, SelectGroups groupData) { protected void updateAggregate(Session session, Object aggregateData) {
try { try {
if (distinct) { if (distinct) {
AggregateDataCollecting data = getDataDistinct(session, groupData, false); AggregateDataCollecting data = (AggregateDataCollecting) aggregateData;
Value[] argValues = new Value[args.length]; Value[] argValues = new Value[args.length];
Value arg = null; Value arg = null;
for (int i = 0, len = args.length; i < len; i++) { for (int i = 0, len = args.length; i < len; i++) {
...@@ -209,7 +212,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -209,7 +212,7 @@ public class JavaAggregate extends AbstractAggregate {
} }
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues)); data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else { } else {
Aggregate agg = getData(session, groupData, false); Aggregate agg = (Aggregate) aggregateData;
Object[] argValues = new Object[args.length]; Object[] argValues = new Object[args.length];
Object arg = null; Object arg = null;
for (int i = 0, len = args.length; i < len; i++) { for (int i = 0, len = args.length; i < len; i++) {
...@@ -232,73 +235,9 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -232,73 +235,9 @@ public class JavaAggregate extends AbstractAggregate {
} }
} }
private Aggregate getData(Session session, SelectGroups groupData, boolean ifExists) throws SQLException { @Override
Aggregate data; protected Object createAggregateData() {
ValueArray key; return distinct ? new AggregateDataCollecting() : getInstance();
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Aggregate> map = (ValueHashMap<Aggregate>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
map.put(key, data);
}
} else {
data = (Aggregate) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
private AggregateDataCollecting getDataDistinct(Session session, SelectGroups groupData, boolean ifExists) {
AggregateDataCollecting data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateDataCollecting> map = (ValueHashMap<AggregateDataCollecting>) groupData
.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
map.put(key, data);
}
} else {
data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
} }
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论