提交 0d26b305 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Forbid incorrect nesting of aggregates and window functions

上级 3f93f4ae
......@@ -191,7 +191,7 @@ public class AlterTableAddConstraint extends SchemaCommand {
String name = generateConstraintName(table);
ConstraintCheck check = new ConstraintCheck(getSchema(), id, name, table);
TableFilter filter = new TableFilter(session, table, null, false, null, 0, null);
checkExpression.mapColumns(filter, 0);
checkExpression.mapColumns(filter, 0, Expression.MAP_INITIAL);
checkExpression = checkExpression.optimize(session);
check.setExpression(checkExpression);
check.setTableFilter(filter);
......
......@@ -133,9 +133,9 @@ public class Delete extends Prepared {
@Override
public void prepare() {
if (condition != null) {
condition.mapColumns(targetTableFilter, 0);
condition.mapColumns(targetTableFilter, 0, Expression.MAP_INITIAL);
if (sourceTableFilter != null) {
condition.mapColumns(sourceTableFilter, 0);
condition.mapColumns(sourceTableFilter, 0, Expression.MAP_INITIAL);
}
condition = condition.optimize(session);
condition.createIndexConditions(session, targetTableFilter);
......
......@@ -333,7 +333,7 @@ public class Insert extends Prepared implements ResultTarget {
Expression e = expr[i];
if (e != null) {
if(sourceTableFilter!=null){
e.mapColumns(sourceTableFilter, 0);
e.mapColumns(sourceTableFilter, 0, Expression.MAP_INITIAL);
}
e = e.optimize(session);
if (e instanceof Parameter) {
......
......@@ -319,8 +319,8 @@ public class MergeUsing extends Prepared {
onCondition.addFilterConditions(sourceTableFilter, true);
onCondition.addFilterConditions(targetTableFilter, true);
onCondition.mapColumns(sourceTableFilter, 2);
onCondition.mapColumns(targetTableFilter, 1);
onCondition.mapColumns(sourceTableFilter, 2, Expression.MAP_INITIAL);
onCondition.mapColumns(targetTableFilter, 1, Expression.MAP_INITIAL);
if (keys == null) {
keys = buildColumnListFromOnCondition(targetTableFilter.getTable());
......
......@@ -1043,7 +1043,7 @@ public class Select extends Query {
if (havingIndex >= 0) {
Expression expr = expressions.get(havingIndex);
SelectListColumnResolver res = new SelectListColumnResolver(this);
expr.mapColumns(res, 0);
expr.mapColumns(res, 0, Expression.MAP_INITIAL);
}
checkInit = true;
}
......@@ -1448,10 +1448,10 @@ public class Select extends Query {
@Override
public void mapColumns(ColumnResolver resolver, int level) {
for (Expression e : expressions) {
e.mapColumns(resolver, level);
e.mapColumns(resolver, level, Expression.MAP_INITIAL);
}
if (condition != null) {
condition.mapColumns(resolver, level);
condition.mapColumns(resolver, level, Expression.MAP_INITIAL);
}
}
......
......@@ -216,15 +216,15 @@ public class Update extends Prepared {
@Override
public void prepare() {
if (condition != null) {
condition.mapColumns(targetTableFilter, 0);
condition.mapColumns(targetTableFilter, 0, Expression.MAP_INITIAL);
condition = condition.optimize(session);
condition.createIndexConditions(session, targetTableFilter);
}
for (Column c : columns) {
Expression e = expressionMap.get(c);
e.mapColumns(targetTableFilter, 0);
e.mapColumns(targetTableFilter, 0, Expression.MAP_INITIAL);
if (sourceTableFilter!=null){
e.mapColumns(sourceTableFilter, 0);
e.mapColumns(sourceTableFilter, 0, Expression.MAP_INITIAL);
}
expressionMap.put(c, e.optimize(session));
}
......
......@@ -42,8 +42,8 @@ public class Alias extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
expr.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
expr.mapColumns(resolver, level, state);
}
@Override
......
......@@ -150,9 +150,9 @@ public class BinaryOperation extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
right.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
right.mapColumns(resolver, level, state);
}
@Override
......
......@@ -477,11 +477,11 @@ public class CompareLike extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
right.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
right.mapColumns(resolver, level, state);
if (escape != null) {
escape.mapColumns(resolver, level);
escape.mapColumns(resolver, level, state);
}
}
......
......@@ -497,10 +497,10 @@ public class Comparison extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
if (right != null) {
right.mapColumns(resolver, level);
right.mapColumns(resolver, level, state);
}
}
......
......@@ -256,9 +256,9 @@ public class ConditionAndOr extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
right.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
right.mapColumns(resolver, level, state);
}
@Override
......
......@@ -54,7 +54,7 @@ public class ConditionExists extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
query.mapColumns(resolver, level + 1);
}
......
......@@ -66,10 +66,10 @@ public class ConditionIn extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
for (Expression e : valueList) {
e.mapColumns(resolver, level);
e.mapColumns(resolver, level, state);
}
this.queryLevel = Math.max(level, this.queryLevel);
}
......
......@@ -83,8 +83,8 @@ public class ConditionInConstantSet extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
this.queryLevel = Math.max(level, this.queryLevel);
}
......
......@@ -109,8 +109,8 @@ public class ConditionInParameter extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
}
@Override
......
......@@ -101,8 +101,8 @@ public class ConditionInSelect extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
query.mapColumns(resolver, level + 1);
this.queryLevel = Math.max(level, this.queryLevel);
}
......
......@@ -37,8 +37,8 @@ public class ConditionNot extends Condition {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
condition.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
condition.mapColumns(resolver, level, state);
}
@Override
......
......@@ -24,6 +24,23 @@ import org.h2.value.ValueArray;
*/
public abstract class Expression {
/**
* Initial state for {@link #mapColumns(ColumnResolver, int, int)}.
*/
public static final int MAP_INITIAL = 0;
/**
* State for expressions inside a window function for
* {@link #mapColumns(ColumnResolver, int, int)}.
*/
public static final int MAP_IN_WINDOW = 1;
/**
* State for expressions inside an aggregate for
* {@link #mapColumns(ColumnResolver, int, int)}.
*/
public static final int MAP_IN_AGGREGATE = 2;
private boolean addedToFilter;
/**
......@@ -47,8 +64,10 @@ public abstract class Expression {
*
* @param resolver the column resolver
* @param level the subquery nesting level
* @param state current state for nesting checks, initial value is
* {@link #MAP_INITIAL}
*/
public abstract void mapColumns(ColumnResolver resolver, int level);
public abstract void mapColumns(ColumnResolver resolver, int level, int state);
/**
* Try to optimize the expression.
......
......@@ -79,7 +79,7 @@ public class ExpressionColumn extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
if (tableAlias != null && !database.equalsIdentifiers(
tableAlias, resolver.getTableAlias())) {
return;
......
......@@ -40,9 +40,9 @@ public class ExpressionList extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
for (Expression e : list) {
e.mapColumns(resolver, level);
e.mapColumns(resolver, level, state);
}
}
......
......@@ -2062,10 +2062,10 @@ public class Function extends Expression implements FunctionCall {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
for (Expression e : args) {
if (e != null) {
e.mapColumns(resolver, level);
e.mapColumns(resolver, level, state);
}
}
}
......
......@@ -249,10 +249,10 @@ public class IntervalOperation extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
left.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
left.mapColumns(resolver, level, state);
if (right != null) {
right.mapColumns(resolver, level);
right.mapColumns(resolver, level, state);
}
}
......
......@@ -44,9 +44,9 @@ public class JavaFunction extends Expression implements FunctionCall {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
for (Expression e : args) {
e.mapColumns(resolver, level);
e.mapColumns(resolver, level, state);
}
}
......
......@@ -71,7 +71,7 @@ public class Parameter extends Expression implements ParameterInterface {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
// can't map
}
......
......@@ -38,7 +38,7 @@ public class Rownum extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
// nothing to do
}
......
......@@ -37,7 +37,7 @@ public class SequenceValue extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
// nothing to do
}
......
......@@ -58,7 +58,7 @@ public class Subquery extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
query.mapColumns(resolver, level + 1);
}
......
......@@ -37,8 +37,8 @@ public class UnaryOperation extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
arg.mapColumns(resolver, level);
public void mapColumns(ColumnResolver resolver, int level, int state) {
arg.mapColumns(resolver, level, state);
}
@Override
......
......@@ -95,7 +95,7 @@ public class ValueExpression extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
// nothing to do
}
......
......@@ -86,7 +86,7 @@ public class Variable extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
// nothing to do
}
......
......@@ -43,7 +43,7 @@ public class Wildcard extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumns(ColumnResolver resolver, int level, int state) {
throw DbException.get(ErrorCode.SYNTAX_ERROR_1, table);
}
......
......@@ -48,11 +48,11 @@ public abstract class AbstractAggregate extends DataAnalysisOperation {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumnsAnalysis(ColumnResolver resolver, int level, int innerState) {
if (filterCondition != null) {
filterCondition.mapColumns(resolver, level);
filterCondition.mapColumns(resolver, level, innerState);
}
super.mapColumns(resolver, level);
super.mapColumnsAnalysis(resolver, level, innerState);
}
@Override
......
......@@ -512,19 +512,19 @@ public class Aggregate extends AbstractAggregate {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumnsAnalysis(ColumnResolver resolver, int level, int innerState) {
if (on != null) {
on.mapColumns(resolver, level);
on.mapColumns(resolver, level, innerState);
}
if (orderByList != null) {
for (SelectOrderBy o : orderByList) {
o.expression.mapColumns(resolver, level);
o.expression.mapColumns(resolver, level, innerState);
}
}
if (groupConcatSeparator != null) {
groupConcatSeparator.mapColumns(resolver, level);
groupConcatSeparator.mapColumns(resolver, level, innerState);
}
super.mapColumns(resolver, level);
super.mapColumnsAnalysis(resolver, level, innerState);
}
@Override
......
......@@ -81,7 +81,22 @@ public abstract class DataAnalysisOperation extends Expression {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public final void mapColumns(ColumnResolver resolver, int level, int state) {
if (over != null) {
if (state != MAP_INITIAL) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
state = MAP_IN_WINDOW;
} else {
if (state == MAP_IN_AGGREGATE) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
state = MAP_IN_AGGREGATE;
}
mapColumnsAnalysis(resolver, level, state);
}
protected void mapColumnsAnalysis(ColumnResolver resolver, int level, int innerState) {
if (over != null) {
over.mapColumns(resolver, level);
}
......@@ -109,7 +124,7 @@ public abstract class DataAnalysisOperation extends Expression {
}
@Override
public void updateAggregate(Session session, int stage) {
public final void updateAggregate(Session session, int stage) {
if (stage == Aggregate.STAGE_RESET) {
updateGroupAggregates(session, Aggregate.STAGE_RESET);
lastGroupRowId = 0;
......@@ -122,10 +137,6 @@ public abstract class DataAnalysisOperation extends Expression {
}
return;
}
// TODO aggregates: check nested MIN(MAX(ID)) and so on
// if (on != null) {
// on.updateAggregate();
// }
SelectGroups groupData = select.getGroupDataIfCurrent(window);
if (groupData == null) {
// this is a different level (the enclosing query)
......
......@@ -110,11 +110,11 @@ public class JavaAggregate extends AbstractAggregate {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumnsAnalysis(ColumnResolver resolver, int level, int innerState) {
for (Expression arg : args) {
arg.mapColumns(resolver, level);
arg.mapColumns(resolver, level, innerState);
}
super.mapColumns(resolver, level);
super.mapColumnsAnalysis(resolver, level, innerState);
}
@Override
......
......@@ -82,18 +82,18 @@ public final class Window {
* the column resolver
* @param level
* the subquery nesting level
* @see Expression#mapColumns(ColumnResolver, int)
* @see Expression#mapColumns(ColumnResolver, int, int)
*/
public void mapColumns(ColumnResolver resolver, int level) {
resolveWindows(resolver);
if (partitionBy != null) {
for (Expression e : partitionBy) {
e.mapColumns(resolver, level);
e.mapColumns(resolver, level, Expression.MAP_IN_WINDOW);
}
}
if (orderBy != null) {
for (SelectOrderBy o : orderBy) {
o.expression.mapColumns(resolver, level);
o.expression.mapColumns(resolver, level, Expression.MAP_IN_WINDOW);
}
}
}
......
......@@ -382,13 +382,13 @@ public class WindowFunction extends DataAnalysisOperation {
}
@Override
public void mapColumns(ColumnResolver resolver, int level) {
public void mapColumnsAnalysis(ColumnResolver resolver, int level, int innerState) {
if (args != null) {
for (Expression arg : args) {
arg.mapColumns(resolver, level);
arg.mapColumns(resolver, level, innerState);
}
}
super.mapColumns(resolver, level);
super.mapColumnsAnalysis(resolver, level, innerState);
}
@Override
......
......@@ -475,11 +475,11 @@ public class Column {
if (defaultExpression != null || onUpdateExpression != null) {
computeTableFilter = new TableFilter(session, table, null, false, null, 0, null);
if (defaultExpression != null) {
defaultExpression.mapColumns(computeTableFilter, 0);
defaultExpression.mapColumns(computeTableFilter, 0, Expression.MAP_INITIAL);
defaultExpression = defaultExpression.optimize(session);
}
if (onUpdateExpression != null) {
onUpdateExpression.mapColumns(computeTableFilter, 0);
onUpdateExpression.mapColumns(computeTableFilter, 0, Expression.MAP_INITIAL);
onUpdateExpression = onUpdateExpression.optimize(session);
}
}
......@@ -683,7 +683,7 @@ public class Column {
if (name == null) {
name = "VALUE";
}
expr.mapColumns(resolver, 0);
expr.mapColumns(resolver, 0, Expression.MAP_INITIAL);
name = oldName;
}
expr = expr.optimize(session);
......
......@@ -663,7 +663,7 @@ public class TableFilter implements ColumnResolver {
*/
public void addJoin(TableFilter filter, boolean outer, Expression on) {
if (on != null) {
on.mapColumns(this, 0);
on.mapColumns(this, 0, Expression.MAP_INITIAL);
TableFilterVisitor visitor = new MapColumnsVisitor(on);
visit(visitor);
filter.visit(visitor);
......@@ -697,11 +697,11 @@ public class TableFilter implements ColumnResolver {
* @param on the condition
*/
public void mapAndAddFilter(Expression on) {
on.mapColumns(this, 0);
on.mapColumns(this, 0, Expression.MAP_INITIAL);
addFilterCondition(on, true);
on.createIndexConditions(session, this);
if (nestedJoin != null) {
on.mapColumns(nestedJoin, 0);
on.mapColumns(nestedJoin, 0, Expression.MAP_INITIAL);
on.createIndexConditions(session, nestedJoin);
}
if (join != null) {
......@@ -1219,7 +1219,7 @@ public class TableFilter implements ColumnResolver {
@Override
public void accept(TableFilter f) {
on.mapColumns(f, 0);
on.mapColumns(f, 0, Expression.MAP_INITIAL);
}
}
......
......@@ -41,3 +41,12 @@ SELECT *, LAST_VALUE(ID) OVER W FROM TEST
DROP TABLE TEST;
> ok
SELECT MAX(MAX(X) OVER ()) OVER () FROM VALUES (1);
> exception INVALID_USE_OF_AGGREGATE_FUNCTION_1
SELECT MAX(MAX(X) OVER ()) FROM VALUES (1);
> exception INVALID_USE_OF_AGGREGATE_FUNCTION_1
SELECT MAX(MAX(X)) FROM VALUES (1);
> exception INVALID_USE_OF_AGGREGATE_FUNCTION_1
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论