Unverified 提交 f49d2b6a authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1462 from katzyn/window

Separate aggregate and window code in some places
...@@ -256,7 +256,12 @@ ${item.example}</p> ...@@ -256,7 +256,12 @@ ${item.example}</p>
<c:forEach var="item" items="otherGrammar"> <c:forEach var="item" items="otherGrammar">
<h3 id="${item.link}" class="notranslate" onclick="switchBnf(this)">${item.topic}</h3> <h3 id="${item.link}" class="notranslate" onclick="switchBnf(this)">${item.topic}</h3>
<!-- railroad-start --> <!-- railroad-start -->
<pre name="bnf" style="display: none">
${item.syntax}
</pre>
<div name="railroad">
${item.railroad} ${item.railroad}
</div>
<!-- railroad-end --> <!-- railroad-end -->
<!-- syntax-start <!-- syntax-start
<pre> <pre>
......
...@@ -171,6 +171,7 @@ import org.h2.expression.UnaryOperation; ...@@ -171,6 +171,7 @@ import org.h2.expression.UnaryOperation;
import org.h2.expression.ValueExpression; import org.h2.expression.ValueExpression;
import org.h2.expression.Variable; import org.h2.expression.Variable;
import org.h2.expression.Wildcard; import org.h2.expression.Wildcard;
import org.h2.expression.aggregate.DataAnalysisOperation;
import org.h2.expression.aggregate.AbstractAggregate; import org.h2.expression.aggregate.AbstractAggregate;
import org.h2.expression.aggregate.Aggregate; import org.h2.expression.aggregate.Aggregate;
import org.h2.expression.aggregate.Aggregate.AggregateType; import org.h2.expression.aggregate.Aggregate.AggregateType;
...@@ -3053,23 +3054,24 @@ public class Parser { ...@@ -3053,23 +3054,24 @@ public class Parser {
} }
private void readFilterAndOver(AbstractAggregate aggregate) { private void readFilterAndOver(AbstractAggregate aggregate) {
boolean isAggregate = aggregate.isAggregate(); if (readIf("FILTER")) {
if (isAggregate && readIf("FILTER")) {
read(OPEN_PAREN); read(OPEN_PAREN);
read(WHERE); read(WHERE);
Expression filterCondition = readExpression(); Expression filterCondition = readExpression();
read(CLOSE_PAREN); read(CLOSE_PAREN);
aggregate.setFilterCondition(filterCondition); aggregate.setFilterCondition(filterCondition);
} }
Window over = null; readOver(aggregate);
}
private void readOver(DataAnalysisOperation operation) {
if (readIf("OVER")) { if (readIf("OVER")) {
over = readWindowNameOrSpecification(); operation.setOverCondition(readWindowNameOrSpecification());
aggregate.setOverCondition(over);
currentSelect.setWindowQuery(); currentSelect.setWindowQuery();
} else if (!isAggregate) { } else if (operation.isAggregate()) {
throw getSyntaxError();
} else {
currentSelect.setGroupQuery(); currentSelect.setGroupQuery();
} else {
throw getSyntaxError();
} }
} }
...@@ -3440,7 +3442,7 @@ public class Parser { ...@@ -3440,7 +3442,7 @@ public class Parser {
default: default:
// Avoid warning // Avoid warning
} }
readFilterAndOver(function); readOver(function);
return function; return function;
} }
......
...@@ -14,6 +14,7 @@ import java.util.Map.Entry; ...@@ -14,6 +14,7 @@ import java.util.Map.Entry;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.expression.aggregate.DataAnalysisOperation;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueArray; import org.h2.value.ValueArray;
...@@ -210,7 +211,7 @@ public abstract class SelectGroups { ...@@ -210,7 +211,7 @@ public abstract class SelectGroups {
/** /**
* Maps an expression object to its data. * Maps an expression object to its data.
*/ */
private final HashMap<Expression, Object> windowData = new HashMap<>(); private final HashMap<DataAnalysisOperation, Object> windowData = new HashMap<>();
/** /**
* The id of the current group. * The id of the current group.
...@@ -251,14 +252,9 @@ public abstract class SelectGroups { ...@@ -251,14 +252,9 @@ public abstract class SelectGroups {
* *
* @param expr * @param expr
* expression * expression
* @param window
* true if expression is a window expression
* @return expression data or null * @return expression data or null
*/ */
public Object getCurrentGroupExprData(Expression expr, boolean window) { public final Object getCurrentGroupExprData(Expression expr) {
if (window) {
return windowData.get(expr);
}
Integer index = exprToIndexInGroupByData.get(expr); Integer index = exprToIndexInGroupByData.get(expr);
if (index == null) { if (index == null) {
return null; return null;
...@@ -273,15 +269,8 @@ public abstract class SelectGroups { ...@@ -273,15 +269,8 @@ public abstract class SelectGroups {
* expression * expression
* @param object * @param object
* expression data to set * expression data to set
* @param window
* true if expression is a window expression
*/ */
public void setCurrentGroupExprData(Expression expr, Object obj, boolean window) { public final void setCurrentGroupExprData(Expression expr, Object obj) {
if (window) {
Object old = windowData.put(expr, obj);
assert old == null;
return;
}
Integer index = exprToIndexInGroupByData.get(expr); Integer index = exprToIndexInGroupByData.get(expr);
if (index != null) { if (index != null) {
assert currentGroupByExprData[index] == null; assert currentGroupByExprData[index] == null;
...@@ -297,6 +286,30 @@ public abstract class SelectGroups { ...@@ -297,6 +286,30 @@ public abstract class SelectGroups {
currentGroupByExprData[index] = obj; currentGroupByExprData[index] = obj;
} }
/**
* Get the window data for the specified expression.
*
* @param expr
* expression
* @return expression data or null
*/
public final Object getWindowExprData(DataAnalysisOperation expr) {
return windowData.get(expr);
}
/**
* Set the window data for the specified expression.
*
* @param expr
* expression
* @param object
* expression data to set
*/
public final void setWindowExprData(DataAnalysisOperation expr, Object obj) {
Object old = windowData.put(expr, obj);
assert old == null;
}
abstract void updateCurrentGroupExprData(); abstract void updateCurrentGroupExprData();
/** /**
......
...@@ -164,9 +164,9 @@ public class ExpressionColumn extends Expression { ...@@ -164,9 +164,9 @@ public class ExpressionColumn extends Expression {
// this is a different level (the enclosing query) // this is a different level (the enclosing query)
return; return;
} }
Value v = (Value) groupData.getCurrentGroupExprData(this, false); Value v = (Value) groupData.getCurrentGroupExprData(this);
if (v == null) { if (v == null) {
groupData.setCurrentGroupExprData(this, now, false); groupData.setCurrentGroupExprData(this, now);
} else { } else {
if (!database.areEqual(now, v)) { if (!database.areEqual(now, v)) {
throw DbException.get(ErrorCode.MUST_GROUP_BY_COLUMN_1, getSQL()); throw DbException.get(ErrorCode.MUST_GROUP_BY_COLUMN_1, getSQL());
...@@ -180,7 +180,7 @@ public class ExpressionColumn extends Expression { ...@@ -180,7 +180,7 @@ public class ExpressionColumn extends Expression {
if (select != null) { if (select != null) {
SelectGroups groupData = select.getGroupDataIfCurrent(false); SelectGroups groupData = select.getGroupDataIfCurrent(false);
if (groupData != null) { if (groupData != null) {
Value v = (Value) groupData.getCurrentGroupExprData(this, false); Value v = (Value) groupData.getCurrentGroupExprData(this);
if (v != null) { if (v != null) {
return v; return v;
} }
......
...@@ -6,60 +6,34 @@ ...@@ -6,60 +6,34 @@
package org.h2.expression.aggregate; package org.h2.expression.aggregate;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups; import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy; import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.expression.ExpressionVisitor;
import org.h2.message.DbException;
import org.h2.result.SortOrder;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueArray;
import org.h2.value.ValueInt;
/** /**
* A base class for aggregates and window functions. * A base class for aggregate functions.
*/ */
public abstract class AbstractAggregate extends Expression { public abstract class AbstractAggregate extends DataAnalysisOperation {
protected final Select select;
protected final boolean distinct; protected final boolean distinct;
protected Expression filterCondition; protected Expression filterCondition;
protected Window over;
protected SortOrder overOrderBySort;
private int lastGroupRowId;
protected static SortOrder createOrder(Session session, ArrayList<SelectOrderBy> orderBy, int offset) {
int size = orderBy.size();
int[] index = new int[size];
int[] sortType = new int[size];
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderBy.get(i);
index[i] = i + offset;
sortType[i] = o.sortType;
}
return new SortOrder(session.getDatabase(), index, sortType, null);
}
AbstractAggregate(Select select, boolean distinct) { AbstractAggregate(Select select, boolean distinct) {
this.select = select; super(select);
this.distinct = distinct; this.distinct = distinct;
} }
@Override
public final boolean isAggregate() {
return true;
}
/** /**
* Sets the FILTER condition. * Sets the FILTER condition.
* *
...@@ -67,38 +41,7 @@ public abstract class AbstractAggregate extends Expression { ...@@ -67,38 +41,7 @@ public abstract class AbstractAggregate extends Expression {
* FILTER condition * FILTER condition
*/ */
public void setFilterCondition(Expression filterCondition) { public void setFilterCondition(Expression filterCondition) {
if (isAggregate()) { this.filterCondition = filterCondition;
this.filterCondition = filterCondition;
} else {
throw DbException.getUnsupportedException("Window function");
}
}
/**
* Sets the OVER condition.
*
* @param over
* OVER condition
*/
public void setOverCondition(Window over) {
this.over = over;
}
/**
* Checks whether this expression is an aggregate function.
*
* @return true if this is an aggregate function (including aggregates with
* OVER clause), false if this is a window function
*/
public abstract boolean isAggregate();
/**
* Returns the sort order for OVER clause.
*
* @return the sort order for OVER clause
*/
SortOrder getOverOrderBySort() {
return overOrderBySort;
} }
@Override @Override
...@@ -106,23 +49,15 @@ public abstract class AbstractAggregate extends Expression { ...@@ -106,23 +49,15 @@ public abstract class AbstractAggregate extends Expression {
if (filterCondition != null) { if (filterCondition != null) {
filterCondition.mapColumns(resolver, level); filterCondition.mapColumns(resolver, level);
} }
if (over != null) { super.mapColumns(resolver, level);
over.mapColumns(resolver, level);
}
} }
@Override @Override
public Expression optimize(Session session) { public Expression optimize(Session session) {
if (over != null) { if (filterCondition != null) {
over.optimize(session); filterCondition = filterCondition.optimize(session);
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
overOrderBySort = createOrder(session, orderBy, getNumExpressions());
} else if (!isAggregate()) {
overOrderBySort = new SortOrder(session.getDatabase(), new int[getNumExpressions()], new int[0], null);
}
} }
return this; return super.optimize(session);
} }
@Override @Override
...@@ -130,70 +65,19 @@ public abstract class AbstractAggregate extends Expression { ...@@ -130,70 +65,19 @@ public abstract class AbstractAggregate extends Expression {
if (filterCondition != null) { if (filterCondition != null) {
filterCondition.setEvaluatable(tableFilter, b); filterCondition.setEvaluatable(tableFilter, b);
} }
if (over != null) { super.setEvaluatable(tableFilter, b);
over.setEvaluatable(tableFilter, b);
}
} }
@Override @Override
public void updateAggregate(Session session, int stage) { protected void updateAggregate(Session session, SelectGroups groupData, int groupRowId) {
if (stage == Aggregate.STAGE_RESET) { if (filterCondition == null || filterCondition.getBooleanValue(session)) {
updateSubAggregates(session, Aggregate.STAGE_RESET); ArrayList<SelectOrderBy> orderBy;
lastGroupRowId = 0; if (over != null && (orderBy = over.getOrderBy()) != null) {
return;
}
boolean window = stage == Aggregate.STAGE_WINDOW;
if (window != (over != null)) {
if (!window && select.isWindowQuery()) {
updateSubAggregates(session, stage);
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
if (!select.isGroupQuery()) {
over.updateAggregate(session, stage);
}
}
if (filterCondition != null) {
if (!filterCondition.getBooleanValue(session)) {
return;
}
}
if (over != null) {
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null || !isAggregate()) {
updateOrderedAggregate(session, groupData, groupRowId, orderBy); updateOrderedAggregate(session, groupData, groupRowId, orderBy);
return; } else {
updateAggregate(session, getData(session, groupData, false, false));
} }
} }
updateAggregate(session, getData(session, groupData, false, false));
}
private void updateSubAggregates(Session session, int stage) {
updateGroupAggregates(session, stage);
if (filterCondition != null) {
filterCondition.updateAggregate(session, stage);
}
if (over != null) {
over.updateAggregate(session, stage);
}
} }
/** /**
...@@ -206,267 +90,20 @@ public abstract class AbstractAggregate extends Expression { ...@@ -206,267 +90,20 @@ public abstract class AbstractAggregate extends Expression {
*/ */
protected abstract void updateAggregate(Session session, Object aggregateData); protected abstract void updateAggregate(Session session, Object aggregateData);
/**
* Invoked when processing group stage of grouped window queries to update
* arguments of this aggregate.
*
* @param session
* the session
* @param stage
* select stage
*/
protected abstract void updateGroupAggregates(Session session, int stage);
/**
* Returns the number of expressions, excluding FILTER and OVER clauses.
*
* @return the number of expressions
*/
protected abstract int getNumExpressions();
/**
* Stores current values of expressions into the specified array.
*
* @param session
* the session
* @param array
* array to store values of expressions
*/
protected abstract void rememberExpressions(Session session, Value[] array);
/**
* Updates the provided aggregate data from the remembered expressions.
*
* @param session
* the session
* @param aggregateData
* aggregate data
* @param array
* values of expressions
*/
protected abstract void updateFromExpressions(Session session, Object aggregateData, Value[] array);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists, boolean forOrderBy) {
Object data;
if (over != null) {
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
PartitionData partition = (PartitionData) map.get(key);
if (partition == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
map.put(key, new PartitionData(data));
} else {
data = partition.getData();
}
} else {
PartitionData partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setCurrentGroupExprData(this, new PartitionData(data), true);
} else {
data = partition.getData();
}
}
} else {
data = groupData.getCurrentGroupExprData(this, false);
if (data == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setCurrentGroupExprData(this, data, false);
}
}
return data;
}
protected abstract Object createAggregateData();
@Override @Override
public boolean isEverything(ExpressionVisitor visitor) { protected void updateGroupAggregates(Session session, int stage) {
if (over == null) { if (filterCondition != null) {
return true; filterCondition.updateAggregate(session, stage);
}
switch (visitor.getType()) {
case ExpressionVisitor.QUERY_COMPARABLE:
case ExpressionVisitor.OPTIMIZABLE_MIN_MAX_COUNT_ALL:
case ExpressionVisitor.DETERMINISTIC:
case ExpressionVisitor.INDEPENDENT:
return false;
case ExpressionVisitor.EVALUATABLE:
case ExpressionVisitor.READONLY:
case ExpressionVisitor.NOT_FROM_RESOLVER:
case ExpressionVisitor.GET_DEPENDENCIES:
case ExpressionVisitor.SET_MAX_DATA_MODIFICATION_ID:
case ExpressionVisitor.GET_COLUMNS1:
case ExpressionVisitor.GET_COLUMNS2:
return true;
default:
throw DbException.throwInternalError("type=" + visitor.getType());
} }
super.updateGroupAggregates(session, stage);
} }
@Override @Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
return over == null ? getAggregatedValue(session, getData(session, groupData, true, false))
: getWindowResult(session, groupData);
}
private Value getWindowResult(Session session, SelectGroups groupData) {
PartitionData partition;
Object data;
boolean forOrderBy = over.getOrderBy() != null;
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getCurrentGroupExprData(this, true);
if (map == null) {
map = new ValueHashMap<>();
groupData.setCurrentGroupExprData(this, map, true);
}
partition = (PartitionData) map.get(key);
if (partition == null) {
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
map.put(key, partition);
} else {
data = partition.getData();
}
} else {
partition = (PartitionData) groupData.getCurrentGroupExprData(this, true);
if (partition == null) {
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
groupData.setCurrentGroupExprData(this, partition, true);
} else {
data = partition.getData();
}
}
if (over.getOrderBy() != null || !isAggregate()) {
return getOrderedResult(session, groupData, partition, data);
}
Value result = partition.getResult();
if (result == null) {
result = getAggregatedValue(session, data);
partition.setResult(result);
}
return result;
}
/***
* Returns aggregated value.
*
* @param session
* the session
* @param aggregateData
* the aggregate data
* @return aggregated value.
*/
protected abstract Value getAggregatedValue(Session session, Object aggregateData);
private void updateOrderedAggregate(Session session, SelectGroups groupData, int groupRowId,
ArrayList<SelectOrderBy> orderBy) {
int ne = getNumExpressions();
int size = orderBy != null ? orderBy.size() : 0;
Value[] array = new Value[ne + size + 1];
rememberExpressions(session, array);
for (int i = 0; i < size; i++) {
@SuppressWarnings("null")
SelectOrderBy o = orderBy.get(i);
array[ne++] = o.expression.getValue(session);
}
array[ne] = ValueInt.get(groupRowId);
@SuppressWarnings("unchecked")
ArrayList<Value[]> data = (ArrayList<Value[]>) getData(session, groupData, false, true);
data.add(array);
}
private Value getOrderedResult(Session session, SelectGroups groupData, PartitionData partition, Object data) {
HashMap<Integer, Value> result = partition.getOrderedResult();
if (result == null) {
result = new HashMap<>();
@SuppressWarnings("unchecked")
ArrayList<Value[]> orderedData = (ArrayList<Value[]>) data;
int rowIdColumn = getNumExpressions();
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
rowIdColumn += orderBy.size();
Collections.sort(orderedData, overOrderBySort);
}
getOrderedResultLoop(session, result, orderedData, rowIdColumn);
partition.setOrderedResult(result);
}
return result.get(groupData.getCurrentGroupRowId());
}
/**
* @param session
* the session
* @param result
* the map to append result to
* @param ordered
* ordered data
* @param rowIdColumn
* the index of row id value
*/
protected void getOrderedResultLoop(Session session, HashMap<Integer, Value> result, ArrayList<Value[]> ordered,
int rowIdColumn) {
WindowFrame frame = over.getWindowFrame();
if (frame == null || frame.isDefault()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
result.put(row[rowIdColumn].getInt(), getAggregatedValue(session, aggregateData));
}
} else if (frame.isFullPartition()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
}
Value value = getAggregatedValue(session, aggregateData);
for (Value[] row : ordered) {
result.put(row[rowIdColumn].getInt(), value);
}
} else {
int size = ordered.size();
for (int i = 0; i < size; i++) {
Object aggregateData = createAggregateData();
for (Iterator<Value[]> iter = frame.iterator(session, ordered, getOverOrderBySort(), i, false); iter
.hasNext();) {
updateFromExpressions(session, aggregateData, iter.next());
}
result.put(ordered.get(i)[rowIdColumn].getInt(), getAggregatedValue(session, aggregateData));
}
}
}
protected StringBuilder appendTailConditions(StringBuilder builder) { protected StringBuilder appendTailConditions(StringBuilder builder) {
if (filterCondition != null) { if (filterCondition != null) {
builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')'); builder.append(" FILTER (WHERE ").append(filterCondition.getSQL()).append(')');
} }
if (over != null) { return super.appendTailConditions(builder);
builder.append(' ').append(over.getSQL());
}
return builder;
} }
} }
...@@ -251,11 +251,6 @@ public class Aggregate extends AbstractAggregate { ...@@ -251,11 +251,6 @@ public class Aggregate extends AbstractAggregate {
return AGGREGATES.get(name); return AGGREGATES.get(name);
} }
@Override
public boolean isAggregate() {
return true;
}
/** /**
* Set the order for ARRAY_AGG() or GROUP_CONCAT() aggregate. * Set the order for ARRAY_AGG() or GROUP_CONCAT() aggregate.
* *
...@@ -312,6 +307,7 @@ public class Aggregate extends AbstractAggregate { ...@@ -312,6 +307,7 @@ public class Aggregate extends AbstractAggregate {
@Override @Override
protected void updateGroupAggregates(Session session, int stage) { protected void updateGroupAggregates(Session session, int stage) {
super.updateGroupAggregates(session, stage);
if (on != null) { if (on != null) {
on.updateAggregate(session, stage); on.updateAggregate(session, stage);
} }
...@@ -514,9 +510,6 @@ public class Aggregate extends AbstractAggregate { ...@@ -514,9 +510,6 @@ public class Aggregate extends AbstractAggregate {
if (groupConcatSeparator != null) { if (groupConcatSeparator != null) {
groupConcatSeparator = groupConcatSeparator.optimize(session); groupConcatSeparator = groupConcatSeparator.optimize(session);
} }
if (filterCondition != null) {
filterCondition = filterCondition.optimize(session);
}
switch (type) { switch (type) {
case GROUP_CONCAT: case GROUP_CONCAT:
dataType = Value.STRING; dataType = Value.STRING;
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression.aggregate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import org.h2.api.ErrorCode;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.expression.ExpressionVisitor;
import org.h2.message.DbException;
import org.h2.result.SortOrder;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.ValueHashMap;
import org.h2.value.Value;
import org.h2.value.ValueArray;
import org.h2.value.ValueInt;
/**
* A base class for data analysis operations such as aggregates and window
* functions.
*/
public abstract class DataAnalysisOperation extends Expression {
protected final Select select;
protected Window over;
protected SortOrder overOrderBySort;
private int lastGroupRowId;
protected static SortOrder createOrder(Session session, ArrayList<SelectOrderBy> orderBy, int offset) {
int size = orderBy.size();
int[] index = new int[size];
int[] sortType = new int[size];
for (int i = 0; i < size; i++) {
SelectOrderBy o = orderBy.get(i);
index[i] = i + offset;
sortType[i] = o.sortType;
}
return new SortOrder(session.getDatabase(), index, sortType, null);
}
DataAnalysisOperation(Select select) {
this.select = select;
}
/**
* Sets the OVER condition.
*
* @param over
* OVER condition
*/
public void setOverCondition(Window over) {
this.over = over;
}
/**
* Checks whether this expression is an aggregate function.
*
* @return true if this is an aggregate function (including aggregates with
* OVER clause), false if this is a window function
*/
public abstract boolean isAggregate();
/**
* Returns the sort order for OVER clause.
*
* @return the sort order for OVER clause
*/
SortOrder getOverOrderBySort() {
return overOrderBySort;
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
if (over != null) {
over.mapColumns(resolver, level);
}
}
@Override
public Expression optimize(Session session) {
if (over != null) {
over.optimize(session);
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
overOrderBySort = createOrder(session, orderBy, getNumExpressions());
} else if (!isAggregate()) {
overOrderBySort = new SortOrder(session.getDatabase(), new int[getNumExpressions()], new int[0], null);
}
}
return this;
}
@Override
public void setEvaluatable(TableFilter tableFilter, boolean b) {
if (over != null) {
over.setEvaluatable(tableFilter, b);
}
}
@Override
public void updateAggregate(Session session, int stage) {
if (stage == Aggregate.STAGE_RESET) {
updateGroupAggregates(session, Aggregate.STAGE_RESET);
lastGroupRowId = 0;
return;
}
boolean window = stage == Aggregate.STAGE_WINDOW;
if (window != (over != null)) {
if (!window && select.isWindowQuery()) {
updateGroupAggregates(session, stage);
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = groupData.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
if (over != null) {
if (!select.isGroupQuery()) {
over.updateAggregate(session, stage);
}
}
updateAggregate(session, groupData, groupRowId);
}
protected abstract void updateAggregate(Session session, SelectGroups groupData, int groupRowId);
/**
* Invoked when processing group stage of grouped window queries to update
* arguments of this aggregate.
*
* @param session
* the session
* @param stage
* select stage
*/
protected void updateGroupAggregates(Session session, int stage) {
if (over != null) {
over.updateAggregate(session, stage);
}
}
/**
* Returns the number of expressions, excluding FILTER and OVER clauses.
*
* @return the number of expressions
*/
protected abstract int getNumExpressions();
/**
* Stores current values of expressions into the specified array.
*
* @param session
* the session
* @param array
* array to store values of expressions
*/
protected abstract void rememberExpressions(Session session, Value[] array);
/**
* Updates the provided aggregate data from the remembered expressions.
*
* @param session
* the session
* @param aggregateData
* aggregate data
* @param array
* values of expressions
*/
protected abstract void updateFromExpressions(Session session, Object aggregateData, Value[] array);
protected Object getData(Session session, SelectGroups groupData, boolean ifExists, boolean forOrderBy) {
Object data;
if (over != null) {
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getWindowExprData(this);
if (map == null) {
if (ifExists) {
return null;
}
map = new ValueHashMap<>();
groupData.setWindowExprData(this, map);
}
PartitionData partition = (PartitionData) map.get(key);
if (partition == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
map.put(key, new PartitionData(data));
} else {
data = partition.getData();
}
} else {
PartitionData partition = (PartitionData) groupData.getWindowExprData(this);
if (partition == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setWindowExprData(this, new PartitionData(data));
} else {
data = partition.getData();
}
}
} else {
data = groupData.getCurrentGroupExprData(this);
if (data == null) {
if (ifExists) {
return null;
}
data = forOrderBy ? new ArrayList<>() : createAggregateData();
groupData.setCurrentGroupExprData(this, data);
}
}
return data;
}
protected abstract Object createAggregateData();
@Override
public boolean isEverything(ExpressionVisitor visitor) {
if (over == null) {
return true;
}
switch (visitor.getType()) {
case ExpressionVisitor.QUERY_COMPARABLE:
case ExpressionVisitor.OPTIMIZABLE_MIN_MAX_COUNT_ALL:
case ExpressionVisitor.DETERMINISTIC:
case ExpressionVisitor.INDEPENDENT:
return false;
case ExpressionVisitor.EVALUATABLE:
case ExpressionVisitor.READONLY:
case ExpressionVisitor.NOT_FROM_RESOLVER:
case ExpressionVisitor.GET_DEPENDENCIES:
case ExpressionVisitor.SET_MAX_DATA_MODIFICATION_ID:
case ExpressionVisitor.GET_COLUMNS1:
case ExpressionVisitor.GET_COLUMNS2:
return true;
default:
throw DbException.throwInternalError("type=" + visitor.getType());
}
}
@Override
public Value getValue(Session session) {
SelectGroups groupData = select.getGroupDataIfCurrent(over != null);
if (groupData == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
return over == null ? getAggregatedValue(session, getData(session, groupData, true, false))
: getWindowResult(session, groupData);
}
private Value getWindowResult(Session session, SelectGroups groupData) {
PartitionData partition;
Object data;
boolean forOrderBy = over.getOrderBy() != null;
ValueArray key = over.getCurrentKey(session);
if (key != null) {
@SuppressWarnings("unchecked")
ValueHashMap<Object> map = (ValueHashMap<Object>) groupData.getWindowExprData(this);
if (map == null) {
map = new ValueHashMap<>();
groupData.setWindowExprData(this, map);
}
partition = (PartitionData) map.get(key);
if (partition == null) {
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
map.put(key, partition);
} else {
data = partition.getData();
}
} else {
partition = (PartitionData) groupData.getWindowExprData(this);
if (partition == null) {
data = forOrderBy ? new ArrayList<>() : createAggregateData();
partition = new PartitionData(data);
groupData.setWindowExprData(this, partition);
} else {
data = partition.getData();
}
}
if (over.getOrderBy() != null || !isAggregate()) {
return getOrderedResult(session, groupData, partition, data);
}
Value result = partition.getResult();
if (result == null) {
result = getAggregatedValue(session, data);
partition.setResult(result);
}
return result;
}
/***
* Returns aggregated value.
*
* @param session
* the session
* @param aggregateData
* the aggregate data
* @return aggregated value.
*/
protected abstract Value getAggregatedValue(Session session, Object aggregateData);
protected void updateOrderedAggregate(Session session, SelectGroups groupData, int groupRowId,
ArrayList<SelectOrderBy> orderBy) {
int ne = getNumExpressions();
int size = orderBy != null ? orderBy.size() : 0;
Value[] array = new Value[ne + size + 1];
rememberExpressions(session, array);
for (int i = 0; i < size; i++) {
@SuppressWarnings("null")
SelectOrderBy o = orderBy.get(i);
array[ne++] = o.expression.getValue(session);
}
array[ne] = ValueInt.get(groupRowId);
@SuppressWarnings("unchecked")
ArrayList<Value[]> data = (ArrayList<Value[]>) getData(session, groupData, false, true);
data.add(array);
}
private Value getOrderedResult(Session session, SelectGroups groupData, PartitionData partition, Object data) {
HashMap<Integer, Value> result = partition.getOrderedResult();
if (result == null) {
result = new HashMap<>();
@SuppressWarnings("unchecked")
ArrayList<Value[]> orderedData = (ArrayList<Value[]>) data;
int rowIdColumn = getNumExpressions();
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
rowIdColumn += orderBy.size();
Collections.sort(orderedData, overOrderBySort);
}
getOrderedResultLoop(session, result, orderedData, rowIdColumn);
partition.setOrderedResult(result);
}
return result.get(groupData.getCurrentGroupRowId());
}
/**
* @param session
* the session
* @param result
* the map to append result to
* @param ordered
* ordered data
* @param rowIdColumn
* the index of row id value
*/
protected void getOrderedResultLoop(Session session, HashMap<Integer, Value> result, ArrayList<Value[]> ordered,
int rowIdColumn) {
WindowFrame frame = over.getWindowFrame();
if (frame == null || frame.isDefault()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
result.put(row[rowIdColumn].getInt(), getAggregatedValue(session, aggregateData));
}
} else if (frame.isFullPartition()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
}
Value value = getAggregatedValue(session, aggregateData);
for (Value[] row : ordered) {
result.put(row[rowIdColumn].getInt(), value);
}
} else {
int size = ordered.size();
for (int i = 0; i < size; i++) {
Object aggregateData = createAggregateData();
for (Iterator<Value[]> iter = frame.iterator(session, ordered, getOverOrderBySort(), i, false); iter
.hasNext();) {
updateFromExpressions(session, aggregateData, iter.next());
}
result.put(ordered.get(i)[rowIdColumn].getInt(), getAggregatedValue(session, aggregateData));
}
}
}
protected StringBuilder appendTailConditions(StringBuilder builder) {
if (over != null) {
builder.append(' ').append(over.getSQL());
}
return builder;
}
}
...@@ -40,11 +40,6 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -40,11 +40,6 @@ public class JavaAggregate extends AbstractAggregate {
this.args = args; this.args = args;
} }
@Override
public boolean isAggregate() {
return true;
}
@Override @Override
public int getCost() { public int getCost() {
int cost = 5; int cost = 5;
...@@ -140,9 +135,6 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -140,9 +135,6 @@ public class JavaAggregate extends AbstractAggregate {
} catch (SQLException e) { } catch (SQLException e) {
throw DbException.convert(e); throw DbException.convert(e);
} }
if (filterCondition != null) {
filterCondition = filterCondition.optimize(session);
}
return this; return this;
} }
...@@ -237,6 +229,7 @@ public class JavaAggregate extends AbstractAggregate { ...@@ -237,6 +229,7 @@ public class JavaAggregate extends AbstractAggregate {
@Override @Override
protected void updateGroupAggregates(Session session, int stage) { protected void updateGroupAggregates(Session session, int stage) {
super.updateGroupAggregates(session, stage);
for (Expression expr : args) { for (Expression expr : args) {
expr.updateAggregate(session, stage); expr.updateAggregate(session, stage);
} }
......
...@@ -10,6 +10,7 @@ import java.util.HashMap; ...@@ -10,6 +10,7 @@ import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.command.dml.SelectGroups;
import org.h2.command.dml.SelectOrderBy; import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression; import org.h2.expression.Expression;
...@@ -24,7 +25,7 @@ import org.h2.value.ValueNull; ...@@ -24,7 +25,7 @@ import org.h2.value.ValueNull;
/** /**
* A window function. * A window function.
*/ */
public class WindowFunction extends AbstractAggregate { public class WindowFunction extends DataAnalysisOperation {
private final WindowFunctionType type; private final WindowFunctionType type;
...@@ -105,7 +106,7 @@ public class WindowFunction extends AbstractAggregate { ...@@ -105,7 +106,7 @@ public class WindowFunction extends AbstractAggregate {
* arguments, or null * arguments, or null
*/ */
public WindowFunction(WindowFunctionType type, Select select, Expression[] args) { public WindowFunction(WindowFunctionType type, Select select, Expression[] args) {
super(select, false); super(select);
this.type = type; this.type = type;
this.args = args; this.args = args;
} }
...@@ -145,12 +146,13 @@ public class WindowFunction extends AbstractAggregate { ...@@ -145,12 +146,13 @@ public class WindowFunction extends AbstractAggregate {
} }
@Override @Override
protected void updateAggregate(Session session, Object aggregateData) { protected void updateAggregate(Session session, SelectGroups groupData, int groupRowId) {
throw DbException.getUnsupportedException("Window function"); updateOrderedAggregate(session, groupData, groupRowId, over.getOrderBy());
} }
@Override @Override
protected void updateGroupAggregates(Session session, int stage) { protected void updateGroupAggregates(Session session, int stage) {
super.updateGroupAggregates(session, stage);
if (args != null) { if (args != null) {
for (Expression expr : args) { for (Expression expr : args) {
expr.updateAggregate(session, stage); expr.updateAggregate(session, stage);
......
...@@ -449,9 +449,10 @@ public class TransactionStore { ...@@ -449,9 +449,10 @@ public class TransactionStore {
if (map != null) { // might be null if map was removed later if (map != null) { // might be null if map was removed later
Object key = op[1]; Object key = op[1];
commitDecisionMaker.setUndoKey(undoKey); commitDecisionMaker.setUndoKey(undoKey);
// although second parameter (value) is not really used // although second parameter (value) is not really
// by CommitDecisionMaker, MVRTreeMap has weird traversal logic based on it, // used by CommitDecisionMaker, MVRTreeMap has weird
// and any non-null value will do, to signify update, not removal // traversal logic based on it, and any non-null
// value will do, to signify update, not removal
map.operate(key, VersionedValue.DUMMY, commitDecisionMaker); map.operate(key, VersionedValue.DUMMY, commitDecisionMaker);
} }
} }
......
...@@ -796,4 +796,4 @@ interior envelopes multilinestring multipoint packed exterior normalization awkw ...@@ -796,4 +796,4 @@ interior envelopes multilinestring multipoint packed exterior normalization awkw
xym normalizes coord setz xyzm geometrycollection multipolygon mixup rings polygons rejection finite xym normalizes coord setz xyzm geometrycollection multipolygon mixup rings polygons rejection finite
pointzm pointz pointm dimensionality redefine forum measures pointzm pointz pointm dimensionality redefine forum measures
mpg casted pzm mls constrained subtypes complains mpg casted pzm mls constrained subtypes complains
ranks rno dro rko precede cume reopens preceding unbounded rightly itr lag maximal tiles tile ntile ranks rno dro rko precede cume reopens preceding unbounded rightly itr lag maximal tiles tile ntile signify
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论