提交 3de090e0 authored 作者: Thomas Mueller's avatar Thomas Mueller

User defined aggregate functions: the method getType expected internal data types.

上级 d6054c63
...@@ -29,10 +29,10 @@ public interface AggregateFunction { ...@@ -29,10 +29,10 @@ public interface AggregateFunction {
* The method should check here if the number of parameters passed is correct, * The method should check here if the number of parameters passed is correct,
* and if not it should throw an exception. * and if not it should throw an exception.
* *
* @param inputType the SQL type of the parameters * @param inputTypes the SQL type of the parameters
* @return the SQL type of the result * @return the SQL type of the result
*/ */
int getType(int[] inputType) throws SQLException; int getType(int[] inputTypes) throws SQLException;
/** /**
* This method is called once for each row. * This method is called once for each row.
......
...@@ -60,7 +60,7 @@ public class JavaAggregate extends Expression { ...@@ -60,7 +60,7 @@ public class JavaAggregate extends Expression {
} }
public int getScale() { public int getScale() {
return 0; return DataType.getDataType(dataType).defaultScale;
} }
public String getSQL() { public String getSQL() {
...@@ -91,6 +91,9 @@ public class JavaAggregate extends Expression { ...@@ -91,6 +91,9 @@ public class JavaAggregate extends Expression {
case ExpressionVisitor.GET_DEPENDENCIES: case ExpressionVisitor.GET_DEPENDENCIES:
visitor.addDependency(userAggregate); visitor.addDependency(userAggregate);
break; break;
case ExpressionVisitor.OPTIMIZABLE_MIN_MAX_COUNT_ALL:
// user defined aggregate functions can not be optimized
return false;
default: default:
} }
for (int i = 0; i < args.length; i++) { for (int i = 0; i < args.length; i++) {
...@@ -111,13 +114,16 @@ public class JavaAggregate extends Expression { ...@@ -111,13 +114,16 @@ public class JavaAggregate extends Expression {
public Expression optimize(Session session) throws SQLException { public Expression optimize(Session session) throws SQLException {
userConnection = session.createConnection(false); userConnection = session.createConnection(false);
argTypes = new int[args.length]; argTypes = new int[args.length];
int[] argSqlTypes = new int[args.length];
for (int i = 0; i < args.length; i++) { for (int i = 0; i < args.length; i++) {
Expression expr = args[i]; Expression expr = args[i];
args[i] = expr.optimize(session); args[i] = expr.optimize(session);
argTypes[i] = expr.getType(); int type = expr.getType();
argTypes[i] = type;
argSqlTypes[i] = DataType.convertTypeToSQLType(type);
} }
aggregate = getInstance(); aggregate = getInstance();
dataType = aggregate.getType(argTypes); dataType = DataType.convertSQLTypeToValueType(aggregate.getType(argSqlTypes));
return this; return this;
} }
......
...@@ -13,6 +13,7 @@ import org.h2.engine.FunctionAlias; ...@@ -13,6 +13,7 @@ import org.h2.engine.FunctionAlias;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.table.ColumnResolver; import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter; import org.h2.table.TableFilter;
import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueNull; import org.h2.value.ValueNull;
import org.h2.value.ValueResultSet; import org.h2.value.ValueResultSet;
...@@ -64,7 +65,7 @@ public class JavaFunction extends Expression implements FunctionCall { ...@@ -64,7 +65,7 @@ public class JavaFunction extends Expression implements FunctionCall {
} }
public int getScale() { public int getScale() {
return 0; return DataType.getDataType(getType()).defaultScale;
} }
public long getPrecision() { public long getPrecision() {
......
...@@ -11,6 +11,7 @@ import java.io.File; ...@@ -11,6 +11,7 @@ import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.FileReader; import java.io.FileReader;
import java.io.InputStream; import java.io.InputStream;
import java.math.BigDecimal;
import java.sql.Blob; import java.sql.Blob;
import java.sql.Connection; import java.sql.Connection;
import java.sql.DatabaseMetaData; import java.sql.DatabaseMetaData;
...@@ -30,7 +31,7 @@ import org.h2.util.IOUtils; ...@@ -30,7 +31,7 @@ import org.h2.util.IOUtils;
/** /**
* Tests for user defined functions and aggregates. * Tests for user defined functions and aggregates.
*/ */
public class TestFunctions extends TestBase { public class TestFunctions extends TestBase implements AggregateFunction {
/** /**
* Run just this test. * Run just this test.
...@@ -42,6 +43,8 @@ public class TestFunctions extends TestBase { ...@@ -42,6 +43,8 @@ public class TestFunctions extends TestBase {
} }
public void test() throws Exception { public void test() throws Exception {
deleteDb("functions");
testPrecision();
testVarArgs(); testVarArgs();
testAggregate(); testAggregate();
testFunctions(); testFunctions();
...@@ -49,6 +52,23 @@ public class TestFunctions extends TestBase { ...@@ -49,6 +52,23 @@ public class TestFunctions extends TestBase {
deleteDb("functions"); deleteDb("functions");
} }
private void testPrecision() throws SQLException {
Connection conn = getConnection("functions");
Statement stat = conn.createStatement();
stat.execute("create alias no_op for \""+getClass().getName()+".noOp\"");
PreparedStatement prep = conn.prepareStatement("select * from dual where no_op(1.6)=?");
prep.setBigDecimal(1, new BigDecimal("1.6"));
ResultSet rs = prep.executeQuery();
assertTrue(rs.next());
stat.execute("create aggregate agg_sum for \""+getClass().getName()+"\"");
rs = stat.executeQuery("select agg_sum(1), sum(1.6) from dual");
rs.next();
assertEquals(1, rs.getMetaData().getScale(2));
assertEquals(32767, rs.getMetaData().getScale(1));
conn.close();
}
private void testVarArgs() throws SQLException { private void testVarArgs() throws SQLException {
//## Java 1.5 begin ## //## Java 1.5 begin ##
Connection conn = getConnection("functions"); Connection conn = getConnection("functions");
...@@ -433,6 +453,13 @@ public class TestFunctions extends TestBase { ...@@ -433,6 +453,13 @@ public class TestFunctions extends TestBase {
return 1; return 1;
} }
/**
* This method is called via reflection from the database.
*/
public static BigDecimal noOp(BigDecimal dec) {
return dec;
}
/** /**
* This method is called via reflection from the database. * This method is called via reflection from the database.
*/ */
...@@ -473,4 +500,21 @@ public class TestFunctions extends TestBase { ...@@ -473,4 +500,21 @@ public class TestFunctions extends TestBase {
} }
//## Java 1.5 end ## //## Java 1.5 end ##
public void add(Object value) throws SQLException {
}
public Object getResult() throws SQLException {
return new BigDecimal("1.6");
}
public int getType(int[] inputTypes) throws SQLException {
if (inputTypes.length != 1 || inputTypes[0] != Types.INTEGER) {
throw new SQLException("unexpected data type");
}
return Types.DECIMAL;
}
public void init(Connection conn) throws SQLException {
}
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论