提交 5e3a0621 authored 作者: Thomas Mueller's avatar Thomas Mueller

There is a new Aggregate API that supports the internal H2 data types (GEOMETRY…

There is a new Aggregate API that supports the internal H2 data types (GEOMETRY for example). Thanks a lot to Nicolas Fortin for the patch!
上级 8c0bb119
......@@ -18,7 +18,9 @@ Change Log
<h1>Change Log</h1>
<h2>Next Version (unreleased)</h2>
<ul><li>Referential integrity constraints sometimes used the wrong index,
<ul><li>There is a new Aggregate API that supports the internal H2 data types
(GEOMETRY for example). Thanks a lot to Nicolas Fortin for the patch!
</li><li>Referential integrity constraints sometimes used the wrong index,
such that updating a row in the referenced table incorrectly failed with
a constraint violation.
</li><li>The Polish translation was completed and corrected by Wojtek Jurczyk. Thanks a lot!
......
/*
* Copyright 2004-2013 H2 Group. Multiple-Licensed under the H2 License,
* Version 1.0, and under the Eclipse Public License, Version 1.0
* (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.api;
import java.sql.Connection;
import java.sql.SQLException;
/**
* A user-defined aggregate function needs to implement this interface.
* The class must be public and must have a public non-argument constructor.
*/
public interface Aggregate {
/**
* This method is called when the aggregate function is used.
* A new object is created for each invocation.
*
* @param conn a connection to the database
*/
void init(Connection conn) throws SQLException;
/**
* This method must return the H2 data type, {@link org.h2.value.Value},
* of the aggregate function, given the H2 data type of the input data.
* The method should check here if the number of parameters
* passed is correct, and if not it should throw an exception.
*
* @param inputTypes the H2 data type of the parameters,
* @return the H2 data type of the result
* @throws SQLException if the number/type of parameters passed is incorrect
*/
int getInternalType(int[] inputTypes) throws SQLException;
/**
* This method is called once for each row.
* If the aggregate function is called with multiple parameters,
* those are passed as array.
*
* @param value the value(s) for this row
*/
void add(Object value) throws SQLException;
/**
* This method returns the computed aggregate value.
*
* @return the aggregated value
*/
Object getResult() throws SQLException;
}
\ No newline at end of file
......@@ -12,6 +12,10 @@ import java.sql.SQLException;
/**
* A user-defined aggregate function needs to implement this interface.
* The class must be public and must have a public non-argument constructor.
* <p>
* Please note this interface only has limited support for data types.
* If you need data types that don't have a corresponding SQL type
* (for example GEOMETRY), then use the {@link Aggregate} interface.
*/
public interface AggregateFunction {
......
......@@ -7,11 +7,16 @@
package org.h2.engine;
import org.h2.api.AggregateFunction;
import org.h2.api.Aggregate;
import org.h2.command.Parser;
import org.h2.message.DbException;
import org.h2.message.Trace;
import org.h2.table.Table;
import org.h2.util.Utils;
import org.h2.value.DataType;
import java.sql.Connection;
import java.sql.SQLException;
/**
* Represents a user-defined aggregate function.
......@@ -29,14 +34,19 @@ public class UserAggregate extends DbObjectBase {
}
}
public AggregateFunction getInstance() {
public Aggregate getInstance() {
if (javaClass == null) {
javaClass = Utils.loadUserClass(className);
}
Object obj;
try {
obj = javaClass.newInstance();
AggregateFunction agg = (AggregateFunction) obj;
Aggregate agg;
if (obj instanceof Aggregate) {
agg = (Aggregate) obj;
} else {
agg = new AggregateWrapper((AggregateFunction) obj);
}
return agg;
} catch (Exception e) {
throw DbException.convert(e);
......@@ -80,4 +90,39 @@ public class UserAggregate extends DbObjectBase {
return this.className;
}
/**
* Wrap {@link AggregateFunction} in order to behave as {@link org.h2.api.Aggregate}
**/
private static class AggregateWrapper implements Aggregate {
private final AggregateFunction aggregateFunction;
AggregateWrapper(AggregateFunction aggregateFunction) {
this.aggregateFunction = aggregateFunction;
}
@Override
public void init(Connection conn) throws SQLException {
aggregateFunction.init(conn);
}
@Override
public int getInternalType(int[] inputTypes) throws SQLException {
int[] sqlTypes = new int[inputTypes.length];
for (int i = 0; i < inputTypes.length; i++) {
sqlTypes[i] = DataType.convertTypeToSQLType(inputTypes[i]);
}
return DataType.convertSQLTypeToValueType(aggregateFunction.getType(sqlTypes));
}
@Override
public void add(Object value) throws SQLException {
aggregateFunction.add(value);
}
@Override
public Object getResult() throws SQLException {
return aggregateFunction.getResult();
}
}
}
......@@ -9,7 +9,7 @@ package org.h2.expression;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashMap;
import org.h2.api.AggregateFunction;
import org.h2.api.Aggregate;
import org.h2.command.Parser;
import org.h2.command.dml.Select;
import org.h2.constant.ErrorCode;
......@@ -116,17 +116,15 @@ public class JavaAggregate extends Expression {
userConnection = session.createConnection(false);
int len = args.length;
argTypes = new int[len];
int[] argSqlTypes = new int[len];
for (int i = 0; i < len; i++) {
Expression expr = args[i];
args[i] = expr.optimize(session);
int type = expr.getType();
argTypes[i] = type;
argSqlTypes[i] = DataType.convertTypeToSQLType(type);
}
try {
AggregateFunction aggregate = getInstance();
dataType = DataType.convertSQLTypeToValueType(aggregate.getType(argSqlTypes));
Aggregate aggregate = getInstance();
dataType = aggregate.getInternalType(argTypes);
} catch (SQLException e) {
throw DbException.convert(e);
}
......@@ -140,8 +138,8 @@ public class JavaAggregate extends Expression {
}
}
private AggregateFunction getInstance() throws SQLException {
AggregateFunction agg = userAggregate.getInstance();
private Aggregate getInstance() throws SQLException {
Aggregate agg = userAggregate.getInstance();
agg.init(userConnection);
return agg;
}
......@@ -153,7 +151,7 @@ public class JavaAggregate extends Expression {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
try {
AggregateFunction agg = (AggregateFunction) group.get(this);
Aggregate agg = (Aggregate) group.get(this);
if (agg == null) {
agg = getInstance();
}
......@@ -182,7 +180,7 @@ public class JavaAggregate extends Expression {
}
lastGroupRowId = groupRowId;
AggregateFunction agg = (AggregateFunction) group.get(this);
Aggregate agg = (Aggregate) group.get(this);
try {
if (agg == null) {
agg = getInstance();
......
......@@ -33,6 +33,8 @@ import java.util.Locale;
import java.util.Properties;
import java.util.TimeZone;
import java.util.UUID;
import org.h2.api.Aggregate;
import org.h2.api.AggregateFunction;
import org.h2.constant.ErrorCode;
import org.h2.engine.Constants;
......@@ -79,6 +81,7 @@ public class TestFunctions extends TestBase implements AggregateFunction {
testMathFunctions();
testVarArgs();
testAggregate();
testAggregateType();
testFunctions();
testFileRead();
testValue();
......@@ -558,6 +561,71 @@ public class TestFunctions extends TestBase implements AggregateFunction {
}
/**
* This median implementation keeps all objects in memory.
*/
public static class MedianStringType implements Aggregate {
private final ArrayList<String> list = New.arrayList();
@Override
public void add(Object value) {
list.add(value.toString());
}
@Override
public Object getResult() {
return list.get(list.size() / 2);
}
@Override
public int getInternalType(int[] inputTypes) throws SQLException {
return Value.STRING;
}
@Override
public void init(Connection conn) {
// nothing to do
}
}
private void testAggregateType() throws SQLException {
deleteDb("functions");
Connection conn = getConnection("functions");
Statement stat = conn.createStatement();
stat.execute("CREATE AGGREGATE MEDIAN FOR \"" + MedianStringType.class.getName() + "\"");
stat.execute("CREATE AGGREGATE IF NOT EXISTS MEDIAN FOR \"" + MedianStringType.class.getName() + "\"");
ResultSet rs = stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
rs.next();
assertEquals("5", rs.getString(1));
conn.close();
if (config.memory) {
return;
}
conn = getConnection("functions");
stat = conn.createStatement();
stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
DatabaseMetaData meta = conn.getMetaData();
rs = meta.getProcedures(null, null, "MEDIAN");
assertTrue(rs.next());
assertFalse(rs.next());
rs = stat.executeQuery("SCRIPT");
boolean found = false;
while (rs.next()) {
String sql = rs.getString(1);
if (sql.contains("MEDIAN")) {
found = true;
}
}
assertTrue(found);
stat.execute("DROP AGGREGATE MEDIAN");
stat.execute("DROP AGGREGATE IF EXISTS MEDIAN");
conn.close();
}
private void testAggregate() throws SQLException {
deleteDb("functions");
Connection conn = getConnection("functions");
......
......@@ -12,6 +12,9 @@ import java.sql.Savepoint;
import java.sql.Statement;
import java.sql.Types;
import java.util.Random;
import com.vividsolutions.jts.geom.Envelope;
import org.h2.api.Aggregate;
import org.h2.test.TestBase;
import org.h2.tools.SimpleResultSet;
import org.h2.tools.SimpleRowSource;
......@@ -77,6 +80,7 @@ public class TestSpatial extends TestBase {
testEquals();
testTableFunctionGeometry();
testHashCode();
testAggregateWithGeometry();
}
private void testHashCode() {
......@@ -642,4 +646,62 @@ public class TestSpatial extends TestBase {
return rs;
}
public void testAggregateWithGeometry() throws SQLException {
deleteDb("spatialIndex");
Connection conn = getConnection("spatialIndex");
try {
Statement st = conn.createStatement();
st.execute("CREATE AGGREGATE TABLE_ENVELOPE FOR \""+TableEnvelope.class.getName()+"\"");
st.execute("CREATE TABLE test(the_geom GEOMETRY)");
st.execute("INSERT INTO test VALUES ('POINT(1 1)'), ('POINT(10 5)')");
ResultSet rs = st.executeQuery("select TABLE_ENVELOPE(the_geom) from test");
assertEquals("geometry", rs.getMetaData().getColumnTypeName(1).toLowerCase());
assertTrue(rs.next());
assertTrue(rs.getObject(1) instanceof Geometry);
assertTrue(new Envelope(1, 10, 1, 5).equals(((Geometry) rs.getObject(1)).getEnvelopeInternal()));
assertFalse(rs.next());
} finally {
conn.close();
}
deleteDb("spatialIndex");
}
/**
* An aggregate function that calculates the envelope.
*/
public static class TableEnvelope implements Aggregate {
private Envelope tableEnvelope;
@Override
public int getInternalType(int[] inputTypes) throws SQLException {
for (int inputType : inputTypes) {
if (inputType != Value.GEOMETRY) {
throw new SQLException("TableEnvelope accept only Geometry argument");
}
}
return Value.GEOMETRY;
}
@Override
public void init(Connection conn) throws SQLException {
tableEnvelope = null;
}
@Override
public void add(Object value) throws SQLException {
if (value instanceof Geometry) {
if (tableEnvelope == null) {
tableEnvelope = ((Geometry) value).getEnvelopeInternal();
} else {
tableEnvelope.expandToInclude(((Geometry) value).getEnvelopeInternal());
}
}
}
@Override
public Object getResult() throws SQLException {
return new GeometryFactory().toGeometry(tableEnvelope);
}
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论