提交 08913f79 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Extract DataAnalysisOperation from AbstractAggregate

上级 3814c595
...@@ -171,6 +171,7 @@ import org.h2.expression.UnaryOperation; ...@@ -171,6 +171,7 @@ import org.h2.expression.UnaryOperation;
import org.h2.expression.ValueExpression; import org.h2.expression.ValueExpression;
import org.h2.expression.Variable; import org.h2.expression.Variable;
import org.h2.expression.Wildcard; import org.h2.expression.Wildcard;
import org.h2.expression.aggregate.DataAnalysisOperation;
import org.h2.expression.aggregate.AbstractAggregate; import org.h2.expression.aggregate.AbstractAggregate;
import org.h2.expression.aggregate.Aggregate; import org.h2.expression.aggregate.Aggregate;
import org.h2.expression.aggregate.Aggregate.AggregateType; import org.h2.expression.aggregate.Aggregate.AggregateType;
...@@ -3053,23 +3054,24 @@ public class Parser { ...@@ -3053,23 +3054,24 @@ public class Parser {
} }
private void readFilterAndOver(AbstractAggregate aggregate) { private void readFilterAndOver(AbstractAggregate aggregate) {
boolean isAggregate = aggregate.isAggregate(); if (readIf("FILTER")) {
if (isAggregate && readIf("FILTER")) {
read(OPEN_PAREN); read(OPEN_PAREN);
read(WHERE); read(WHERE);
Expression filterCondition = readExpression(); Expression filterCondition = readExpression();
read(CLOSE_PAREN); read(CLOSE_PAREN);
aggregate.setFilterCondition(filterCondition); aggregate.setFilterCondition(filterCondition);
} }
Window over = null; readOver(aggregate);
}
private void readOver(DataAnalysisOperation operation) {
if (readIf("OVER")) { if (readIf("OVER")) {
over = readWindowNameOrSpecification(); operation.setOverCondition(readWindowNameOrSpecification());
aggregate.setOverCondition(over);
currentSelect.setWindowQuery(); currentSelect.setWindowQuery();
} else if (!isAggregate) { } else if (operation.isAggregate()) {
throw getSyntaxError();
} else {
currentSelect.setGroupQuery(); currentSelect.setGroupQuery();
} else {
throw getSyntaxError();
} }
} }
...@@ -3440,7 +3442,7 @@ public class Parser { ...@@ -3440,7 +3442,7 @@ public class Parser {
default: default:
// Avoid warning // Avoid warning
} }
readFilterAndOver(function); readOver(function);
return function; return function;
} }
......
...@@ -14,6 +14,7 @@ import java.util.Map.Entry; ...@@ -14,6 +14,7 @@ import java.util.Map.Entry;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.expression.aggregate.DataAnalysisOperation;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -210,7 +211,7 @@ public abstract class SelectGroups { ...@@ -210,7 +211,7 @@ public abstract class SelectGroups {
/** /**
* Maps an expression object to its data. * Maps an expression object to its data.
*/ */
private final HashMap<Expression, Object> windowData = new HashMap<>(); private final HashMap<DataAnalysisOperation, Object> windowData = new HashMap<>();
/** /**
* The id of the current group. * The id of the current group.
...@@ -292,7 +293,7 @@ public abstract class SelectGroups { ...@@ -292,7 +293,7 @@ public abstract class SelectGroups {
* expression * expression
* @return expression data or null * @return expression data or null
*/ */
public final Object getWindowExprData(Expression expr) { public final Object getWindowExprData(DataAnalysisOperation expr) {
return windowData.get(expr); return windowData.get(expr);
} }
...@@ -304,7 +305,7 @@ public abstract class SelectGroups { ...@@ -304,7 +305,7 @@ public abstract class SelectGroups {
* @param object * @param object
* expression data to set * expression data to set
*/ */
public final void setWindowExprData(Expression expr, Object obj) { public final void setWindowExprData(DataAnalysisOperation expr, Object obj) {
Object old = windowData.put(expr, obj); Object old = windowData.put(expr, obj);
assert old == null; assert old == null;
} }
......
...@@ -6,60 +6,34 @@ ...@@ -6,60 +6,34 @@
package org.h2.expression.aggregate; package org.h2.expression.aggregate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups; import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy; import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.expression.ExpressionVisitor;
import org.h2.message.DbException;
import org.h2.result.SortOrder;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueArray;
import org.h2.value.ValueInt;
/** /**
* A base class for aggregates and window functions. * A base class for aggregate functions.
*/ */
public abstract class AbstractAggregate extends Expression { public abstract class AbstractAggregate extends DataAnalysisOperation {
protected final Select select;
protected final boolean distinct; protected final boolean distinct;
protected Expression filterCondition; protected Expression filterCondition;
protected Window over;
protected SortOrder overOrderBySort;
private int lastGroupRowId;
protected static SortOrder createOrder(Session session, ArrayList<SelectOrderBy> orderBy, int offset) {
int size = orderBy.size();
int[] index = new int[size];
int[] sortType = new int[size];
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderBy.get(i);
index[i] = i + offset;
sortType[i] = o.sortType;
}
return new SortOrder(session.getDatabase(), index, sortType, null);
}
AbstractAggregate(Select select, boolean distinct) { AbstractAggregate(Select select, boolean distinct) {
this.select = select; super(select);
this.distinct = distinct; this.distinct = distinct;
} }
@Override
public final boolean isAggregate() {
return true;
}
/** /**
* Sets the FILTER condition. * Sets the FILTER condition.
* *
...@@ -67,38 +41,7 @@ public abstract class AbstractAggregate extends Expression { ...@@ -67,38 +41,7 @@ public abstract class AbstractAggregate extends Expression {
* FILTER condition * FILTER condition
*/ */
public void setFilterCondition(Expression filterCondition) { public void setFilterCondition(Expression filterCondition) {
if (isAggregate()) { this.filterCondition = filterCondition;
this.filterCondition = filterCondition;
} else {
throw DbException.getUnsupportedException("Window function");
}
}
/**
* Sets the OVER condition.
*
* @param over
* OVER condition
*/
public void setOverCondition(Window over) {
this.over = over;
}
/**
* Checks whether this expression is an aggregate function.
*
* @return true if this is an aggregate function (including aggregates with
* OVER clause), false if this is a window function
*/
public abstract boolean isAggregate();
/**
* Returns the sort order for OVER clause.
*
* @return the sort order for OVER clause
*/
SortOrder getOverOrderBySort() {
return overOrderBySort;
} }
@Override @Override
...@@ -106,23 +49,15 @@ public abstract class AbstractAggregate extends Expression { ...@@ -106,23 +49,15 @@ public abstract class AbstractAggregate extends Expression {
if (filterCondition != null) { if (filterCondition != null) {
filterCondition.mapColumns(resolver, level); filterCondition.mapColumns(resolver, level);
} }
if (over != null) { super.mapColumns(resolver, level);
over.mapColumns(resolver, level);
}
} }
@Override @Override
public Expression optimize(Session session) { public Expression optimize(Session session) {
if (over != null) { if (filterCondition != null) {
over.optimize(session); filterCondition = filterCondition.optimize(session);
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
overOrderBySort = createOrder(session, orderBy, getNumExpressions());
} else if (!isAggregate()) {
overOrderBySort = new SortOrder(session.getDatabase(), new int[getNumExpressions()], new int[0], null);
}
} }
return this; return super.optimize(session);
} }
@Override @Override
...@@ -130,70 +65,19 @@ public abstract class AbstractAggregate extends Expression { ...@@ -130,70 +65,19 @@ public abstract class AbstractAggregate extends Expression {
if (filterCondition != null) { if (filterCondition != null) {
filterCondition.setEvaluatable(tableFilter, b); filterCondition.setEvaluatable(tableFilter, b);
} }
if (over != null) { super.setEvaluatable(tableFilter, b);
over.setEvaluatable(tableFilter, b);
}
} }
@Override @Override
public void updateAggregate(Session session, int stage) { protected void updateAggregate(Session session, SelectGroups groupData, int groupRowId) {
if (stage == Aggregate.STAGE_RESET) { if (filterCondition == null || filterCondition.getBooleanValue(session)) {
updateSubAggregates(session, Aggregate.STAGE_RESET); ArrayList<SelectOrderBy> orderBy;
lastGroupRowId = 0; if (over != null && (orderBy = over.getOrderBy()) != null) {
return;
}
boolean window = stage == Aggregate.STAGE_WINDOW;
if (window != (over != null)) {
if (!window && select.isWindowQuery()) {
updateSubAggregates(session, stage);
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
if (!select.isGroupQuery()) {
over.updateAggregate(session, stage);
}
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
if (over != null) {
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null || !isAggregate()) {
updateOrderedAggregate(session, groupData, groupRowId, orderBy); updateOrderedAggregate(session, groupData, groupRowId, orderBy);
return; } else {
updateAggregate(session, getData(session, groupData, false, false));
} }
} }
updateAggregate(session, getData(session, groupData, false, false));
}
private void updateSubAggregates(Session session, int stage) {
updateGroupAggregates(session, stage);
if (filterCondition != null) {
filterCondition.updateAggregate(session, stage);
}
if (over != null) {
over.updateAggregate(session, stage);
}
} }
/** /**
...@@ -206,267 +90,20 @@ public abstract class AbstractAggregate extends Expression { ...@@ -206,267 +90,20 @@ public abstract class AbstractAggregate extends Expression {
*/ */
protected abstract void updateAggregate(Session session, Object aggregateData); protected abstract void updateAggregate(Session session, Object aggregateData);
/**
* Invoked when processing group stage of grouped window queries to update
* arguments of this aggregate.
*
* @param session
* the session
* @param stage
* select stage
*/
protected abstract void updateGroupAggregates(Session session, int stage);
/**
* Returns the number of expressions, excluding FILTER and OVER clauses.
*
* @return the number of expressions
*/
protected abstract int getNumExpressions();
/**
* Stores current values of expressions into the specified array.
*
* @param session
* the session
* @param array
* array to store values of expressions
*/
protected abstract void rememberExpressions(Session session, Value[] array);
/**
* Updates the provided aggregate data from the remembered expressions.
*
* @param session
* the session
* @param aggregateData
* aggregate data
* @param array
* values of expressions
*/
protected abstract void updateFromExpressions(Session session, Object aggregateData, Value[] array);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists, boolean forOrderBy) {
Object data;
if (over != null) {
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getWindowExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setWindowExprData(this, map);
}
PartitionData partition = (PartitionData) map.get(key);
if (partition == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
map.put(key, new PartitionData(data));
} else {
data = partition.getData();
}
} else {
PartitionData partition = (PartitionData) groupData.getWindowExprData(this);
if (partition == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setWindowExprData(this, new PartitionData(data));
} else {
data = partition.getData();
}
}
} else {
data = groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
protected abstract Object createAggregateData();
@Override @Override
public boolean isEverything(ExpressionVisitor visitor) { protected void updateGroupAggregates(Session session, int stage) {
if (over == null) { if (filterCondition != null) {
return true; filterCondition.updateAggregate(session, stage);
}
switch (visitor.getType()) {
case ExpressionVisitor.QUERY_COMPARABLE:
case ExpressionVisitor.OPTIMIZABLE_MIN_MAX_COUNT_ALL:
case ExpressionVisitor.DETERMINISTIC:
case ExpressionVisitor.INDEPENDENT:
return false;
case ExpressionVisitor.EVALUATABLE:
case ExpressionVisitor.READONLY:
case ExpressionVisitor.NOT_FROM_RESOLVER:
case ExpressionVisitor.GET_DEPENDENCIES:
case ExpressionVisitor.SET_MAX_DATA_MODIFICATION_ID:
case ExpressionVisitor.GET_COLUMNS1:
case ExpressionVisitor.GET_COLUMNS2:
return true;
default:
throw DbException.throwInternalError("type=" + visitor.getType());
} }
super.updateGroupAggregates(session, stage);
} }
@Override @Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
return over == null ? getAggregatedValue(session, getData(session, groupData, true, false))
: getWindowResult(session, groupData);
}
private Value getWindowResult(Session session, SelectGroups groupData) {
PartitionData partition;
Object data;
boolean forOrderBy = over.getOrderBy() != null;
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getWindowExprData(this);
if (map == null) {
map = new ValueHashMap<>();
groupData.setWindowExprData(this, map);
}
partition = (PartitionData) map.get(key);
if (partition == null) {
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
map.put(key, partition);
} else {
data = partition.getData();
}
} else {
partition = (PartitionData) groupData.getWindowExprData(this);
if (partition == null) {
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
groupData.setWindowExprData(this, partition);
} else {
data = partition.getData();
}
}
if (over.getOrderBy() != null || !isAggregate()) {
return getOrderedResult(session, groupData, partition, data);
}
Value result = partition.getResult();
if (result == null) {
result = getAggregatedValue(session, data);
partition.setResult(result);
}
return result;
}
/***
* Returns aggregated value.
*
* @param session
* the session
* @param aggregateData
* the aggregate data
* @return aggregated value.
*/
protected abstract Value getAggregatedValue(Session session, Object aggregateData);
private void updateOrderedAggregate(Session session, SelectGroups groupData, int groupRowId,
ArrayList<SelectOrderBy> orderBy) {
int ne = getNumExpressions();
int size = orderBy != null ? orderBy.size() : 0;
Value[] array = new Value[ne + size + 1];
rememberExpressions(session, array);
for (int i = 0; i < size; i++) {
@SuppressWarnings("null")
SelectOrderBy o = orderBy.get(i);
array[ne++] = o.expression.getValue(session);
}
array[ne] = ValueInt.get(groupRowId);
@SuppressWarnings("unchecked")
ArrayList<Value[]> data = (ArrayList<Value[]>) getData(session, groupData, false, true);
data.add(array);
}
private Value getOrderedResult(Session session, SelectGroups groupData, PartitionData partition, Object data) {
HashMap<Integer, Value> result = partition.getOrderedResult();
if (result == null) {
result = new HashMap<>();
@SuppressWarnings("unchecked")
ArrayList<Value[]> orderedData = (ArrayList<Value[]>) data;
int rowIdColumn = getNumExpressions();
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
rowIdColumn += orderBy.size();
Collections.sort(orderedData, overOrderBySort);
}
getOrderedResultLoop(session, result, orderedData, rowIdColumn);
partition.setOrderedResult(result);
}
return result.get(groupData.getCurrentGroupRowId());
}
/**
* @param session
* the session
* @param result
* the map to append result to
* @param ordered
* ordered data
* @param rowIdColumn
* the index of row id value
*/
protected void getOrderedResultLoop(Session session, HashMap<Integer, Value> result, ArrayList<Value[]> ordered,
int rowIdColumn) {
WindowFrame frame = over.getWindowFrame();
if (frame == null || frame.isDefault()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
result.put(row[rowIdColumn].getInt(), getAggregatedValue(session, aggregateData));
}
} else if (frame.isFullPartition()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
}
Value value = getAggregatedValue(session, aggregateData);
for (Value[] row : ordered) {
result.put(row[rowIdColumn].getInt(), value);
}
} else {
int size = ordered.size();
for (int i = 0; i < size; i++) {
Object aggregateData = createAggregateData();
for (Iterator<Value[]> iter = frame.iterator(session, ordered, getOverOrderBySort(), i, false); iter
.hasNext();) {
updateFromExpressions(session, aggregateData, iter.next());
}
result.put(ordered.get(i)[rowIdColumn].getInt(), getAggregatedValue(session, aggregateData));
}
}
}
protected StringBuilder appendTailConditions(StringBuilder builder) { protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) { if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
} }
if (over != null) { return super.appendTailConditions(builder);
builder.append(' ').append(over.getSQL());
}
return builder;
} }
} }
...@@ -251,11 +251,6 @@ public class Aggregate extends AbstractAggregate { ...@@ -251,11 +251,6 @@ public class Aggregate extends AbstractAggregate {
return AGGREGATES.get(name); return AGGREGATES.get(name);
} }
@Override
public boolean isAggregate() {
return true;
}
/** /**
* Set the order for ARRAY_AGG() or GROUP_CONCAT() aggregate. * Set the order for ARRAY_AGG() or GROUP_CONCAT() aggregate.
* *
...@@ -312,6 +307,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -312,6 +307,7 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
protected void updateGroupAggregates(Session session, int stage) { protected void updateGroupAggregates(Session session, int stage) {
super.updateGroupAggregates(session, stage);
if (on != null) { if (on != null) {
on.updateAggregate(session, stage); on.updateAggregate(session, stage);
} }
...@@ -514,9 +510,6 @@ public class Aggregate extends AbstractAggregate { ...@@ -514,9 +510,6 @@ public class Aggregate extends AbstractAggregate {
if (groupConcatSeparator != null) { if (groupConcatSeparator != null) {
groupConcatSeparator = groupConcatSeparator.optimize(session); groupConcatSeparator = groupConcatSeparator.optimize(session);
} }
if (filterCondition != null) {
filterCondition = filterCondition.optimize(session);
}
switch (type) { switch (type) {
case GROUP_CONCAT: case GROUP_CONCAT:
dataType = Value.STRING; dataType = Value.STRING;
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.expression.ExpressionVisitor;
import org.h2.message.DbException;
import org.h2.result.SortOrder;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueArray;
import org.h2.value.ValueInt;
/**
* A base class for data analysis operations such as aggregates and window
* functions.
*/
public abstract class DataAnalysisOperation extends Expression {
protected final Select select;
protected Window over;
protected SortOrder overOrderBySort;
private int lastGroupRowId;
protected static SortOrder createOrder(Session session, ArrayList<SelectOrderBy> orderBy, int offset) {
int size = orderBy.size();
int[] index = new int[size];
int[] sortType = new int[size];
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderBy.get(i);
index[i] = i + offset;
sortType[i] = o.sortType;
}
return new SortOrder(session.getDatabase(), index, sortType, null);
}
DataAnalysisOperation(Select select) {
this.select = select;
}
/**
* Sets the OVER condition.
*
* @param over
* OVER condition
*/
public void setOverCondition(Window over) {
this.over = over;
}
/**
* Checks whether this expression is an aggregate function.
*
* @return true if this is an aggregate function (including aggregates with
* OVER clause), false if this is a window function
*/
public abstract boolean isAggregate();
/**
* Returns the sort order for OVER clause.
*
* @return the sort order for OVER clause
*/
SortOrder getOverOrderBySort() {
return overOrderBySort;
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
if (over != null) {
over.mapColumns(resolver, level);
}
}
@Override
public Expression optimize(Session session) {
if (over != null) {
over.optimize(session);
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
overOrderBySort = createOrder(session, orderBy, getNumExpressions());
} else if (!isAggregate()) {
overOrderBySort = new SortOrder(session.getDatabase(), new int[getNumExpressions()], new int[0], null);
}
}
return this;
}
@Override
public void setEvaluatable(TableFilter tableFilter, boolean b) {
if (over != null) {
over.setEvaluatable(tableFilter, b);
}
}
@Override
public void updateAggregate(Session session, int stage) {
if (stage == Aggregate.STAGE_RESET) {
updateGroupAggregates(session, Aggregate.STAGE_RESET);
lastGroupRowId = 0;
return;
}
boolean window = stage == Aggregate.STAGE_WINDOW;
if (window != (over != null)) {
if (!window && select.isWindowQuery()) {
updateGroupAggregates(session, stage);
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
if (!select.isGroupQuery()) {
over.updateAggregate(session, stage);
}
}
updateAggregate(session, groupData, groupRowId);
}
protected abstract void updateAggregate(Session session, SelectGroups groupData, int groupRowId);
/**
* Invoked when processing group stage of grouped window queries to update
* arguments of this aggregate.
*
* @param session
* the session
* @param stage
* select stage
*/
protected void updateGroupAggregates(Session session, int stage) {
if (over != null) {
over.updateAggregate(session, stage);
}
}
/**
* Returns the number of expressions, excluding FILTER and OVER clauses.
*
* @return the number of expressions
*/
protected abstract int getNumExpressions();
/**
* Stores current values of expressions into the specified array.
*
* @param session
* the session
* @param array
* array to store values of expressions
*/
protected abstract void rememberExpressions(Session session, Value[] array);
/**
* Updates the provided aggregate data from the remembered expressions.
*
* @param session
* the session
* @param aggregateData
* aggregate data
* @param array
* values of expressions
*/
protected abstract void updateFromExpressions(Session session, Object aggregateData, Value[] array);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists, boolean forOrderBy) {
Object data;
if (over != null) {
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getWindowExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setWindowExprData(this, map);
}
PartitionData partition = (PartitionData) map.get(key);
if (partition == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
map.put(key, new PartitionData(data));
} else {
data = partition.getData();
}
} else {
PartitionData partition = (PartitionData) groupData.getWindowExprData(this);
if (partition == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setWindowExprData(this, new PartitionData(data));
} else {
data = partition.getData();
}
}
} else {
data = groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
protected abstract Object createAggregateData();
@Override
public boolean isEverything(ExpressionVisitor visitor) {
if (over == null) {
return true;
}
switch (visitor.getType()) {
case ExpressionVisitor.QUERY_COMPARABLE:
case ExpressionVisitor.OPTIMIZABLE_MIN_MAX_COUNT_ALL:
case ExpressionVisitor.DETERMINISTIC:
case ExpressionVisitor.INDEPENDENT:
return false;
case ExpressionVisitor.EVALUATABLE:
case ExpressionVisitor.READONLY:
case ExpressionVisitor.NOT_FROM_RESOLVER:
case ExpressionVisitor.GET_DEPENDENCIES:
case ExpressionVisitor.SET_MAX_DATA_MODIFICATION_ID:
case ExpressionVisitor.GET_COLUMNS1:
case ExpressionVisitor.GET_COLUMNS2:
return true;
default:
throw DbException.throwInternalError("type=" + visitor.getType());
}
}
@Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
return over == null ? getAggregatedValue(session, getData(session, groupData, true, false))
: getWindowResult(session, groupData);
}
private Value getWindowResult(Session session, SelectGroups groupData) {
PartitionData partition;
Object data;
boolean forOrderBy = over.getOrderBy() != null;
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getWindowExprData(this);
if (map == null) {
map = new ValueHashMap<>();
groupData.setWindowExprData(this, map);
}
partition = (PartitionData) map.get(key);
if (partition == null) {
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
map.put(key, partition);
} else {
data = partition.getData();
}
} else {
partition = (PartitionData) groupData.getWindowExprData(this);
if (partition == null) {
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
groupData.setWindowExprData(this, partition);
} else {
data = partition.getData();
}
}
if (over.getOrderBy() != null || !isAggregate()) {
return getOrderedResult(session, groupData, partition, data);
}
Value result = partition.getResult();
if (result == null) {
result = getAggregatedValue(session, data);
partition.setResult(result);
}
return result;
}
/***
* Returns aggregated value.
*
* @param session
* the session
* @param aggregateData
* the aggregate data
* @return aggregated value.
*/
protected abstract Value getAggregatedValue(Session session, Object aggregateData);
protected void updateOrderedAggregate(Session session, SelectGroups groupData, int groupRowId,
ArrayList<SelectOrderBy> orderBy) {
int ne = getNumExpressions();
int size = orderBy != null ? orderBy.size() : 0;
Value[] array = new Value[ne + size + 1];
rememberExpressions(session, array);
for (int i = 0; i < size; i++) {
@SuppressWarnings("null")
SelectOrderBy o = orderBy.get(i);
array[ne++] = o.expression.getValue(session);
}
array[ne] = ValueInt.get(groupRowId);
@SuppressWarnings("unchecked")
ArrayList<Value[]> data = (ArrayList<Value[]>) getData(session, groupData, false, true);
data.add(array);
}
private Value getOrderedResult(Session session, SelectGroups groupData, PartitionData partition, Object data) {
HashMap<Integer, Value> result = partition.getOrderedResult();
if (result == null) {
result = new HashMap<>();
@SuppressWarnings("unchecked")
ArrayList<Value[]> orderedData = (ArrayList<Value[]>) data;
int rowIdColumn = getNumExpressions();
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
rowIdColumn += orderBy.size();
Collections.sort(orderedData, overOrderBySort);
}
getOrderedResultLoop(session, result, orderedData, rowIdColumn);
partition.setOrderedResult(result);
}
return result.get(groupData.getCurrentGroupRowId());
}
/**
* @param session
* the session
* @param result
* the map to append result to
* @param ordered
* ordered data
* @param rowIdColumn
* the index of row id value
*/
protected void getOrderedResultLoop(Session session, HashMap<Integer, Value> result, ArrayList<Value[]> ordered,
int rowIdColumn) {
WindowFrame frame = over.getWindowFrame();
if (frame == null || frame.isDefault()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
result.put(row[rowIdColumn].getInt(), getAggregatedValue(session, aggregateData));
}
} else if (frame.isFullPartition()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
}
Value value = getAggregatedValue(session, aggregateData);
for (Value[] row : ordered) {
result.put(row[rowIdColumn].getInt(), value);
}
} else {
int size = ordered.size();
for (int i = 0; i < size; i++) {
Object aggregateData = createAggregateData();
for (Iterator<Value[]> iter = frame.iterator(session, ordered, getOverOrderBySort(), i, false); iter
.hasNext();) {
updateFromExpressions(session, aggregateData, iter.next());
}
result.put(ordered.get(i)[rowIdColumn].getInt(), getAggregatedValue(session, aggregateData));
}
}
}
protected StringBuilder appendTailConditions(StringBuilder builder) {
if (over != null) {
builder.append(' ').append(over.getSQL());
}
return builder;
}
}
...@@ -40,11 +40,6 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -40,11 +40,6 @@ public class JavaAggregate extends AbstractAggregate {
this.args = args; this.args = args;
} }
@Override
public boolean isAggregate() {
return true;
}
@Override @Override
public int getCost() { public int getCost() {
int cost = 5; int cost = 5;
...@@ -140,9 +135,6 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -140,9 +135,6 @@ public class JavaAggregate extends AbstractAggregate {
} catch (SQLException e) { } catch (SQLException e) {
throw DbException.convert(e); throw DbException.convert(e);
} }
if (filterCondition != null) {
filterCondition = filterCondition.optimize(session);
}
return this; return this;
} }
...@@ -237,6 +229,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -237,6 +229,7 @@ public class JavaAggregate extends AbstractAggregate {
@Override @Override
protected void updateGroupAggregates(Session session, int stage) { protected void updateGroupAggregates(Session session, int stage) {
super.updateGroupAggregates(session, stage);
for (Expression expr : args) { for (Expression expr : args) {
expr.updateAggregate(session, stage); expr.updateAggregate(session, stage);
} }
......
...@@ -10,6 +10,7 @@ import java.util.HashMap; ...@@ -10,6 +10,7 @@ import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy; import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
...@@ -24,7 +25,7 @@ import org.h2.value.ValueNull; ...@@ -24,7 +25,7 @@ import org.h2.value.ValueNull;
/** /**
* A window function. * A window function.
*/ */
public class WindowFunction extends AbstractAggregate { public class WindowFunction extends DataAnalysisOperation {
private final WindowFunctionType type; private final WindowFunctionType type;
...@@ -105,7 +106,7 @@ public class WindowFunction extends AbstractAggregate { ...@@ -105,7 +106,7 @@ public class WindowFunction extends AbstractAggregate {
* arguments, or null * arguments, or null
*/ */
public WindowFunction(WindowFunctionType type, Select select, Expression[] args) { public WindowFunction(WindowFunctionType type, Select select, Expression[] args) {
super(select, false); super(select);
this.type = type; this.type = type;
this.args = args; this.args = args;
} }
...@@ -145,12 +146,13 @@ public class WindowFunction extends AbstractAggregate { ...@@ -145,12 +146,13 @@ public class WindowFunction extends AbstractAggregate {
} }
@Override @Override
protected void updateAggregate(Session session, Object aggregateData) { protected void updateAggregate(Session session, SelectGroups groupData, int groupRowId) {
throw DbException.getUnsupportedException("Window function"); updateOrderedAggregate(session, groupData, groupRowId, over.getOrderBy());
} }
@Override @Override
protected void updateGroupAggregates(Session session, int stage) { protected void updateGroupAggregates(Session session, int stage) {
super.updateGroupAggregates(session, stage);
if (args != null) { if (args != null) {
for (Expression expr : args) { for (Expression expr : args) {
expr.updateAggregate(session, stage); expr.updateAggregate(session, stage);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论