提交 3de090e0 authored 作者: Thomas Mueller's avatar Thomas Mueller

User defined aggregate functions: the method getType expected internal data types.

上级 d6054c63
......@@ -29,10 +29,10 @@ public interface AggregateFunction {
* The method should check here if the number of parameters passed is correct,
* and if not it should throw an exception.
*
* @param inputType the SQL type of the parameters
* @param inputTypes the SQL type of the parameters
* @return the SQL type of the result
*/
int getType(int[] inputType) throws SQLException;
int getType(int[] inputTypes) throws SQLException;
/**
* This method is called once for each row.
......
......@@ -60,7 +60,7 @@ public class JavaAggregate extends Expression {
}
public int getScale() {
return 0;
return DataType.getDataType(dataType).defaultScale;
}
public String getSQL() {
......@@ -91,6 +91,9 @@ public class JavaAggregate extends Expression {
case ExpressionVisitor.GET_DEPENDENCIES:
visitor.addDependency(userAggregate);
break;
case ExpressionVisitor.OPTIMIZABLE_MIN_MAX_COUNT_ALL:
// user defined aggregate functions can not be optimized
return false;
default:
}
for (int i = 0; i < args.length; i++) {
......@@ -111,13 +114,16 @@ public class JavaAggregate extends Expression {
public Expression optimize(Session session) throws SQLException {
userConnection = session.createConnection(false);
argTypes = new int[args.length];
int[] argSqlTypes = new int[args.length];
for (int i = 0; i < args.length; i++) {
Expression expr = args[i];
args[i] = expr.optimize(session);
argTypes[i] = expr.getType();
int type = expr.getType();
argTypes[i] = type;
argSqlTypes[i] = DataType.convertTypeToSQLType(type);
}
aggregate = getInstance();
dataType = aggregate.getType(argTypes);
dataType = DataType.convertSQLTypeToValueType(aggregate.getType(argSqlTypes));
return this;
}
......
......@@ -13,6 +13,7 @@ import org.h2.engine.FunctionAlias;
import org.h2.engine.Session;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueNull;
import org.h2.value.ValueResultSet;
......@@ -64,7 +65,7 @@ public class JavaFunction extends Expression implements FunctionCall {
}
public int getScale() {
return 0;
return DataType.getDataType(getType()).defaultScale;
}
public long getPrecision() {
......
......@@ -11,6 +11,7 @@ import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.InputStream;
import java.math.BigDecimal;
import java.sql.Blob;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
......@@ -30,7 +31,7 @@ import org.h2.util.IOUtils;
/**
* Tests for user defined functions and aggregates.
*/
public class TestFunctions extends TestBase {
public class TestFunctions extends TestBase implements AggregateFunction {
/**
* Run just this test.
......@@ -42,6 +43,8 @@ public class TestFunctions extends TestBase {
}
public void test() throws Exception {
deleteDb("functions");
testPrecision();
testVarArgs();
testAggregate();
testFunctions();
......@@ -49,6 +52,23 @@ public class TestFunctions extends TestBase {
deleteDb("functions");
}
private void testPrecision() throws SQLException {
Connection conn = getConnection("functions");
Statement stat = conn.createStatement();
stat.execute("create alias no_op for \""+getClass().getName()+".noOp\"");
PreparedStatement prep = conn.prepareStatement("select * from dual where no_op(1.6)=?");
prep.setBigDecimal(1, new BigDecimal("1.6"));
ResultSet rs = prep.executeQuery();
assertTrue(rs.next());
stat.execute("create aggregate agg_sum for \""+getClass().getName()+"\"");
rs = stat.executeQuery("select agg_sum(1), sum(1.6) from dual");
rs.next();
assertEquals(1, rs.getMetaData().getScale(2));
assertEquals(32767, rs.getMetaData().getScale(1));
conn.close();
}
private void testVarArgs() throws SQLException {
//## Java 1.5 begin ##
Connection conn = getConnection("functions");
......@@ -433,6 +453,13 @@ public class TestFunctions extends TestBase {
return 1;
}
/**
* This method is called via reflection from the database.
*/
public static BigDecimal noOp(BigDecimal dec) {
return dec;
}
/**
* This method is called via reflection from the database.
*/
......@@ -473,4 +500,21 @@ public class TestFunctions extends TestBase {
}
//## Java 1.5 end ##
public void add(Object value) throws SQLException {
}
public Object getResult() throws SQLException {
return new BigDecimal("1.6");
}
public int getType(int[] inputTypes) throws SQLException {
if (inputTypes.length != 1 || inputTypes[0] != Types.INTEGER) {
throw new SQLException("unexpected data type");
}
return Types.DECIMAL;
}
public void init(Connection conn) throws SQLException {
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论