提交 72d728e3 authored 作者: Noel Grandin's avatar Noel Grandin

Issue #156: Add support for getGeneratedKeys() when executing commands via…

Issue #156: Add support for getGeneratedKeys() when executing commands via PreparedStatement#executeBatch
上级 00a1b0e2
...@@ -21,6 +21,8 @@ Change Log ...@@ -21,6 +21,8 @@ Change Log
<h2>Next Version (unreleased)</h2> <h2>Next Version (unreleased)</h2>
<ul> <ul>
<li>Issue #156: Add support for getGeneratedKeys() when executing commands via PreparedStatement#executeBatch
</li>
<li>Issue #195: The new Maven uses a .cmd file instead of a .bat file <li>Issue #195: The new Maven uses a .cmd file instead of a .bat file
</li> </li>
<li>Issue #212: EXPLAIN PLAN for UPDATE statement did not display LIMIT expression <li>Issue #212: EXPLAIN PLAN for UPDATE statement did not display LIMIT expression
......
...@@ -25,13 +25,13 @@ import java.sql.Statement; ...@@ -25,13 +25,13 @@ import java.sql.Statement;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Calendar; import java.util.Calendar;
import java.util.HashMap; import java.util.HashMap;
import org.h2.api.ErrorCode; import org.h2.api.ErrorCode;
import org.h2.command.CommandInterface; import org.h2.command.CommandInterface;
import org.h2.expression.ParameterInterface; import org.h2.expression.ParameterInterface;
import org.h2.message.DbException; import org.h2.message.DbException;
import org.h2.message.TraceObject; import org.h2.message.TraceObject;
import org.h2.result.ResultInterface; import org.h2.result.ResultInterface;
import org.h2.tools.SimpleResultSet;
import org.h2.util.DateTimeUtils; import org.h2.util.DateTimeUtils;
import org.h2.util.IOUtils; import org.h2.util.IOUtils;
import org.h2.util.New; import org.h2.util.New;
...@@ -61,6 +61,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements ...@@ -61,6 +61,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements
protected CommandInterface command; protected CommandInterface command;
private final String sqlStatement; private final String sqlStatement;
private ArrayList<Value[]> batchParameters; private ArrayList<Value[]> batchParameters;
private ArrayList<Object> batchIdentities;
private HashMap<String, Integer> cachedColumnLabelMap; private HashMap<String, Integer> cachedColumnLabelMap;
JdbcPreparedStatement(JdbcConnection conn, String sql, int id, JdbcPreparedStatement(JdbcConnection conn, String sql, int id,
...@@ -97,6 +98,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements ...@@ -97,6 +98,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements
if (isDebugEnabled()) { if (isDebugEnabled()) {
debugCodeAssign("ResultSet", TraceObject.RESULT_SET, id, "executeQuery()"); debugCodeAssign("ResultSet", TraceObject.RESULT_SET, id, "executeQuery()");
} }
batchIdentities = null;
synchronized (session) { synchronized (session) {
checkClosed(); checkClosed();
closeOldResultSet(); closeOldResultSet();
...@@ -139,6 +141,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements ...@@ -139,6 +141,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements
try { try {
debugCodeCall("executeUpdate"); debugCodeCall("executeUpdate");
checkClosedForWrite(); checkClosedForWrite();
batchIdentities = null;
try { try {
return executeUpdateInternal(); return executeUpdateInternal();
} finally { } finally {
...@@ -724,6 +727,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements ...@@ -724,6 +727,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements
* *
* @deprecated since JDBC 2.0, use setCharacterStream * @deprecated since JDBC 2.0, use setCharacterStream
*/ */
@Deprecated
@Override @Override
public void setUnicodeStream(int parameterIndex, InputStream x, int length) public void setUnicodeStream(int parameterIndex, InputStream x, int length)
throws SQLException { throws SQLException {
...@@ -1158,12 +1162,14 @@ public class JdbcPreparedStatement extends JdbcStatement implements ...@@ -1158,12 +1162,14 @@ public class JdbcPreparedStatement extends JdbcStatement implements
@Override @Override
public int[] executeBatch() throws SQLException { public int[] executeBatch() throws SQLException {
try { try {
int id = getNextId(TraceObject.PREPARED_STATEMENT);
debugCodeCall("executeBatch"); debugCodeCall("executeBatch");
if (batchParameters == null) { if (batchParameters == null) {
// TODO batch: check what other database do if no parameters are // TODO batch: check what other database do if no parameters are
// set // set
batchParameters = New.arrayList(); batchParameters = New.arrayList();
} }
batchIdentities = New.arrayList();
int size = batchParameters.size(); int size = batchParameters.size();
int[] result = new int[size]; int[] result = new int[size];
boolean error = false; boolean error = false;
...@@ -1181,6 +1187,10 @@ public class JdbcPreparedStatement extends JdbcStatement implements ...@@ -1181,6 +1187,10 @@ public class JdbcPreparedStatement extends JdbcStatement implements
} }
try { try {
result[i] = executeUpdateInternal(); result[i] = executeUpdateInternal();
ResultSet rs = conn.getGeneratedKeys(this, id);
while (rs.next()) {
batchIdentities.add(rs.getObject(1));
}
} catch (Exception re) { } catch (Exception re) {
SQLException e = logAndConvert(re); SQLException e = logAndConvert(re);
if (next == null) { if (next == null) {
...@@ -1207,6 +1217,20 @@ public class JdbcPreparedStatement extends JdbcStatement implements ...@@ -1207,6 +1217,20 @@ public class JdbcPreparedStatement extends JdbcStatement implements
} }
} }
@Override
public ResultSet getGeneratedKeys() throws SQLException {
if (batchIdentities != null && !batchIdentities.isEmpty()) {
SimpleResultSet rs = new SimpleResultSet();
rs.addColumn("identity", java.sql.Types.INTEGER,
10, 0);
for (Object o : batchIdentities) {
rs.addRow(o);
}
return rs;
}
return super.getGeneratedKeys();
}
/** /**
* Adds the current settings to the batch. * Adds the current settings to the batch.
*/ */
......
...@@ -23,7 +23,6 @@ import java.sql.Statement; ...@@ -23,7 +23,6 @@ import java.sql.Statement;
import java.sql.Timestamp; import java.sql.Timestamp;
import java.sql.Types; import java.sql.Types;
import java.util.UUID; import java.util.UUID;
import org.h2.api.ErrorCode; import org.h2.api.ErrorCode;
import org.h2.api.Trigger; import org.h2.api.Trigger;
import org.h2.test.TestBase; import org.h2.test.TestBase;
...@@ -79,6 +78,7 @@ public class TestPreparedStatement extends TestBase { ...@@ -79,6 +78,7 @@ public class TestPreparedStatement extends TestBase {
testSubquery(conn); testSubquery(conn);
testObject(conn); testObject(conn);
testIdentity(conn); testIdentity(conn);
testBatchGeneratedKeys(conn);
testDataTypes(conn); testDataTypes(conn);
testGetMoreResults(conn); testGetMoreResults(conn);
testBlob(conn); testBlob(conn);
...@@ -504,11 +504,10 @@ public class TestPreparedStatement extends TestBase { ...@@ -504,11 +504,10 @@ public class TestPreparedStatement extends TestBase {
private void testScopedGeneratedKey(Connection conn) throws SQLException { private void testScopedGeneratedKey(Connection conn) throws SQLException {
Statement stat = conn.createStatement(); Statement stat = conn.createStatement();
Trigger t = new SequenceTrigger();
stat.execute("create table test(id identity)"); stat.execute("create table test(id identity)");
stat.execute("create sequence seq start with 1000"); stat.execute("create sequence seq start with 1000");
stat.execute("create trigger test_ins after insert on test call \"" + stat.execute("create trigger test_ins after insert on test call \"" +
t.getClass().getName() + "\""); SequenceTrigger.class.getName() + "\"");
stat.execute("insert into test values(null)"); stat.execute("insert into test values(null)");
ResultSet rs = stat.getGeneratedKeys(); ResultSet rs = stat.getGeneratedKeys();
rs.next(); rs.next();
...@@ -1049,6 +1048,28 @@ public class TestPreparedStatement extends TestBase { ...@@ -1049,6 +1048,28 @@ public class TestPreparedStatement extends TestBase {
assertFalse(rs.next()); assertFalse(rs.next());
stat.execute("DROP TABLE TEST"); stat.execute("DROP TABLE TEST");
stat.execute("DROP SEQUENCE SEQ");
}
private void testBatchGeneratedKeys(Connection conn) throws SQLException {
Statement stat = conn.createStatement();
stat.execute("CREATE SEQUENCE SEQ");
stat.execute("CREATE TABLE TEST(ID INT)");
PreparedStatement prep = conn.prepareStatement("INSERT INTO TEST VALUES(NEXT VALUE FOR SEQ)");
prep.addBatch();
prep.addBatch();
prep.addBatch();
prep.executeBatch();
ResultSet keys = prep.getGeneratedKeys();
keys.next();
assertEquals(1, keys.getLong(1));
keys.next();
assertEquals(2, keys.getLong(1));
keys.next();
assertEquals(3, keys.getLong(1));
assertFalse(keys.next());
stat.execute("DROP TABLE TEST");
stat.execute("DROP SEQUENCE SEQ");
} }
private int getLength() { private int getLength() {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论