提交 e5428a0e authored 作者: Thomas Mueller's avatar Thomas Mueller

Java methods with variable number of parameters can now be used (for Java 1.5 or newer).

上级 9afa3362
......@@ -52,6 +52,12 @@ public class SysProperties {
*/
public static final String FILE_SEPARATOR = getStringSetting("file.separator", "/");
/**
* System property <code>java.specification.version</code>.<br />
* It is set by the system. Examples: 1.4, 1.5, 1.6.
*/
public static final String JAVA_SPECIFICATION_VERSION = getStringSetting("java.specification.version", "1.4");
/**
* System property <code>line.separator</code> (default: \n).<br />
* It is usually set by the system, and used by the script and trace tools.
......
......@@ -6,10 +6,12 @@
*/
package org.h2.engine;
import java.lang.reflect.Array;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import org.h2.command.Parser;
import org.h2.constant.ErrorCode;
......@@ -68,7 +70,7 @@ public class FunctionAlias extends DbObjectBase {
continue;
}
if (m.getName().equals(methodName) || getMethodSignature(m).equals(methodName)) {
JavaMethod javaMethod = new JavaMethod(m);
JavaMethod javaMethod = new JavaMethod(m, i);
for (int j = 0; j < list.size(); j++) {
JavaMethod old = (JavaMethod) list.get(j);
if (old.getParameterCount() == javaMethod.getParameterCount()) {
......@@ -88,6 +90,11 @@ public class FunctionAlias extends DbObjectBase {
}
javaMethods = new JavaMethod[list.size()];
list.toArray(javaMethods);
// Sort elements. Methods with a variable number of arguments must be at
// the end. Reason: there could be one method without parameters and one
// with a variable number. The one without parameters needs to be used
// if no parameters are given.
Arrays.sort(javaMethods);
}
private String getMethodSignature(Method m) {
......@@ -155,8 +162,10 @@ public class FunctionAlias extends DbObjectBase {
load();
int parameterCount = args.length;
for (int i = 0; i < javaMethods.length; i++) {
if (javaMethods[i].getParameterCount() == parameterCount) {
return javaMethods[i];
JavaMethod m = javaMethods[i];
int count = m.getParameterCount();
if (count == parameterCount || (m.isVarArgs() && count <= parameterCount + 1)) {
return m;
}
}
throw Message.getSQLException(ErrorCode.METHOD_NOT_FOUND_1, methodName + " (" + className + ", parameter count: " + parameterCount + ")");
......@@ -185,14 +194,18 @@ public class FunctionAlias extends DbObjectBase {
* Each method must have a different number of parameters however.
* This helper class represents one such method.
*/
public static class JavaMethod {
private Method method;
private int paramCount;
public static class JavaMethod implements Comparable {
private final int id;
private final Method method;
private final int dataType;
private boolean hasConnectionParam;
private int dataType;
private boolean varArgs;
private Class varArgClass;
private int paramCount;
JavaMethod(Method method) throws SQLException {
JavaMethod(Method method, int id) throws SQLException {
this.method = method;
this.id = id;
Class[] paramClasses = method.getParameterTypes();
paramCount = paramClasses.length;
if (paramCount > 0) {
......@@ -202,6 +215,13 @@ public class FunctionAlias extends DbObjectBase {
paramCount--;
}
}
if (paramCount > 0) {
Class lastArg = paramClasses[paramClasses.length - 1];
if (lastArg.isArray() && ClassUtils.isVarArgs(method)) {
varArgs = true;
varArgClass = lastArg.getComponentType();
}
}
Class returnClass = method.getReturnType();
dataType = DataType.getTypeFromClass(returnClass);
}
......@@ -234,8 +254,23 @@ public class FunctionAlias extends DbObjectBase {
if (hasConnectionParam && params.length > 0) {
params[p++] = session.createConnection(columnList);
}
for (int a = 0; a < args.length && p < params.length; a++, p++) {
Class paramClass = paramClasses[p];
// allocate array for varArgs parameters
Object varArg = null;
if (varArgs) {
int len = args.length - params.length + 1 + (hasConnectionParam ? 1 : 0);
varArg = Array.newInstance(varArgClass, len);
params[params.length - 1] = varArg;
}
for (int a = 0; a < args.length; a++, p++) {
boolean currentIsVarArg = varArgs && p >= paramClasses.length - 1;
Class paramClass;
if (currentIsVarArg) {
paramClass = varArgClass;
} else {
paramClass = paramClasses[p];
}
int type = DataType.getTypeFromClass(paramClass);
Value v = args[a].getValue(session);
v = v.convertTo(type);
......@@ -258,8 +293,12 @@ public class FunctionAlias extends DbObjectBase {
o = DataType.convertTo(session, session.createConnection(false), v, paramClass);
}
}
if (currentIsVarArg) {
Array.set(varArg, p - params.length + 1, o);
} else {
params[p] = o;
}
}
boolean old = session.getAutoCommit();
try {
session.setAutoCommit(false);
......@@ -291,6 +330,24 @@ public class FunctionAlias extends DbObjectBase {
return paramCount;
}
public boolean isVarArgs() {
return varArgs;
}
public int compareTo(Object o) {
JavaMethod m = (JavaMethod) o;
if (varArgs != m.varArgs) {
return varArgs ? 1 : -1;
}
if (paramCount != m.paramCount) {
return paramCount - m.paramCount;
}
if (hasConnectionParam != m.hasConnectionParam) {
return hasConnectionParam ? 1 : -1;
}
return id - m.id;
}
}
}
......@@ -6,6 +6,7 @@
*/
package org.h2.util;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashSet;
......@@ -88,4 +89,27 @@ public class ClassUtils {
}
}
/**
* Checks if the given method takes a variable number of arguments. For Java
* 1.4 and older, false is returned. Example:
* <pre>
* public static double mean(double... values)
* </pre>
*
* @param m the method to test
* @return true if the method takes a variable number of arguments.
*/
public static boolean isVarArgs(Method m) {
if ("1.5".compareTo(SysProperties.JAVA_SPECIFICATION_VERSION) > 0) {
return false;
}
try {
Method isVarArgs = m.getClass().getMethod("isVarArgs", new Class[0]);
Boolean result = (Boolean) isVarArgs.invoke(m, new Object[0]);
return result.booleanValue();
} catch (Exception e) {
return false;
}
}
}
......@@ -555,9 +555,13 @@ public abstract class TestBase {
*/
protected void assertEquals(double expected, double actual) throws Exception {
if (expected != actual) {
if (Double.isNaN(expected) && Double.isNaN(actual)) {
// if both a NaN, then there is no error
} else {
fail("Expected: " + expected + " actual: " + actual);
}
}
}
/**
* Check if two values are equal, and if not throw an exception.
......@@ -568,9 +572,13 @@ public abstract class TestBase {
*/
protected void assertEquals(float expected, float actual) throws Exception {
if (expected != actual) {
if (Float.isNaN(expected) && Float.isNaN(actual)) {
// if both a NaN, then there is no error
} else {
fail("Expected: " + expected + " actual: " + actual);
}
}
}
/**
* Check if two values are equal, and if not throw an exception.
......
......@@ -32,17 +32,53 @@ import org.h2.util.IOUtils;
*/
public class TestFunctions extends TestBase {
private Statement stat;
public void test() throws Exception {
testVarArgs();
testAggregate();
testFunctions();
testFileRead();
}
private void testVarArgs() throws Exception {
//## Java 1.5 begin ##
Connection conn = getConnection("functions");
Statement stat = conn.createStatement();
stat.execute("CREATE ALIAS mean FOR \"" +
getClass().getName() + ".mean\"");
ResultSet rs = stat.executeQuery(
"select mean(), mean(10), mean(10, 20), mean(10, 20, 30)");
rs.next();
assertEquals(1.0, rs.getDouble(1));
assertEquals(10.0, rs.getDouble(2));
assertEquals(15.0, rs.getDouble(3));
assertEquals(20.0, rs.getDouble(4));
stat.execute("CREATE ALIAS mean2 FOR \"" +
getClass().getName() + ".mean2\"");
rs = stat.executeQuery(
"select mean2(), mean2(10), mean2(10, 20)");
rs.next();
assertEquals(Double.NaN, rs.getDouble(1));
assertEquals(10.0, rs.getDouble(2));
assertEquals(15.0, rs.getDouble(3));
stat.execute("CREATE ALIAS printMean FOR \"" +
getClass().getName() + ".printMean\"");
rs = stat.executeQuery(
"select printMean('A'), printMean('A', 10), " +
"printMean('BB', 10, 20), printMean ('CCC', 10, 20, 30)");
rs.next();
assertEquals("A: 0", rs.getString(1));
assertEquals("A: 10", rs.getString(2));
assertEquals("BB: 15", rs.getString(3));
assertEquals("CCC: 20", rs.getString(4));
conn.close();
//## Java 1.5 end ##
}
private void testFileRead() throws Exception {
Connection conn = getConnection("functions");
stat = conn.createStatement();
Statement stat = conn.createStatement();
File f = new File(baseDir + "/test.txt");
Properties prop = System.getProperties();
FileOutputStream out = new FileOutputStream(f);
......@@ -94,7 +130,7 @@ public class TestFunctions extends TestBase {
private void testAggregate() throws Exception {
deleteDb("functions");
Connection conn = getConnection("functions");
stat = conn.createStatement();
Statement stat = conn.createStatement();
stat.execute("CREATE AGGREGATE MEDIAN FOR \"" + MedianString.class.getName() + "\"");
stat.execute("CREATE AGGREGATE IF NOT EXISTS MEDIAN FOR \"" + MedianString.class.getName() + "\"");
ResultSet rs = stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
......@@ -130,10 +166,10 @@ public class TestFunctions extends TestBase {
private void testFunctions() throws Exception {
deleteDb("functions");
Connection conn = getConnection("functions");
stat = conn.createStatement();
test("abs(null)", null);
test("abs(1)", "1");
test("abs(1)", "1");
Statement stat = conn.createStatement();
test(stat, "abs(null)", null);
test(stat, "abs(1)", "1");
test(stat, "abs(1)", "1");
stat.execute("CREATE TABLE TEST(ID INT PRIMARY KEY, NAME VARCHAR)");
stat.execute("CREATE ALIAS ADD_ROW FOR \"" + getClass().getName() + ".addRow\"");
......@@ -262,7 +298,7 @@ public class TestFunctions extends TestBase {
conn.close();
}
private void test(String sql, String value) throws Exception {
private void test(Statement stat, String sql, String value) throws Exception {
ResultSet rs = stat.executeQuery("CALL " + sql);
rs.next();
String s = rs.getString(1);
......@@ -380,4 +416,51 @@ public class TestFunctions extends TestBase {
return (int) Math.sqrt(value);
}
/**
* This method is called via reflection from the database.
*/
public static double mean() {
return 1;
}
/**
* This method is called via reflection from the database.
*/
//## Java 1.5 begin ##
public static double mean(double... values) {
double sum = 0;
for (double x : values) {
sum += x;
}
return sum / values.length;
}
//## Java 1.5 end ##
/**
* This method is called via reflection from the database.
*/
//## Java 1.5 begin ##
public static double mean2(Connection conn, double... values) {
conn.getClass();
double sum = 0;
for (double x : values) {
sum += x;
}
return sum / values.length;
}
//## Java 1.5 end ##
/**
* This method is called via reflection from the database.
*/
//## Java 1.5 begin ##
public static String printMean(String prefix, double... values) {
double sum = 0;
for (double x : values) {
sum += x;
}
return prefix + ": " + (int) (sum / values.length);
}
//## Java 1.5 end ##
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论