提交 e5428a0e authored 作者: Thomas Mueller's avatar Thomas Mueller

Java methods with variable number of parameters can now be used (for Java 1.5 or newer).

上级 9afa3362
...@@ -52,6 +52,12 @@ public class SysProperties { ...@@ -52,6 +52,12 @@ public class SysProperties {
*/ */
public static final String FILE_SEPARATOR = getStringSetting("file.separator", "/"); public static final String FILE_SEPARATOR = getStringSetting("file.separator", "/");
/**
* System property <code>java.specification.version</code>.<br />
* It is set by the system. Examples: 1.4, 1.5, 1.6.
*/
public static final String JAVA_SPECIFICATION_VERSION = getStringSetting("java.specification.version", "1.4");
/** /**
* System property <code>line.separator</code> (default: \n).<br /> * System property <code>line.separator</code> (default: \n).<br />
* It is usually set by the system, and used by the script and trace tools. * It is usually set by the system, and used by the script and trace tools.
......
...@@ -6,10 +6,12 @@ ...@@ -6,10 +6,12 @@
*/ */
package org.h2.engine; package org.h2.engine;
import java.lang.reflect.Array;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.sql.Connection; import java.sql.Connection;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Arrays;
import org.h2.command.Parser; import org.h2.command.Parser;
import org.h2.constant.ErrorCode; import org.h2.constant.ErrorCode;
...@@ -68,7 +70,7 @@ public class FunctionAlias extends DbObjectBase { ...@@ -68,7 +70,7 @@ public class FunctionAlias extends DbObjectBase {
continue; continue;
} }
if (m.getName().equals(methodName) || getMethodSignature(m).equals(methodName)) { if (m.getName().equals(methodName) || getMethodSignature(m).equals(methodName)) {
JavaMethod javaMethod = new JavaMethod(m); JavaMethod javaMethod = new JavaMethod(m, i);
for (int j = 0; j < list.size(); j++) { for (int j = 0; j < list.size(); j++) {
JavaMethod old = (JavaMethod) list.get(j); JavaMethod old = (JavaMethod) list.get(j);
if (old.getParameterCount() == javaMethod.getParameterCount()) { if (old.getParameterCount() == javaMethod.getParameterCount()) {
...@@ -88,6 +90,11 @@ public class FunctionAlias extends DbObjectBase { ...@@ -88,6 +90,11 @@ public class FunctionAlias extends DbObjectBase {
} }
javaMethods = new JavaMethod[list.size()]; javaMethods = new JavaMethod[list.size()];
list.toArray(javaMethods); list.toArray(javaMethods);
// Sort elements. Methods with a variable number of arguments must be at
// the end. Reason: there could be one method without parameters and one
// with a variable number. The one without parameters needs to be used
// if no parameters are given.
Arrays.sort(javaMethods);
} }
private String getMethodSignature(Method m) { private String getMethodSignature(Method m) {
...@@ -155,8 +162,10 @@ public class FunctionAlias extends DbObjectBase { ...@@ -155,8 +162,10 @@ public class FunctionAlias extends DbObjectBase {
load(); load();
int parameterCount = args.length; int parameterCount = args.length;
for (int i = 0; i < javaMethods.length; i++) { for (int i = 0; i < javaMethods.length; i++) {
if (javaMethods[i].getParameterCount() == parameterCount) { JavaMethod m = javaMethods[i];
return javaMethods[i]; int count = m.getParameterCount();
if (count == parameterCount || (m.isVarArgs() && count <= parameterCount + 1)) {
return m;
} }
} }
throw Message.getSQLException(ErrorCode.METHOD_NOT_FOUND_1, methodName + " (" + className + ", parameter count: " + parameterCount + ")"); throw Message.getSQLException(ErrorCode.METHOD_NOT_FOUND_1, methodName + " (" + className + ", parameter count: " + parameterCount + ")");
...@@ -185,14 +194,18 @@ public class FunctionAlias extends DbObjectBase { ...@@ -185,14 +194,18 @@ public class FunctionAlias extends DbObjectBase {
* Each method must have a different number of parameters however. * Each method must have a different number of parameters however.
* This helper class represents one such method. * This helper class represents one such method.
*/ */
public static class JavaMethod { public static class JavaMethod implements Comparable {
private Method method; private final int id;
private int paramCount; private final Method method;
private final int dataType;
private boolean hasConnectionParam; private boolean hasConnectionParam;
private int dataType; private boolean varArgs;
private Class varArgClass;
private int paramCount;
JavaMethod(Method method) throws SQLException { JavaMethod(Method method, int id) throws SQLException {
this.method = method; this.method = method;
this.id = id;
Class[] paramClasses = method.getParameterTypes(); Class[] paramClasses = method.getParameterTypes();
paramCount = paramClasses.length; paramCount = paramClasses.length;
if (paramCount > 0) { if (paramCount > 0) {
...@@ -202,6 +215,13 @@ public class FunctionAlias extends DbObjectBase { ...@@ -202,6 +215,13 @@ public class FunctionAlias extends DbObjectBase {
paramCount--; paramCount--;
} }
} }
if (paramCount > 0) {
Class lastArg = paramClasses[paramClasses.length - 1];
if (lastArg.isArray() && ClassUtils.isVarArgs(method)) {
varArgs = true;
varArgClass = lastArg.getComponentType();
}
}
Class returnClass = method.getReturnType(); Class returnClass = method.getReturnType();
dataType = DataType.getTypeFromClass(returnClass); dataType = DataType.getTypeFromClass(returnClass);
} }
...@@ -234,8 +254,23 @@ public class FunctionAlias extends DbObjectBase { ...@@ -234,8 +254,23 @@ public class FunctionAlias extends DbObjectBase {
if (hasConnectionParam && params.length > 0) { if (hasConnectionParam && params.length > 0) {
params[p++] = session.createConnection(columnList); params[p++] = session.createConnection(columnList);
} }
for (int a = 0; a < args.length && p < params.length; a++, p++) {
Class paramClass = paramClasses[p]; // allocate array for varArgs parameters
Object varArg = null;
if (varArgs) {
int len = args.length - params.length + 1 + (hasConnectionParam ? 1 : 0);
varArg = Array.newInstance(varArgClass, len);
params[params.length - 1] = varArg;
}
for (int a = 0; a < args.length; a++, p++) {
boolean currentIsVarArg = varArgs && p >= paramClasses.length - 1;
Class paramClass;
if (currentIsVarArg) {
paramClass = varArgClass;
} else {
paramClass = paramClasses[p];
}
int type = DataType.getTypeFromClass(paramClass); int type = DataType.getTypeFromClass(paramClass);
Value v = args[a].getValue(session); Value v = args[a].getValue(session);
v = v.convertTo(type); v = v.convertTo(type);
...@@ -258,8 +293,12 @@ public class FunctionAlias extends DbObjectBase { ...@@ -258,8 +293,12 @@ public class FunctionAlias extends DbObjectBase {
o = DataType.convertTo(session, session.createConnection(false), v, paramClass); o = DataType.convertTo(session, session.createConnection(false), v, paramClass);
} }
} }
if (currentIsVarArg) {
Array.set(varArg, p - params.length + 1, o);
} else {
params[p] = o; params[p] = o;
} }
}
boolean old = session.getAutoCommit(); boolean old = session.getAutoCommit();
try { try {
session.setAutoCommit(false); session.setAutoCommit(false);
...@@ -291,6 +330,24 @@ public class FunctionAlias extends DbObjectBase { ...@@ -291,6 +330,24 @@ public class FunctionAlias extends DbObjectBase {
return paramCount; return paramCount;
} }
public boolean isVarArgs() {
return varArgs;
}
public int compareTo(Object o) {
JavaMethod m = (JavaMethod) o;
if (varArgs != m.varArgs) {
return varArgs ? 1 : -1;
}
if (paramCount != m.paramCount) {
return paramCount - m.paramCount;
}
if (hasConnectionParam != m.hasConnectionParam) {
return hasConnectionParam ? 1 : -1;
}
return id - m.id;
}
} }
} }
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
*/ */
package org.h2.util; package org.h2.util;
import java.lang.reflect.Method;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
...@@ -88,4 +89,27 @@ public class ClassUtils { ...@@ -88,4 +89,27 @@ public class ClassUtils {
} }
} }
/**
* Checks if the given method takes a variable number of arguments. For Java
* 1.4 and older, false is returned. Example:
* <pre>
* public static double mean(double... values)
* </pre>
*
* @param m the method to test
* @return true if the method takes a variable number of arguments.
*/
public static boolean isVarArgs(Method m) {
if ("1.5".compareTo(SysProperties.JAVA_SPECIFICATION_VERSION) > 0) {
return false;
}
try {
Method isVarArgs = m.getClass().getMethod("isVarArgs", new Class[0]);
Boolean result = (Boolean) isVarArgs.invoke(m, new Object[0]);
return result.booleanValue();
} catch (Exception e) {
return false;
}
}
} }
...@@ -555,9 +555,13 @@ public abstract class TestBase { ...@@ -555,9 +555,13 @@ public abstract class TestBase {
*/ */
protected void assertEquals(double expected, double actual) throws Exception { protected void assertEquals(double expected, double actual) throws Exception {
if (expected != actual) { if (expected != actual) {
if (Double.isNaN(expected) && Double.isNaN(actual)) {
// if both a NaN, then there is no error
} else {
fail("Expected: " + expected + " actual: " + actual); fail("Expected: " + expected + " actual: " + actual);
} }
} }
}
/** /**
* Check if two values are equal, and if not throw an exception. * Check if two values are equal, and if not throw an exception.
...@@ -568,9 +572,13 @@ public abstract class TestBase { ...@@ -568,9 +572,13 @@ public abstract class TestBase {
*/ */
protected void assertEquals(float expected, float actual) throws Exception { protected void assertEquals(float expected, float actual) throws Exception {
if (expected != actual) { if (expected != actual) {
if (Float.isNaN(expected) && Float.isNaN(actual)) {
// if both a NaN, then there is no error
} else {
fail("Expected: " + expected + " actual: " + actual); fail("Expected: " + expected + " actual: " + actual);
} }
} }
}
/** /**
* Check if two values are equal, and if not throw an exception. * Check if two values are equal, and if not throw an exception.
......
...@@ -32,17 +32,53 @@ import org.h2.util.IOUtils; ...@@ -32,17 +32,53 @@ import org.h2.util.IOUtils;
*/ */
public class TestFunctions extends TestBase { public class TestFunctions extends TestBase {
private Statement stat;
public void test() throws Exception { public void test() throws Exception {
testVarArgs();
testAggregate(); testAggregate();
testFunctions(); testFunctions();
testFileRead(); testFileRead();
} }
private void testVarArgs() throws Exception {
//## Java 1.5 begin ##
Connection conn = getConnection("functions");
Statement stat = conn.createStatement();
stat.execute("CREATE ALIAS mean FOR \"" +
getClass().getName() + ".mean\"");
ResultSet rs = stat.executeQuery(
"select mean(), mean(10), mean(10, 20), mean(10, 20, 30)");
rs.next();
assertEquals(1.0, rs.getDouble(1));
assertEquals(10.0, rs.getDouble(2));
assertEquals(15.0, rs.getDouble(3));
assertEquals(20.0, rs.getDouble(4));
stat.execute("CREATE ALIAS mean2 FOR \"" +
getClass().getName() + ".mean2\"");
rs = stat.executeQuery(
"select mean2(), mean2(10), mean2(10, 20)");
rs.next();
assertEquals(Double.NaN, rs.getDouble(1));
assertEquals(10.0, rs.getDouble(2));
assertEquals(15.0, rs.getDouble(3));
stat.execute("CREATE ALIAS printMean FOR \"" +
getClass().getName() + ".printMean\"");
rs = stat.executeQuery(
"select printMean('A'), printMean('A', 10), " +
"printMean('BB', 10, 20), printMean ('CCC', 10, 20, 30)");
rs.next();
assertEquals("A: 0", rs.getString(1));
assertEquals("A: 10", rs.getString(2));
assertEquals("BB: 15", rs.getString(3));
assertEquals("CCC: 20", rs.getString(4));
conn.close();
//## Java 1.5 end ##
}
private void testFileRead() throws Exception { private void testFileRead() throws Exception {
Connection conn = getConnection("functions"); Connection conn = getConnection("functions");
stat = conn.createStatement(); Statement stat = conn.createStatement();
File f = new File(baseDir + "/test.txt"); File f = new File(baseDir + "/test.txt");
Properties prop = System.getProperties(); Properties prop = System.getProperties();
FileOutputStream out = new FileOutputStream(f); FileOutputStream out = new FileOutputStream(f);
...@@ -94,7 +130,7 @@ public class TestFunctions extends TestBase { ...@@ -94,7 +130,7 @@ public class TestFunctions extends TestBase {
private void testAggregate() throws Exception { private void testAggregate() throws Exception {
deleteDb("functions"); deleteDb("functions");
Connection conn = getConnection("functions"); Connection conn = getConnection("functions");
stat = conn.createStatement(); Statement stat = conn.createStatement();
stat.execute("CREATE AGGREGATE MEDIAN FOR \"" + MedianString.class.getName() + "\""); stat.execute("CREATE AGGREGATE MEDIAN FOR \"" + MedianString.class.getName() + "\"");
stat.execute("CREATE AGGREGATE IF NOT EXISTS MEDIAN FOR \"" + MedianString.class.getName() + "\""); stat.execute("CREATE AGGREGATE IF NOT EXISTS MEDIAN FOR \"" + MedianString.class.getName() + "\"");
ResultSet rs = stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)"); ResultSet rs = stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
...@@ -130,10 +166,10 @@ public class TestFunctions extends TestBase { ...@@ -130,10 +166,10 @@ public class TestFunctions extends TestBase {
private void testFunctions() throws Exception { private void testFunctions() throws Exception {
deleteDb("functions"); deleteDb("functions");
Connection conn = getConnection("functions"); Connection conn = getConnection("functions");
stat = conn.createStatement(); Statement stat = conn.createStatement();
test("abs(null)", null); test(stat, "abs(null)", null);
test("abs(1)", "1"); test(stat, "abs(1)", "1");
test("abs(1)", "1"); test(stat, "abs(1)", "1");
stat.execute("CREATE TABLE TEST(ID INT PRIMARY KEY, NAME VARCHAR)"); stat.execute("CREATE TABLE TEST(ID INT PRIMARY KEY, NAME VARCHAR)");
stat.execute("CREATE ALIAS ADD_ROW FOR \"" + getClass().getName() + ".addRow\""); stat.execute("CREATE ALIAS ADD_ROW FOR \"" + getClass().getName() + ".addRow\"");
...@@ -262,7 +298,7 @@ public class TestFunctions extends TestBase { ...@@ -262,7 +298,7 @@ public class TestFunctions extends TestBase {
conn.close(); conn.close();
} }
private void test(String sql, String value) throws Exception { private void test(Statement stat, String sql, String value) throws Exception {
ResultSet rs = stat.executeQuery("CALL " + sql); ResultSet rs = stat.executeQuery("CALL " + sql);
rs.next(); rs.next();
String s = rs.getString(1); String s = rs.getString(1);
...@@ -380,4 +416,51 @@ public class TestFunctions extends TestBase { ...@@ -380,4 +416,51 @@ public class TestFunctions extends TestBase {
return (int) Math.sqrt(value); return (int) Math.sqrt(value);
} }
/**
* This method is called via reflection from the database.
*/
public static double mean() {
return 1;
}
/**
* This method is called via reflection from the database.
*/
//## Java 1.5 begin ##
public static double mean(double... values) {
double sum = 0;
for (double x : values) {
sum += x;
}
return sum / values.length;
}
//## Java 1.5 end ##
/**
* This method is called via reflection from the database.
*/
//## Java 1.5 begin ##
public static double mean2(Connection conn, double... values) {
conn.getClass();
double sum = 0;
for (double x : values) {
sum += x;
}
return sum / values.length;
}
//## Java 1.5 end ##
/**
* This method is called via reflection from the database.
*/
//## Java 1.5 begin ##
public static String printMean(String prefix, double... values) {
double sum = 0;
for (double x : values) {
sum += x;
}
return prefix + ": " + (int) (sum / values.length);
}
//## Java 1.5 end ##
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论