提交 afa9a0b5 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Move more common code to AbstractAggregate

上级 bd864e33
......@@ -11,6 +11,8 @@ import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.ValueArray;
/**
* A base class for aggregates.
......@@ -111,7 +113,7 @@ public abstract class AbstractAggregate extends Expression {
return;
}
}
updateAggregate(session, groupData);
updateAggregate(session, getData(session, groupData, false));
}
/**
......@@ -119,10 +121,10 @@ public abstract class AbstractAggregate extends Expression {
*
* @param session
* the session
* @param groupData
* group data from the select
* @param aggregateData
* aggregate data
*/
protected abstract void updateAggregate(Session session, SelectGroups groupData);
protected abstract void updateAggregate(Session session, Object aggregateData);
/**
* Invoked when processing group stage of grouped window queries to update
......@@ -133,6 +135,42 @@ public abstract class AbstractAggregate extends Expression {
*/
protected abstract void updateGroupAggregates(Session session);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists) {
Object data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
map.put(key, data);
}
} else {
data = groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = createAggregateData();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
protected abstract Object createAggregateData();
protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
......
......@@ -29,7 +29,6 @@ import org.h2.table.Table;
import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder;
import org.h2.util.StringUtils;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueArray;
......@@ -285,8 +284,8 @@ public class Aggregate extends AbstractAggregate {
}
@Override
protected void updateAggregate(Session session, SelectGroups groupData) {
AggregateData data = getData(session, groupData);
protected void updateAggregate(Session session, Object aggregateData) {
AggregateData data = (AggregateData) aggregateData;
Value v = on == null ? null : on.getValue(session);
if (type == AggregateType.GROUP_CONCAT) {
if (v != ValueNull.INSTANCE) {
......@@ -326,6 +325,11 @@ public class Aggregate extends AbstractAggregate {
return v;
}
@Override
protected Object createAggregateData() {
return AggregateData.create(type);
}
@Override
public Value getValue(Session session) {
if (select.isQuickAggregateQuery()) {
......@@ -364,7 +368,10 @@ public class Aggregate extends AbstractAggregate {
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
AggregateData data = getData(session, groupData);
AggregateData data = (AggregateData) getData(session, groupData, true);
if (data == null) {
data = (AggregateData) createAggregateData();
}
switch (type) {
case GROUP_CONCAT: {
Value[] array = ((AggregateDataCollecting) data).getArray();
......@@ -419,32 +426,6 @@ public class Aggregate extends AbstractAggregate {
}
}
private AggregateData getData(Session session, SelectGroups groupData) {
AggregateData data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateData> map = (ValueHashMap<AggregateData>) groupData.getCurrentGroupExprData(this,
true);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
data = AggregateData.create(type);
map.put(key, data);
}
} else {
data = (AggregateData) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
data = AggregateData.create(type);
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
@Override
public int getType() {
return dataType;
......
......@@ -20,7 +20,6 @@ import org.h2.message.DbException;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder;
import org.h2.util.ValueHashMap;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueArray;
......@@ -148,9 +147,13 @@ public class JavaAggregate extends AbstractAggregate {
super.setEvaluatable(tableFilter, b);
}
private Aggregate getInstance() throws SQLException {
private Aggregate getInstance() {
Aggregate agg = userAggregate.getInstance();
try {
agg.init(userConnection);
} catch (SQLException ex) {
throw DbException.convert(ex);
}
return agg;
}
......@@ -164,7 +167,7 @@ public class JavaAggregate extends AbstractAggregate {
Aggregate agg;
if (distinct) {
agg = getInstance();
AggregateDataCollecting data = getDataDistinct(session, groupData, true);
AggregateDataCollecting data = (AggregateDataCollecting) getData(session, groupData, true);
if (data != null) {
for (Value value : data.values) {
if (args.length == 1) {
......@@ -180,7 +183,7 @@ public class JavaAggregate extends AbstractAggregate {
}
}
} else {
agg = getData(session, groupData, true);
agg = (Aggregate) getData(session, groupData, true);
if (agg == null) {
agg = getInstance();
}
......@@ -196,10 +199,10 @@ public class JavaAggregate extends AbstractAggregate {
}
@Override
protected void updateAggregate(Session session, SelectGroups groupData) {
protected void updateAggregate(Session session, Object aggregateData) {
try {
if (distinct) {
AggregateDataCollecting data = getDataDistinct(session, groupData, false);
AggregateDataCollecting data = (AggregateDataCollecting) aggregateData;
Value[] argValues = new Value[args.length];
Value arg = null;
for (int i = 0, len = args.length; i < len; i++) {
......@@ -209,7 +212,7 @@ public class JavaAggregate extends AbstractAggregate {
}
data.add(session.getDatabase(), dataType, true, args.length == 1 ? arg : ValueArray.get(argValues));
} else {
Aggregate agg = getData(session, groupData, false);
Aggregate agg = (Aggregate) aggregateData;
Object[] argValues = new Object[args.length];
Object arg = null;
for (int i = 0, len = args.length; i < len; i++) {
......@@ -232,73 +235,9 @@ public class JavaAggregate extends AbstractAggregate {
}
}
private Aggregate getData(Session session, SelectGroups groupData, boolean ifExists) throws SQLException {
Aggregate data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Aggregate> map = (ValueHashMap<Aggregate>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
map.put(key, data);
}
} else {
data = (Aggregate) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = getInstance();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
}
private AggregateDataCollecting getDataDistinct(Session session, SelectGroups groupData, boolean ifExists) {
AggregateDataCollecting data;
ValueArray key;
if (over != null && (key = over.getCurrentKey(session)) != null) {
@SuppressWarnings("unchecked")
ValueHashMap<AggregateDataCollecting> map = (ValueHashMap<AggregateDataCollecting>) groupData
.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
data = map.get(key);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
map.put(key, data);
}
} else {
data = (AggregateDataCollecting) groupData.getCurrentGroupExprData(this, over != null);
if (data == null) {
if (ifExists) {
return null;
}
data = new AggregateDataCollecting();
groupData.setCurrentGroupExprData(this, data, over != null);
}
}
return data;
@Override
protected Object createAggregateData() {
return distinct ? new AggregateDataCollecting() : getInstance();
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论