提交 ca80a773 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Fix ConditionIn with table functions

上级 a06e2742
......@@ -54,7 +54,15 @@ public class ConditionIn extends Condition {
}
boolean result = false;
boolean hasNull = false;
for (Expression e : valueList) {
int size = valueList.size();
if (size == 1) {
Expression e = valueList.get(0);
if (e instanceof TableFunction) {
return ConditionInParameter.getValue(database, l, e.getValue(session));
}
}
for (int i = 0; i < size; i++) {
Expression e = valueList.get(i);
Value r = e.getValue(session);
if (r == ValueNull.INSTANCE) {
hasNull = true;
......@@ -87,9 +95,25 @@ public class ConditionIn extends Condition {
if (constant && left == ValueExpression.getNull()) {
return left;
}
int size = valueList.size();
if (size == 1) {
Expression right = valueList.get(0);
if (right instanceof TableFunction) {
TableFunction tf = (TableFunction) right;
if (tf.getFunctionType() == Function.UNNEST) {
Expression[] args = tf.getArgs();
if (args.length == 1) {
Expression arg = args[0];
if (arg instanceof Parameter) {
return new ConditionInParameter(database, left, (Parameter) arg);
}
}
}
return this;
}
}
boolean allValuesConstant = true;
boolean allValuesNull = true;
int size = valueList.size();
for (int i = 0; i < size; i++) {
Expression e = valueList.get(i);
e = e.optimize(session);
......@@ -109,22 +133,7 @@ public class ConditionIn extends Condition {
return ValueExpression.get(getValue(session));
}
if (size == 1) {
Expression right = valueList.get(0);
if (right instanceof TableFunction) {
TableFunction tf = (TableFunction) right;
if (tf.getFunctionType() == Function.UNNEST) {
Expression[] args = tf.getArgs();
if (args.length == 1) {
Expression arg = args[0];
if (arg instanceof Parameter) {
return new ConditionInParameter(database, left, (Parameter) arg);
}
}
}
}
Expression expr = new Comparison(session, Comparison.EQUAL, left, right);
expr = expr.optimize(session);
return expr;
return new Comparison(session, Comparison.EQUAL, left, valueList.get(0)).optimize(session);
}
if (allValuesConstant && !allValuesNull) {
int leftType = left.getType();
......
......@@ -15,6 +15,7 @@ import org.h2.expression.ExpressionVisitor;
import org.h2.expression.Parameter;
import org.h2.expression.ValueExpression;
import org.h2.index.IndexCondition;
import org.h2.result.ResultInterface;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.value.Value;
......@@ -64,6 +65,41 @@ public class ConditionInParameter extends Condition {
private final Parameter parameter;
static Value getValue(Database database, Value l, Value value) {
boolean result = false;
boolean hasNull = false;
if (value == ValueNull.INSTANCE) {
hasNull = true;
} else if (value.getType() == Value.RESULT_SET) {
for (ResultInterface ri = value.getResult(); ri.next();) {
Value r = ri.currentRow()[0];
if (r == ValueNull.INSTANCE) {
hasNull = true;
} else {
result = Comparison.compareNotNull(database, l, r, Comparison.EQUAL);
if (result) {
break;
}
}
}
} else {
for (Value r : ((ValueArray) value.convertTo(Value.ARRAY)).getList()) {
if (r == ValueNull.INSTANCE) {
hasNull = true;
} else {
result = Comparison.compareNotNull(database, l, r, Comparison.EQUAL);
if (result) {
break;
}
}
}
}
if (!result && hasNull) {
return ValueNull.INSTANCE;
}
return ValueBoolean.get(result);
}
/**
* Create a new {@code IN(UNNEST(?))} condition.
*
......@@ -86,27 +122,7 @@ public class ConditionInParameter extends Condition {
if (l == ValueNull.INSTANCE) {
return l;
}
boolean result = false;
boolean hasNull = false;
Value value = parameter.getValue(session);
if (value == ValueNull.INSTANCE) {
hasNull = true;
} else {
for (Value r : ((ValueArray) value.convertTo(Value.ARRAY)).getList()) {
if (r == ValueNull.INSTANCE) {
hasNull = true;
} else {
result = Comparison.compareNotNull(database, l, r, Comparison.EQUAL);
if (result) {
break;
}
}
}
}
if (!result && hasNull) {
return ValueNull.INSTANCE;
}
return ValueBoolean.get(result);
return getValue(database, l, parameter.getValue(session));
}
@Override
......
......@@ -40,3 +40,49 @@ EXPLAIN SELECT * FROM UNNEST(ARRAY[1]);
EXPLAIN SELECT * FROM UNNEST(ARRAY[1]) WITH ORDINALITY;
>> SELECT UNNEST.C1, UNNEST.NORD FROM UNNEST((1,)) WITH ORDINALITY /* function */
SELECT 1 IN(UNNEST(ARRAY[1, 2, 3]));
>> TRUE
SELECT 4 IN(UNNEST(ARRAY[1, 2, 3]));
>> FALSE
SELECT X, X IN(UNNEST(ARRAY[2, 4])) FROM SYSTEM_RANGE(1, 5);
> X X IN(UNNEST((2, 4)))
> - --------------------
> 1 FALSE
> 2 TRUE
> 3 FALSE
> 4 TRUE
> 5 FALSE
> rows: 5
SELECT X, X IN(UNNEST(?)) FROM SYSTEM_RANGE(1, 5);
{
2
> X X IN(UNNEST(?1))
> - ----------------
> 1 FALSE
> 2 TRUE
> 3 FALSE
> 4 FALSE
> 5 FALSE
> rows: 5
};
> update count: 0
CREATE TABLE TEST(A INT, B ARRAY);
> ok
INSERT INTO TEST VALUES (2, ARRAY[2, 4]), (3, ARRAY[2, 5]);
> update count: 2
SELECT A, B, A IN(UNNEST(B)) FROM TEST;
> A B A IN(UNNEST(B))
> - ------ ---------------
> 2 (2, 4) TRUE
> 3 (2, 5) FALSE
> rows: 2
DROP TABLE TEST;
> ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论