提交 5e3a0621 authored 作者: Thomas Mueller's avatar Thomas Mueller

There is a new Aggregate API that supports the internal H2 data types (GEOMETRY…

There is a new Aggregate API that supports the internal H2 data types (GEOMETRY for example). Thanks a lot to Nicolas Fortin for the patch!
上级 8c0bb119
...@@ -18,7 +18,9 @@ Change Log ...@@ -18,7 +18,9 @@ Change Log
<h1>Change Log</h1> <h1>Change Log</h1>
<h2>Next Version (unreleased)</h2> <h2>Next Version (unreleased)</h2>
<ul><li>Referential integrity constraints sometimes used the wrong index, <ul><li>There is a new Aggregate API that supports the internal H2 data types
(GEOMETRY for example). Thanks a lot to Nicolas Fortin for the patch!
</li><li>Referential integrity constraints sometimes used the wrong index,
such that updating a row in the referenced table incorrectly failed with such that updating a row in the referenced table incorrectly failed with
a constraint violation. a constraint violation.
</li><li>The Polish translation was completed and corrected by Wojtek Jurczyk. Thanks a lot! </li><li>The Polish translation was completed and corrected by Wojtek Jurczyk. Thanks a lot!
......
/*
* Copyright 2004-2013 H2 Group. Multiple-Licensed under the H2 License,
* Version 1.0, and under the Eclipse Public License, Version 1.0
* (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.api;
import java.sql.Connection;
import java.sql.SQLException;
/**
* A user-defined aggregate function needs to implement this interface.
* The class must be public and must have a public non-argument constructor.
*/
public interface Aggregate {
/**
* This method is called when the aggregate function is used.
* A new object is created for each invocation.
*
* @param conn a connection to the database
*/
void init(Connection conn) throws SQLException;
/**
* This method must return the H2 data type, {@link org.h2.value.Value},
* of the aggregate function, given the H2 data type of the input data.
* The method should check here if the number of parameters
* passed is correct, and if not it should throw an exception.
*
* @param inputTypes the H2 data type of the parameters,
* @return the H2 data type of the result
* @throws SQLException if the number/type of parameters passed is incorrect
*/
int getInternalType(int[] inputTypes) throws SQLException;
/**
* This method is called once for each row.
* If the aggregate function is called with multiple parameters,
* those are passed as array.
*
* @param value the value(s) for this row
*/
void add(Object value) throws SQLException;
/**
* This method returns the computed aggregate value.
*
* @return the aggregated value
*/
Object getResult() throws SQLException;
}
\ No newline at end of file
...@@ -12,6 +12,10 @@ import java.sql.SQLException; ...@@ -12,6 +12,10 @@ import java.sql.SQLException;
/** /**
* A user-defined aggregate function needs to implement this interface. * A user-defined aggregate function needs to implement this interface.
* The class must be public and must have a public non-argument constructor. * The class must be public and must have a public non-argument constructor.
* <p>
* Please note this interface only has limited support for data types.
* If you need data types that don't have a corresponding SQL type
* (for example GEOMETRY), then use the {@link Aggregate} interface.
*/ */
public interface AggregateFunction { public interface AggregateFunction {
......
...@@ -7,11 +7,16 @@ ...@@ -7,11 +7,16 @@
package org.h2.engine; package org.h2.engine;
import org.h2.api.AggregateFunction; import org.h2.api.AggregateFunction;
import org.h2.api.Aggregate;
import org.h2.command.Parser; import org.h2.command.Parser;
import org.h2.message.DbException; import org.h2.message.DbException;
import org.h2.message.Trace; import org.h2.message.Trace;
import org.h2.table.Table; import org.h2.table.Table;
import org.h2.util.Utils; import org.h2.util.Utils;
import org.h2.value.DataType;
import java.sql.Connection;
import java.sql.SQLException;
/** /**
* Represents a user-defined aggregate function. * Represents a user-defined aggregate function.
...@@ -29,14 +34,19 @@ public class UserAggregate extends DbObjectBase { ...@@ -29,14 +34,19 @@ public class UserAggregate extends DbObjectBase {
} }
} }
public AggregateFunction getInstance() { public Aggregate getInstance() {
if (javaClass == null) { if (javaClass == null) {
javaClass = Utils.loadUserClass(className); javaClass = Utils.loadUserClass(className);
} }
Object obj; Object obj;
try { try {
obj = javaClass.newInstance(); obj = javaClass.newInstance();
AggregateFunction agg = (AggregateFunction) obj; Aggregate agg;
if (obj instanceof Aggregate) {
agg = (Aggregate) obj;
} else {
agg = new AggregateWrapper((AggregateFunction) obj);
}
return agg; return agg;
} catch (Exception e) { } catch (Exception e) {
throw DbException.convert(e); throw DbException.convert(e);
...@@ -80,4 +90,39 @@ public class UserAggregate extends DbObjectBase { ...@@ -80,4 +90,39 @@ public class UserAggregate extends DbObjectBase {
return this.className; return this.className;
} }
/**
* Wrap {@link AggregateFunction} in order to behave as {@link org.h2.api.Aggregate}
**/
private static class AggregateWrapper implements Aggregate {
private final AggregateFunction aggregateFunction;
AggregateWrapper(AggregateFunction aggregateFunction) {
this.aggregateFunction = aggregateFunction;
}
@Override
public void init(Connection conn) throws SQLException {
aggregateFunction.init(conn);
}
@Override
public int getInternalType(int[] inputTypes) throws SQLException {
int[] sqlTypes = new int[inputTypes.length];
for (int i = 0; i < inputTypes.length; i++) {
sqlTypes[i] = DataType.convertTypeToSQLType(inputTypes[i]);
}
return DataType.convertSQLTypeToValueType(aggregateFunction.getType(sqlTypes));
}
@Override
public void add(Object value) throws SQLException {
aggregateFunction.add(value);
}
@Override
public Object getResult() throws SQLException {
return aggregateFunction.getResult();
}
}
} }
...@@ -9,7 +9,7 @@ package org.h2.expression; ...@@ -9,7 +9,7 @@ package org.h2.expression;
import java.sql.Connection; import java.sql.Connection;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.HashMap; import java.util.HashMap;
import org.h2.api.AggregateFunction; import org.h2.api.Aggregate;
import org.h2.command.Parser; import org.h2.command.Parser;
import org.h2.command.dml.Select; import org.h2.command.dml.Select;
import org.h2.constant.ErrorCode; import org.h2.constant.ErrorCode;
...@@ -116,17 +116,15 @@ public class JavaAggregate extends Expression { ...@@ -116,17 +116,15 @@ public class JavaAggregate extends Expression {
userConnection = session.createConnection(false); userConnection = session.createConnection(false);
int len = args.length; int len = args.length;
argTypes = new int[len]; argTypes = new int[len];
int[] argSqlTypes = new int[len];
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
Expression expr = args[i]; Expression expr = args[i];
args[i] = expr.optimize(session); args[i] = expr.optimize(session);
int type = expr.getType(); int type = expr.getType();
argTypes[i] = type; argTypes[i] = type;
argSqlTypes[i] = DataType.convertTypeToSQLType(type);
} }
try { try {
AggregateFunction aggregate = getInstance(); Aggregate aggregate = getInstance();
dataType = DataType.convertSQLTypeToValueType(aggregate.getType(argSqlTypes)); dataType = aggregate.getInternalType(argTypes);
} catch (SQLException e) { } catch (SQLException e) {
throw DbException.convert(e); throw DbException.convert(e);
} }
...@@ -140,8 +138,8 @@ public class JavaAggregate extends Expression { ...@@ -140,8 +138,8 @@ public class JavaAggregate extends Expression {
} }
} }
private AggregateFunction getInstance() throws SQLException { private Aggregate getInstance() throws SQLException {
AggregateFunction agg = userAggregate.getInstance(); Aggregate agg = userAggregate.getInstance();
agg.init(userConnection); agg.init(userConnection);
return agg; return agg;
} }
...@@ -153,7 +151,7 @@ public class JavaAggregate extends Expression { ...@@ -153,7 +151,7 @@ public class JavaAggregate extends Expression {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL()); throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
} }
try { try {
AggregateFunction agg = (AggregateFunction) group.get(this); Aggregate agg = (Aggregate) group.get(this);
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
} }
...@@ -182,7 +180,7 @@ public class JavaAggregate extends Expression { ...@@ -182,7 +180,7 @@ public class JavaAggregate extends Expression {
} }
lastGroupRowId = groupRowId; lastGroupRowId = groupRowId;
AggregateFunction agg = (AggregateFunction) group.get(this); Aggregate agg = (Aggregate) group.get(this);
try { try {
if (agg == null) { if (agg == null) {
agg = getInstance(); agg = getInstance();
......
...@@ -33,6 +33,8 @@ import java.util.Locale; ...@@ -33,6 +33,8 @@ import java.util.Locale;
import java.util.Properties; import java.util.Properties;
import java.util.TimeZone; import java.util.TimeZone;
import java.util.UUID; import java.util.UUID;
import org.h2.api.Aggregate;
import org.h2.api.AggregateFunction; import org.h2.api.AggregateFunction;
import org.h2.constant.ErrorCode; import org.h2.constant.ErrorCode;
import org.h2.engine.Constants; import org.h2.engine.Constants;
...@@ -79,6 +81,7 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -79,6 +81,7 @@ public class TestFunctions extends TestBase implements AggregateFunction {
testMathFunctions(); testMathFunctions();
testVarArgs(); testVarArgs();
testAggregate(); testAggregate();
testAggregateType();
testFunctions(); testFunctions();
testFileRead(); testFileRead();
testValue(); testValue();
...@@ -558,6 +561,71 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -558,6 +561,71 @@ public class TestFunctions extends TestBase implements AggregateFunction {
} }
/**
* This median implementation keeps all objects in memory.
*/
public static class MedianStringType implements Aggregate {
private final ArrayList<String> list = New.arrayList();
@Override
public void add(Object value) {
list.add(value.toString());
}
@Override
public Object getResult() {
return list.get(list.size() / 2);
}
@Override
public int getInternalType(int[] inputTypes) throws SQLException {
return Value.STRING;
}
@Override
public void init(Connection conn) {
// nothing to do
}
}
private void testAggregateType() throws SQLException {
deleteDb("functions");
Connection conn = getConnection("functions");
Statement stat = conn.createStatement();
stat.execute("CREATE AGGREGATE MEDIAN FOR \"" + MedianStringType.class.getName() + "\"");
stat.execute("CREATE AGGREGATE IF NOT EXISTS MEDIAN FOR \"" + MedianStringType.class.getName() + "\"");
ResultSet rs = stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
rs.next();
assertEquals("5", rs.getString(1));
conn.close();
if (config.memory) {
return;
}
conn = getConnection("functions");
stat = conn.createStatement();
stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
DatabaseMetaData meta = conn.getMetaData();
rs = meta.getProcedures(null, null, "MEDIAN");
assertTrue(rs.next());
assertFalse(rs.next());
rs = stat.executeQuery("SCRIPT");
boolean found = false;
while (rs.next()) {
String sql = rs.getString(1);
if (sql.contains("MEDIAN")) {
found = true;
}
}
assertTrue(found);
stat.execute("DROP AGGREGATE MEDIAN");
stat.execute("DROP AGGREGATE IF EXISTS MEDIAN");
conn.close();
}
private void testAggregate() throws SQLException { private void testAggregate() throws SQLException {
deleteDb("functions"); deleteDb("functions");
Connection conn = getConnection("functions"); Connection conn = getConnection("functions");
......
...@@ -12,6 +12,9 @@ import java.sql.Savepoint; ...@@ -12,6 +12,9 @@ import java.sql.Savepoint;
import java.sql.Statement; import java.sql.Statement;
import java.sql.Types; import java.sql.Types;
import java.util.Random; import java.util.Random;
import com.vividsolutions.jts.geom.Envelope;
import org.h2.api.Aggregate;
import org.h2.test.TestBase; import org.h2.test.TestBase;
import org.h2.tools.SimpleResultSet; import org.h2.tools.SimpleResultSet;
import org.h2.tools.SimpleRowSource; import org.h2.tools.SimpleRowSource;
...@@ -77,6 +80,7 @@ public class TestSpatial extends TestBase { ...@@ -77,6 +80,7 @@ public class TestSpatial extends TestBase {
testEquals(); testEquals();
testTableFunctionGeometry(); testTableFunctionGeometry();
testHashCode(); testHashCode();
testAggregateWithGeometry();
} }
private void testHashCode() { private void testHashCode() {
...@@ -641,5 +645,63 @@ public class TestSpatial extends TestBase { ...@@ -641,5 +645,63 @@ public class TestSpatial extends TestBase {
rs.addRow(factory.createPoint(new Coordinate(x, y))); rs.addRow(factory.createPoint(new Coordinate(x, y)));
return rs; return rs;
} }
public void testAggregateWithGeometry() throws SQLException {
deleteDb("spatialIndex");
Connection conn = getConnection("spatialIndex");
try {
Statement st = conn.createStatement();
st.execute("CREATE AGGREGATE TABLE_ENVELOPE FOR \""+TableEnvelope.class.getName()+"\"");
st.execute("CREATE TABLE test(the_geom GEOMETRY)");
st.execute("INSERT INTO test VALUES ('POINT(1 1)'), ('POINT(10 5)')");
ResultSet rs = st.executeQuery("select TABLE_ENVELOPE(the_geom) from test");
assertEquals("geometry", rs.getMetaData().getColumnTypeName(1).toLowerCase());
assertTrue(rs.next());
assertTrue(rs.getObject(1) instanceof Geometry);
assertTrue(new Envelope(1, 10, 1, 5).equals(((Geometry) rs.getObject(1)).getEnvelopeInternal()));
assertFalse(rs.next());
} finally {
conn.close();
}
deleteDb("spatialIndex");
}
/**
* An aggregate function that calculates the envelope.
*/
public static class TableEnvelope implements Aggregate {
private Envelope tableEnvelope;
@Override
public int getInternalType(int[] inputTypes) throws SQLException {
for (int inputType : inputTypes) {
if (inputType != Value.GEOMETRY) {
throw new SQLException("TableEnvelope accept only Geometry argument");
}
}
return Value.GEOMETRY;
}
@Override
public void init(Connection conn) throws SQLException {
tableEnvelope = null;
}
@Override
public void add(Object value) throws SQLException {
if (value instanceof Geometry) {
if (tableEnvelope == null) {
tableEnvelope = ((Geometry) value).getEnvelopeInternal();
} else {
tableEnvelope.expandToInclude(((Geometry) value).getEnvelopeInternal());
}
}
}
@Override
public Object getResult() throws SQLException {
return new GeometryFactory().toGeometry(tableEnvelope);
}
}
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论