Unverified 提交 3b4620a2 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov 提交者: GitHub

Merge pull request #1459 from katzyn/window

Improve window clause correctness checks
......@@ -2555,7 +2555,7 @@ ROWS|RANGE|GROUP
[EXCLUDE {CURRENT ROW|GROUP|TIES|NO OTHERS}]
","
A window frame clause.
Is currently supported only in aggregates and FIRST_VALUE(), LAST_VALUE(), and NTH_VALUE() window functions.
May be specified only for aggregates and FIRST_VALUE(), LAST_VALUE(), and NTH_VALUE() window functions.
","
RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW EXCLUDE GROUP
"
......
......@@ -3049,40 +3049,7 @@ public class Parser {
}
Window over = null;
if (readIf("OVER")) {
read(OPEN_PAREN);
ArrayList<Expression> partitionBy = null;
if (readIf("PARTITION")) {
read("BY");
partitionBy = Utils.newSmallArrayList();
do {
Expression expr = readExpression();
partitionBy.add(expr);
} while (readIf(COMMA));
}
ArrayList<SelectOrderBy> orderBy = null;
if (readIf(ORDER)) {
read("BY");
orderBy = parseSimpleOrderList();
} else if (!isAggregate) {
orderBy = new ArrayList<>(0);
}
WindowFrame frame;
if (aggregate instanceof WindowFunction) {
WindowFunction w = (WindowFunction) aggregate;
switch (w.getFunctionType()) {
case FIRST_VALUE:
case LAST_VALUE:
case NTH_VALUE:
frame = readWindowFrame();
break;
default:
frame = null;
}
} else {
frame = readWindowFrame();
}
read(CLOSE_PAREN);
over = new Window(partitionBy, orderBy, frame);
over = readWindowSpecification();
aggregate.setOverCondition(over);
currentSelect.setWindowQuery();
} else if (!isAggregate) {
......@@ -3092,6 +3059,27 @@ public class Parser {
}
}
private Window readWindowSpecification() {
read(OPEN_PAREN);
ArrayList<Expression> partitionBy = null;
if (readIf("PARTITION")) {
read("BY");
partitionBy = Utils.newSmallArrayList();
do {
Expression expr = readExpression();
partitionBy.add(expr);
} while (readIf(COMMA));
}
ArrayList<SelectOrderBy> orderBy = null;
if (readIf(ORDER)) {
read("BY");
orderBy = parseSimpleOrderList();
}
WindowFrame frame = readWindowFrame();
read(CLOSE_PAREN);
return new Window(partitionBy, orderBy, frame);
}
private WindowFrame readWindowFrame() {
WindowFrameUnits units;
if (readIf("ROWS")) {
......
......@@ -118,6 +118,8 @@ public abstract class AbstractAggregate extends Expression {
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
overOrderBySort = createOrder(session, orderBy, getNumExpressions());
} else if (!isAggregate()) {
overOrderBySort = new SortOrder(session.getDatabase(), new int[getNumExpressions()], new int[0], null);
}
}
return this;
......@@ -176,7 +178,7 @@ public abstract class AbstractAggregate extends Expression {
}
if (over != null) {
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
if (orderBy != null || !isAggregate()) {
updateOrderedAggregate(session, groupData, groupRowId, orderBy);
return;
}
......@@ -359,7 +361,7 @@ public abstract class AbstractAggregate extends Expression {
data = partition.getData();
}
}
if (over.getOrderBy() != null) {
if (over.getOrderBy() != null || !isAggregate()) {
return getOrderedResult(session, groupData, partition, data);
}
Value result = partition.getResult();
......@@ -384,10 +386,11 @@ public abstract class AbstractAggregate extends Expression {
private void updateOrderedAggregate(Session session, SelectGroups groupData, int groupRowId,
ArrayList<SelectOrderBy> orderBy) {
int ne = getNumExpressions();
int size = orderBy.size();
int size = orderBy != null ? orderBy.size() : 0;
Value[] array = new Value[ne + size + 1];
rememberExpressions(session, array);
for (int i = 0; i < size; i++) {
@SuppressWarnings("null")
SelectOrderBy o = orderBy.get(i);
array[ne++] = o.expression.getValue(session);
}
......@@ -395,7 +398,6 @@ public abstract class AbstractAggregate extends Expression {
@SuppressWarnings("unchecked")
ArrayList<Value[]> data = (ArrayList<Value[]>) getData(session, groupData, false, true);
data.add(array);
return;
}
private Value getOrderedResult(Session session, SelectGroups groupData, PartitionData partition, Object data) {
......@@ -404,9 +406,12 @@ public abstract class AbstractAggregate extends Expression {
result = new HashMap<>();
@SuppressWarnings("unchecked")
ArrayList<Value[]> orderedData = (ArrayList<Value[]>) data;
int ne = getNumExpressions();
int rowIdColumn = ne + over.getOrderBy().size();
Collections.sort(orderedData, overOrderBySort);
int rowIdColumn = getNumExpressions();
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy != null) {
rowIdColumn += orderBy.size();
Collections.sort(orderedData, overOrderBySort);
}
getOrderedResultLoop(session, result, orderedData, rowIdColumn);
partition.setOrderedResult(result);
}
......@@ -426,7 +431,7 @@ public abstract class AbstractAggregate extends Expression {
protected void getOrderedResultLoop(Session session, HashMap<Integer, Value> result, ArrayList<Value[]> ordered,
int rowIdColumn) {
WindowFrame frame = over.getWindowFrame();
if (frame.isDefault()) {
if (frame == null || frame.isDefault()) {
Object aggregateData = createAggregateData();
for (Value[] row : ordered) {
updateFromExpressions(session, aggregateData, row);
......
......@@ -64,11 +64,6 @@ public final class Window {
public Window(ArrayList<Expression> partitionBy, ArrayList<SelectOrderBy> orderBy, WindowFrame frame) {
this.partitionBy = partitionBy;
this.orderBy = orderBy;
if (frame == null) {
frame = new WindowFrame(WindowFrameUnits.RANGE,
new WindowFrameBound(WindowFrameBoundType.UNBOUNDED_PRECEDING, null), null,
WindowFrameExclusion.EXCLUDE_NO_OTHERS);
}
this.frame = frame;
}
......@@ -146,9 +141,9 @@ public final class Window {
}
/**
* Returns window frame.
* Returns window frame, or null.
*
* @return window frame
* @return window frame, or null
*/
public WindowFrame getWindowFrame() {
return frame;
......@@ -182,7 +177,7 @@ public final class Window {
* @see Expression#getSQL()
*/
public String getSQL() {
if (partitionBy == null && orderBy == null) {
if (partitionBy == null && orderBy == null && frame == null) {
return "OVER ()";
}
StringBuilder builder = new StringBuilder().append("OVER (");
......@@ -196,8 +191,11 @@ public final class Window {
}
}
appendOrderBy(builder, orderBy);
if (!frame.isDefault()) {
builder.append(' ').append(frame.getSQL());
if (frame != null && !frame.isDefault()) {
if (builder.charAt(builder.length() - 1) != '(') {
builder.append(' ');
}
builder.append(frame.getSQL());
}
return builder.append(')').toString();
}
......
......@@ -24,7 +24,7 @@ import org.h2.value.Value;
*/
public final class WindowFrame {
private abstract class Itr implements Iterator<Value[]> {
private static abstract class Itr implements Iterator<Value[]> {
final ArrayList<Value[]> orderedRows;
......@@ -39,7 +39,7 @@ public final class WindowFrame {
}
private final class PlainItr extends Itr {
private static final class PlainItr extends Itr {
private final int endIndex;
......@@ -66,7 +66,7 @@ public final class WindowFrame {
}
private final class PlainReverseItr extends Itr {
private static final class PlainReverseItr extends Itr {
private final int startIndex;
......@@ -93,7 +93,7 @@ public final class WindowFrame {
}
private abstract class AbstractBitSetItr extends Itr {
private static abstract class AbstractBitSetItr extends Itr {
final BitSet set;
......@@ -111,7 +111,7 @@ public final class WindowFrame {
}
private final class BitSetItr extends AbstractBitSetItr {
private static final class BitSetItr extends AbstractBitSetItr {
BitSetItr(ArrayList<Value[]> orderedRows, BitSet set) {
super(orderedRows, set);
......@@ -130,7 +130,7 @@ public final class WindowFrame {
}
private final class BitSetReverseItr extends AbstractBitSetItr {
private static final class BitSetReverseItr extends AbstractBitSetItr {
BitSetReverseItr(ArrayList<Value[]> orderedRows, BitSet set) {
super(orderedRows, set);
......@@ -157,6 +157,31 @@ public final class WindowFrame {
private final WindowFrameExclusion exclusion;
/**
* Returns iterator for the specified frame, or default iterator if frame is
* null.
*
* @param frame
* window frame, or null
* @param session
* the session
* @param orderedRows
* ordered rows
* @param sortOrder
* sort order
* @param currentRow
* index of the current row
* @param reverse
* whether iterator should iterate in reverse order
*
* @return iterator
*/
public static Iterator<Value[]> iterator(WindowFrame frame, Session session, ArrayList<Value[]> orderedRows,
SortOrder sortOrder, int currentRow, boolean reverse) {
return frame != null ? frame.iterator(session, orderedRows, sortOrder, currentRow, reverse)
: reverse ? new PlainReverseItr(orderedRows, 0, currentRow) : new PlainItr(orderedRows, 0, currentRow);
}
private static int toGroupStart(ArrayList<Value[]> orderedRows, SortOrder sortOrder, int offset, int minOffset) {
Value[] row = orderedRows.get(offset);
while (offset > minOffset && sortOrder.compare(row, orderedRows.get(offset - 1)) == 0) {
......
......@@ -10,6 +10,7 @@ import java.util.HashMap;
import java.util.Iterator;
import org.h2.command.dml.Select;
import org.h2.command.dml.SelectOrderBy;
import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.message.DbException;
......@@ -358,10 +359,12 @@ public class WindowFunction extends AbstractAggregate {
Value v;
switch (type) {
case FIRST_VALUE:
v = getNthValue(frame.iterator(session, ordered, getOverOrderBySort(), i, false), 0, ignoreNulls);
v = getNthValue(WindowFrame.iterator(frame, session, ordered, getOverOrderBySort(), i, false), 0,
ignoreNulls);
break;
case LAST_VALUE:
v = getNthValue(frame.iterator(session, ordered, getOverOrderBySort(), i, true), 0, ignoreNulls);
v = getNthValue(WindowFrame.iterator(frame, session, ordered, getOverOrderBySort(), i, true), 0,
ignoreNulls);
break;
case NTH_VALUE: {
int n = row[1].getInt();
......@@ -369,7 +372,8 @@ public class WindowFunction extends AbstractAggregate {
throw DbException.getInvalidValueException("nth row", n);
}
n--;
Iterator<Value[]> iter = frame.iterator(session, ordered, getOverOrderBySort(), i, fromLast);
Iterator<Value[]> iter = WindowFrame.iterator(frame, session, ordered, getOverOrderBySort(), i,
fromLast);
v = getNthValue(iter, n, ignoreNulls);
break;
}
......@@ -397,6 +401,30 @@ public class WindowFunction extends AbstractAggregate {
@Override
public Expression optimize(Session session) {
if (over.getWindowFrame() != null) {
switch (type) {
case FIRST_VALUE:
case LAST_VALUE:
case NTH_VALUE:
break;
default:
String sql = getSQL();
throw DbException.getSyntaxError(sql, sql.length() - 1);
}
}
ArrayList<SelectOrderBy> orderBy = over.getOrderBy();
if (orderBy == null || orderBy.isEmpty()) {
switch (type) {
case RANK:
case DENSE_RANK:
case NTILE:
case LEAD:
case LAG:
String sql = getSQL();
throw DbException.getSyntaxError(sql, sql.length() - 1, "ORDER BY");
default:
}
}
super.optimize(session);
if (args != null) {
for (int i = 0; i < args.length; i++) {
......
......@@ -129,5 +129,17 @@ SELECT LEAD(VALUE, -1) OVER (ORDER BY ID) FROM TEST;
SELECT LAG(VALUE, -1) OVER (ORDER BY ID) FROM TEST;
> exception INVALID_VALUE_2
SELECT LEAD(VALUE) OVER () FROM TEST;
> exception SYNTAX_ERROR_2
SELECT LAG(VALUE) OVER () FROM TEST;
> exception SYNTAX_ERROR_2
SELECT LEAD(VALUE) OVER (ORDER BY ID RANGE CURRENT ROW) FROM TEST;
> exception SYNTAX_ERROR_1
SELECT LAG(VALUE) OVER (ORDER BY ID RANGE CURRENT ROW) FROM TEST;
> exception SYNTAX_ERROR_1
DROP TABLE TEST;
> ok
......@@ -113,3 +113,8 @@ SELECT NTILE(X) OVER (ORDER BY X) FROM (SELECT * FROM SYSTEM_RANGE(1, 6));
> 6
> rows (ordered): 6
SELECT NTILE(X) OVER () FROM (SELECT * FROM SYSTEM_RANGE(1, 1));
> exception SYNTAX_ERROR_2
SELECT NTILE(X) OVER (ORDER BY X RANGE CURRENT ROW) FROM (SELECT * FROM SYSTEM_RANGE(1, 1));
> exception SYNTAX_ERROR_1
......@@ -20,8 +20,6 @@ INSERT INTO TEST VALUES
SELECT *,
ROW_NUMBER() OVER () RN,
RANK() OVER () RK,
DENSE_RANK() OVER () DR,
ROUND(PERCENT_RANK() OVER (), 2) PR,
ROUND(CUME_DIST() OVER (), 2) CD,
......@@ -32,17 +30,17 @@ SELECT *,
ROUND(CUME_DIST() OVER (ORDER BY ID), 2) CDO
FROM TEST;
> ID CATEGORY VALUE RN RK DR PR CD RNO RKO DRO PRO CDO
> -- -------- ----- -- -- -- --- --- --- --- --- ---- ----
> 1 1 11 1 1 1 0.0 1.0 1 1 1 0.0 0.11
> 2 1 12 2 1 1 0.0 1.0 2 2 2 0.13 0.22
> 3 1 13 3 1 1 0.0 1.0 3 3 3 0.25 0.33
> 4 2 21 4 1 1 0.0 1.0 4 4 4 0.38 0.44
> 5 2 22 5 1 1 0.0 1.0 5 5 5 0.5 0.56
> 6 3 31 6 1 1 0.0 1.0 6 6 6 0.63 0.67
> 7 3 32 7 1 1 0.0 1.0 7 7 7 0.75 0.78
> 8 3 33 8 1 1 0.0 1.0 8 8 8 0.88 0.89
> 9 4 41 9 1 1 0.0 1.0 9 9 9 1.0 1.0
> ID CATEGORY VALUE RN PR CD RNO RKO DRO PRO CDO
> -- -------- ----- -- --- --- --- --- --- ---- ----
> 1 1 11 1 0.0 1.0 1 1 1 0.0 0.11
> 2 1 12 2 0.0 1.0 2 2 2 0.13 0.22
> 3 1 13 3 0.0 1.0 3 3 3 0.25 0.33
> 4 2 21 4 0.0 1.0 4 4 4 0.38 0.44
> 5 2 22 5 0.0 1.0 5 5 5 0.5 0.56
> 6 3 31 6 0.0 1.0 6 6 6 0.63 0.67
> 7 3 32 7 0.0 1.0 7 7 7 0.75 0.78
> 8 3 33 8 0.0 1.0 8 8 8 0.88 0.89
> 9 4 41 9 0.0 1.0 9 9 9 1.0 1.0
> rows (ordered): 9
SELECT *,
......@@ -86,19 +84,41 @@ SELECT *,
> rows (ordered): 9
SELECT
ROW_NUMBER() OVER () RN,
RANK() OVER () RK,
DENSE_RANK() OVER () DR,
ROW_NUMBER() OVER (ORDER BY CATEGORY) RN,
RANK() OVER (ORDER BY CATEGORY) RK,
DENSE_RANK() OVER (ORDER BY CATEGORY) DR,
PERCENT_RANK() OVER () PR,
CUME_DIST() OVER () CD
FROM TEST GROUP BY CATEGORY;
> RN RK DR PR CD
> -- -- -- --- ---
> 1 1 1 0.0 1.0
> 2 1 1 0.0 1.0
> 3 1 1 0.0 1.0
> 4 1 1 0.0 1.0
> rows: 4
CUME_DIST() OVER () CD,
CATEGORY C
FROM TEST GROUP BY CATEGORY ORDER BY RN;
> RN RK DR PR CD C
> -- -- -- --- --- -
> 1 1 1 0.0 1.0 1
> 2 2 2 0.0 1.0 2
> 3 3 3 0.0 1.0 3
> 4 4 4 0.0 1.0 4
> rows (ordered): 4
SELECT RANK() OVER () FROM TEST;
> exception SYNTAX_ERROR_2
SELECT DENSE_RANK() OVER () FROM TEST;
> exception SYNTAX_ERROR_2
SELECT ROW_NUMBER() OVER (ORDER BY ID RANGE CURRENT ROW) FROM TEST;
> exception SYNTAX_ERROR_1
SELECT RANK() OVER (ORDER BY ID RANGE CURRENT ROW) FROM TEST;
> exception SYNTAX_ERROR_1
SELECT DENSE_RANK() OVER (ORDER BY ID RANGE CURRENT ROW) FROM TEST;
> exception SYNTAX_ERROR_1
SELECT PERCENT_RANK() OVER (ORDER BY ID RANGE CURRENT ROW) FROM TEST;
> exception SYNTAX_ERROR_1
SELECT CUME_DIST() OVER (ORDER BY ID RANGE CURRENT ROW) FROM TEST;
> exception SYNTAX_ERROR_1
DROP TABLE TEST;
> ok
......@@ -121,5 +141,11 @@ SELECT ROW_NUMBER() OVER(ORDER /**/ BY TYPE) RN, TYPE, SUM(CNT) SUM FROM TEST GR
> 3 c 4
> rows: 3
SELECT RANK () OVER () FROM TEST;
> exception SYNTAX_ERROR_2
SELECT DENSE_RANK () OVER () FROM TEST;
> exception SYNTAX_ERROR_2
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论