提交 35b50ffc authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Add initial implementation of FIRST_VALUE(), LAST_VALUE(), and NTH_VALUE()

上级 bd2fb602
...@@ -3270,8 +3270,19 @@ public class Parser { ...@@ -3270,8 +3270,19 @@ public class Parser {
if (currentSelect == null) { if (currentSelect == null) {
throw getSyntaxError(); throw getSyntaxError();
} }
int numArgs = WindowFunction.getArgumentCount(type);
Expression[] args = null;
if (numArgs > 0) {
args = new Expression[numArgs];
for (int i = 0; i < numArgs; i++) {
if (i > 0) {
read(COMMA);
}
args[i] = readExpression();
}
}
read(CLOSE_PAREN); read(CLOSE_PAREN);
WindowFunction function = new WindowFunction(type, currentSelect); WindowFunction function = new WindowFunction(type, currentSelect, args);
readFilterAndOver(function); readFilterAndOver(function);
return function; return function;
} }
......
...@@ -10,10 +10,14 @@ import java.util.HashMap; ...@@ -10,10 +10,14 @@ import java.util.HashMap;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Expression;
import org.h2.message.DbException; import org.h2.message.DbException;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueDouble; import org.h2.value.ValueDouble;
import org.h2.value.ValueInt; import org.h2.value.ValueInt;
import org.h2.value.ValueNull;
/** /**
* A window function. * A window function.
...@@ -50,6 +54,21 @@ public class WindowFunction extends AbstractAggregate { ...@@ -50,6 +54,21 @@ public class WindowFunction extends AbstractAggregate {
*/ */
CUME_DIST, CUME_DIST,
/**
* The type for FIRST_VALUE() window function.
*/
FIRST_VALUE,
/**
* The type for LAST_VALUE() window function.
*/
LAST_VALUE,
/**
* The type for NTH_VALUE() window function.
*/
NTH_VALUE,
; ;
/** /**
...@@ -62,15 +81,21 @@ public class WindowFunction extends AbstractAggregate { ...@@ -62,15 +81,21 @@ public class WindowFunction extends AbstractAggregate {
public static WindowFunctionType get(String name) { public static WindowFunctionType get(String name) {
switch (name) { switch (name) {
case "ROW_NUMBER": case "ROW_NUMBER":
return WindowFunctionType.ROW_NUMBER; return ROW_NUMBER;
case "RANK": case "RANK":
return RANK; return RANK;
case "DENSE_RANK": case "DENSE_RANK":
return WindowFunctionType.DENSE_RANK; return DENSE_RANK;
case "PERCENT_RANK": case "PERCENT_RANK":
return WindowFunctionType.PERCENT_RANK; return PERCENT_RANK;
case "CUME_DIST": case "CUME_DIST":
return WindowFunctionType.CUME_DIST; return CUME_DIST;
case "FIRST_VALUE":
return FIRST_VALUE;
case "LAST_VALUE":
return LAST_VALUE;
case "NTH_VALUE":
return NTH_VALUE;
default: default:
return null; return null;
} }
...@@ -80,6 +105,27 @@ public class WindowFunction extends AbstractAggregate { ...@@ -80,6 +105,27 @@ public class WindowFunction extends AbstractAggregate {
private final WindowFunctionType type; private final WindowFunctionType type;
private final Expression[] args;
/**
* Returns number of arguments for the specified type.
*
* @param type
* the type of a window function
* @return number of arguments
*/
public static int getArgumentCount(WindowFunctionType type) {
switch (type) {
case FIRST_VALUE:
case LAST_VALUE:
return 1;
case NTH_VALUE:
return 2;
default:
return 0;
}
}
/** /**
* Creates new instance of a window function. * Creates new instance of a window function.
* *
...@@ -87,10 +133,13 @@ public class WindowFunction extends AbstractAggregate { ...@@ -87,10 +133,13 @@ public class WindowFunction extends AbstractAggregate {
* the type * the type
* @param select * @param select
* the select statement * the select statement
* @param args
* arguments, or null
*/ */
public WindowFunction(WindowFunctionType type, Select select) { public WindowFunction(WindowFunctionType type, Select select, Expression[] args) {
super(select, false); super(select, false);
this.type = type; this.type = type;
this.args = args;
} }
@Override @Override
...@@ -105,17 +154,24 @@ public class WindowFunction extends AbstractAggregate { ...@@ -105,17 +154,24 @@ public class WindowFunction extends AbstractAggregate {
@Override @Override
protected void updateGroupAggregates(Session session, int stage) { protected void updateGroupAggregates(Session session, int stage) {
// Nothing to do if (args != null) {
for (Expression expr : args) {
expr.updateAggregate(session, stage);
}
}
} }
@Override @Override
protected int getNumExpressions() { protected int getNumExpressions() {
return 0; return getArgumentCount(type);
} }
@Override @Override
protected void rememberExpressions(Session session, Value[] array) { protected void rememberExpressions(Session session, Value[] array) {
// Nothing to do int cnt = getNumExpressions();
for (int i = 0; i < cnt; i++) {
array[i] = args[i].getValue(session);
}
} }
@Override @Override
...@@ -175,6 +231,25 @@ public class WindowFunction extends AbstractAggregate { ...@@ -175,6 +231,25 @@ public class WindowFunction extends AbstractAggregate {
v = ValueDouble.get((double) nm / size); v = ValueDouble.get((double) nm / size);
break; break;
} }
case FIRST_VALUE:
v = ordered.get(0)[0];
break;
case LAST_VALUE:
v = row[0];
break;
case NTH_VALUE: {
int n = row[1].getInt();
if (n <= 0) {
throw DbException.getInvalidValueException("nth row", n);
}
n--;
if (n < 0 || n > i) {
v = ValueNull.INSTANCE;
} else {
v = ordered.get(n)[0];
}
break;
}
default: default:
throw DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
} }
...@@ -205,6 +280,37 @@ public class WindowFunction extends AbstractAggregate { ...@@ -205,6 +280,37 @@ public class WindowFunction extends AbstractAggregate {
throw DbException.getUnsupportedException("Window function"); throw DbException.getUnsupportedException("Window function");
} }
@Override
public void mapColumns(ColumnResolver resolver, int level) {
if (args != null) {
for (Expression arg : args) {
arg.mapColumns(resolver, level);
}
}
super.mapColumns(resolver, level);
}
@Override
public Expression optimize(Session session) {
super.optimize(session);
if (args != null) {
for (int i = 0; i < args.length; i++) {
args[i] = args[i].optimize(session);
}
}
return this;
}
@Override
public void setEvaluatable(TableFilter tableFilter, boolean b) {
if (args != null) {
for (Expression e : args) {
e.setEvaluatable(tableFilter, b);
}
}
super.setEvaluatable(tableFilter, b);
}
@Override @Override
public int getType() { public int getType() {
switch (type) { switch (type) {
...@@ -215,6 +321,10 @@ public class WindowFunction extends AbstractAggregate { ...@@ -215,6 +321,10 @@ public class WindowFunction extends AbstractAggregate {
case PERCENT_RANK: case PERCENT_RANK:
case CUME_DIST: case CUME_DIST:
return Value.DOUBLE; return Value.DOUBLE;
case FIRST_VALUE:
case LAST_VALUE:
case NTH_VALUE:
return args[0].getType();
default: default:
throw DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
} }
...@@ -222,7 +332,14 @@ public class WindowFunction extends AbstractAggregate { ...@@ -222,7 +332,14 @@ public class WindowFunction extends AbstractAggregate {
@Override @Override
public int getScale() { public int getScale() {
return 0; switch (type) {
case FIRST_VALUE:
case LAST_VALUE:
case NTH_VALUE:
return args[0].getScale();
default:
return 0;
}
} }
@Override @Override
...@@ -235,6 +352,10 @@ public class WindowFunction extends AbstractAggregate { ...@@ -235,6 +352,10 @@ public class WindowFunction extends AbstractAggregate {
case PERCENT_RANK: case PERCENT_RANK:
case CUME_DIST: case CUME_DIST:
return ValueDouble.PRECISION; return ValueDouble.PRECISION;
case FIRST_VALUE:
case LAST_VALUE:
case NTH_VALUE:
return args[0].getPrecision();
default: default:
throw DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
} }
...@@ -250,6 +371,10 @@ public class WindowFunction extends AbstractAggregate { ...@@ -250,6 +371,10 @@ public class WindowFunction extends AbstractAggregate {
case PERCENT_RANK: case PERCENT_RANK:
case CUME_DIST: case CUME_DIST:
return ValueDouble.DISPLAY_SIZE; return ValueDouble.DISPLAY_SIZE;
case FIRST_VALUE:
case LAST_VALUE:
case NTH_VALUE:
return args[0].getDisplaySize();
default: default:
throw DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
} }
...@@ -258,6 +383,7 @@ public class WindowFunction extends AbstractAggregate { ...@@ -258,6 +383,7 @@ public class WindowFunction extends AbstractAggregate {
@Override @Override
public String getSQL() { public String getSQL() {
String text; String text;
int numArgs = 0;
switch (type) { switch (type) {
case ROW_NUMBER: case ROW_NUMBER:
text = "ROW_NUMBER"; text = "ROW_NUMBER";
...@@ -274,16 +400,40 @@ public class WindowFunction extends AbstractAggregate { ...@@ -274,16 +400,40 @@ public class WindowFunction extends AbstractAggregate {
case CUME_DIST: case CUME_DIST:
text = "CUME_DIST"; text = "CUME_DIST";
break; break;
case FIRST_VALUE:
text = "FIRST_VALUE";
numArgs = 1;
break;
case LAST_VALUE:
text = "LAST_VALUE";
numArgs = 1;
break;
case NTH_VALUE:
text = "NTH_VALUE";
numArgs = 2;
break;
default: default:
throw DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
} }
StringBuilder builder = new StringBuilder().append(text).append("()"); StringBuilder builder = new StringBuilder().append(text).append('(');
for (int i = 0; i < numArgs; i++) {
if (i > 0) {
builder.append(", ");
}
builder.append(args[i].getSQL());
}
builder.append(')');
return appendTailConditions(builder).toString(); return appendTailConditions(builder).toString();
} }
@Override @Override
public int getCost() { public int getCost() {
int cost = 1; int cost = 1;
if (args != null) {
for (Expression expr : args) {
cost += expr.getCost();
}
}
return cost; return cost;
} }
......
...@@ -179,7 +179,7 @@ public class TestScript extends TestDb { ...@@ -179,7 +179,7 @@ public class TestScript extends TestDb {
"parsedatetime", "quarter", "second", "truncate", "week", "year", "date_trunc" }) { "parsedatetime", "quarter", "second", "truncate", "week", "year", "date_trunc" }) {
testScript("functions/timeanddate/" + s + ".sql"); testScript("functions/timeanddate/" + s + ".sql");
} }
for (String s : new String[] { "row_number" }) { for (String s : new String[] { "row_number", "nth_value" }) {
testScript("functions/window/" + s + ".sql"); testScript("functions/window/" + s + ".sql");
} }
......
-- Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
-- and the EPL 1.0 (http://h2database.com/html/license.html).
-- Initial Developer: H2 Group
--
CREATE TABLE TEST (ID INT PRIMARY KEY, CATEGORY INT, VALUE INT);
> ok
INSERT INTO TEST VALUES
(1, 1, NULL),
(2, 1, 12),
(3, 1, NULL),
(4, 1, 13),
(5, 1, NULL),
(6, 1, 13),
(7, 2, 21),
(8, 2, 22),
(9, 3, 31),
(10, 3, 32),
(11, 3, 33),
(12, 4, 41),
(13, 4, NULL);
> update count: 13
SELECT *,
FIRST_VALUE(VALUE) OVER (ORDER BY ID) FIRST,
LAST_VALUE(VALUE) OVER (ORDER BY ID) LAST,
NTH_VALUE(VALUE, 2) OVER (ORDER BY ID) NTH
FROM TEST FETCH FIRST 4 ROWS ONLY;
> ID CATEGORY VALUE FIRST LAST NTH
> -- -------- ----- ----- ---- ----
> 1 1 null null null null
> 2 1 12 null 12 12
> 3 1 null null null 12
> 4 1 13 null 13 12
> rows (ordered): 4
SELECT NTH_VALUE(VALUE, 0) OVER (ORDER BY ID) FROM TEST;
> exception INVALID_VALUE_2
SELECT *,
FIRST_VALUE(VALUE) OVER (PARTITION BY CATEGORY ORDER BY ID) FIRST,
LAST_VALUE(VALUE) OVER (PARTITION BY CATEGORY ORDER BY ID) LAST,
NTH_VALUE(VALUE, 2) OVER (PARTITION BY CATEGORY ORDER BY ID) NTH
FROM TEST;
> ID CATEGORY VALUE FIRST LAST NTH
> -- -------- ----- ----- ---- ----
> 1 1 null null null null
> 2 1 12 null 12 12
> 3 1 null null null 12
> 4 1 13 null 13 12
> 5 1 null null null 12
> 6 1 13 null 13 12
> 7 2 21 21 21 null
> 8 2 22 21 22 22
> 9 3 31 31 31 null
> 10 3 32 31 32 32
> 11 3 33 31 33 32
> 12 4 41 41 41 null
> 13 4 null 41 null null
> rows (ordered): 13
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论