提交 e05502f1 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Imprement returning of multiple generated keys and multiple generated columns

上级 e0619d3a
......@@ -12,6 +12,7 @@ import org.h2.api.Trigger;
import org.h2.command.Command;
import org.h2.command.CommandInterface;
import org.h2.command.Prepared;
import org.h2.engine.GeneratedKeys;
import org.h2.engine.Right;
import org.h2.engine.Session;
import org.h2.engine.UndoLogRecord;
......@@ -20,6 +21,7 @@ import org.h2.expression.ConditionAndOr;
import org.h2.expression.Expression;
import org.h2.expression.ExpressionColumn;
import org.h2.expression.Parameter;
import org.h2.expression.SequenceValue;
import org.h2.index.Index;
import org.h2.message.DbException;
import org.h2.mvstore.db.MVPrimaryIndex;
......@@ -145,8 +147,10 @@ public class Insert extends Prepared implements ResultTarget {
int listSize = list.size();
if (listSize > 0) {
int columnLen = columns.length;
GeneratedKeys generatedKeys = session.getGeneratedKeys();
for (int x = 0; x < listSize; x++) {
session.startStatementWithinTransaction();
generatedKeys.nextRow();
Row newRow = table.getTemplateRow();
Expression[] expr = list.get(x);
setCurrentRowNumber(x + 1);
......@@ -160,6 +164,9 @@ public class Insert extends Prepared implements ResultTarget {
try {
Value v = c.convert(e.getValue(session), session.getDatabase().getMode());
newRow.setValue(index, v);
if (e instanceof SequenceValue) {
generatedKeys.add(c, v);
}
} catch (DbException ex) {
throw setRow(ex, x, getSQL(expr));
}
......@@ -179,6 +186,7 @@ public class Insert extends Prepared implements ResultTarget {
continue;
}
}
generatedKeys.confirmRow();
session.log(table, UndoLogRecord.INSERT, newRow);
table.fireAfterRow(session, null, newRow, false);
}
......
......@@ -11,11 +11,13 @@ import org.h2.api.Trigger;
import org.h2.command.Command;
import org.h2.command.CommandInterface;
import org.h2.command.Prepared;
import org.h2.engine.GeneratedKeys;
import org.h2.engine.Right;
import org.h2.engine.Session;
import org.h2.engine.UndoLogRecord;
import org.h2.expression.Expression;
import org.h2.expression.Parameter;
import org.h2.expression.SequenceValue;
import org.h2.index.Index;
import org.h2.message.DbException;
import org.h2.result.ResultInterface;
......@@ -87,8 +89,10 @@ public class Merge extends Prepared {
if (valuesExpressionList.size() > 0) {
// process values in list
count = 0;
GeneratedKeys generatedKeys = session.getGeneratedKeys();
for (int x = 0, size = valuesExpressionList.size(); x < size; x++) {
setCurrentRowNumber(x + 1);
generatedKeys.nextRow();
Expression[] expr = valuesExpressionList.get(x);
Row newRow = targetTable.getTemplateRow();
for (int i = 0, len = columns.length; i < len; i++) {
......@@ -100,6 +104,9 @@ public class Merge extends Prepared {
try {
Value v = c.convert(e.getValue(session));
newRow.setValue(index, v);
if (e instanceof SequenceValue) {
generatedKeys.add(c, v);
}
} catch (DbException ex) {
throw setRow(ex, count, getSQL(expr));
}
......@@ -171,6 +178,7 @@ public class Merge extends Prepared {
if (!done) {
targetTable.lock(session, true, false);
targetTable.addRow(session, row);
session.getGeneratedKeys().confirmRow();
session.log(targetTable, UndoLogRecord.INSERT, row);
targetTable.fireAfterRow(session, null, row, false);
}
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.engine;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.h2.table.Column;
import org.h2.tools.SimpleResultSet;
import org.h2.util.MathUtils;
import org.h2.util.New;
import org.h2.value.DataType;
import org.h2.value.Value;
/**
* Generated keys.
*/
public final class GeneratedKeys {
private final ArrayList<Map<Column, Object>> data = New.arrayList();
private final ArrayList<Object> row = New.arrayList();
private final ArrayList<Column> columns = New.arrayList();
public void add(Column column, Value value) {
row.add(column);
row.add(value.getObject());
}
public void clear() {
data.clear();
row.clear();
columns.clear();
}
public void confirmRow() {
int size = row.size();
if (size > 0) {
if (size == 2) {
Column column = (Column) row.get(0);
data.add(Collections.singletonMap(column, row.get(1)));
if (!columns.contains(column)) {
columns.add(column);
}
} else {
HashMap<Column, Object> map = new HashMap<>();
for (int i = 0; i < size; i += 2) {
Column column = (Column) row.get(i);
map.put(column, row.get(i + 1));
if (!columns.contains(column)) {
columns.add(column);
}
}
data.add(map);
}
row.clear();
}
}
public SimpleResultSet getKeys() {
SimpleResultSet rs = new SimpleResultSet();
for (Column column : columns) {
rs.addColumn(column.getName(), DataType.convertTypeToSQLType(column.getType()),
MathUtils.convertLongToInt(column.getPrecision()), column.getScale());
}
for (Map<Column, Object> map : data) {
Object[] row = new Object[columns.size()];
for (Map.Entry<Column, Object> entry : map.entrySet()) {
row[columns.indexOf(entry.getKey())] = entry.getValue();
}
rs.addRow(row);
}
return rs;
}
public void nextRow() {
row.clear();
}
@Override
public String toString() {
return columns + ": " + data.size();
}
}
......@@ -85,6 +85,7 @@ public class Session extends SessionWithState {
private Value lastIdentity = ValueLong.get(0);
private Value lastScopeIdentity = ValueLong.get(0);
private Value lastTriggerIdentity;
private GeneratedKeys generatedKeys;
private int firstUncommittedLog = Session.LOG_WRITTEN;
private int firstUncommittedPos = Session.LOG_WRITTEN;
private HashMap<String, Savepoint> savepoints;
......@@ -1075,6 +1076,13 @@ public class Session extends SessionWithState {
return lastTriggerIdentity;
}
public GeneratedKeys getGeneratedKeys() {
if (generatedKeys == null) {
generatedKeys = new GeneratedKeys();
}
return generatedKeys;
}
/**
* Called when a log entry for this session is added. The session keeps
* track of the first entry in the transaction log that is not yet
......@@ -1240,6 +1248,11 @@ public class Session extends SessionWithState {
*/
public void setCurrentCommand(Command command) {
this.currentCommand = command;
// Preserve generated keys in case of a new query so they can be read with
// CALL GET_GENERATED_KEYS()
if (command != null && !command.isQuery() && generatedKeys != null) {
generatedKeys.clear();
}
if (queryTimeout > 0 && command != null) {
currentCommandStart = System.currentTimeMillis();
long now = System.nanoTime();
......
......@@ -115,7 +115,7 @@ public class Function extends Expression implements FunctionCall {
public static final int DATABASE = 150, USER = 151, CURRENT_USER = 152,
IDENTITY = 153, SCOPE_IDENTITY = 154, AUTOCOMMIT = 155,
READONLY = 156, DATABASE_PATH = 157, LOCK_TIMEOUT = 158,
DISK_SPACE_USED = 159, SIGNAL = 160;
DISK_SPACE_USED = 159, SIGNAL = 160, GET_GENERATED_KEYS = 161;
private static final Pattern SIGNAL_PATTERN = Pattern.compile("[0-9A-Z]{5}");
......@@ -491,6 +491,7 @@ public class Function extends Expression implements FunctionCall {
addFunctionNotDeterministic("DISK_SPACE_USED", DISK_SPACE_USED,
1, Value.LONG);
addFunctionWithNull("SIGNAL", SIGNAL, 2, Value.NULL);
addFunction("GET_GENERATED_KEYS", GET_GENERATED_KEYS, 0, Value.RESULT_SET);
addFunction("H2VERSION", H2VERSION, 0, Value.STRING);
// TableFunction
......@@ -924,6 +925,9 @@ public class Function extends Expression implements FunctionCall {
case DISK_SPACE_USED:
result = ValueLong.get(getDiskSpaceUsed(session, v0));
break;
case GET_GENERATED_KEYS:
result = ValueResultSet.get(session.getGeneratedKeys().getKeys());
break;
case CAST:
case CONVERT: {
v0 = v0.convertTo(dataType);
......
......@@ -85,6 +85,7 @@ public class JdbcConnection extends TraceObject
private CommandInterface getReadOnly, getGeneratedKeys;
private CommandInterface setLockMode, getLockMode;
private CommandInterface setQueryTimeout, getQueryTimeout;
private boolean oldGetGeneratedKeys;
private int savepointId;
private String catalog;
......@@ -1560,6 +1561,21 @@ public class JdbcConnection extends TraceObject
* INTERNAL
*/
ResultSet getGeneratedKeys(JdbcStatement stat, int id) {
if (!oldGetGeneratedKeys) {
try {
getGeneratedKeys = prepareCommand("CALL GET_GENERATED_KEYS()", getGeneratedKeys);
ResultInterface result = getGeneratedKeys.executeQuery(Integer.MAX_VALUE, true);
ResultSet rs = new JdbcResultSet(this, stat, getGeneratedKeys, result,
id, false, true, false);
return rs;
} catch (DbException ex) {
if (ex.getErrorCode() == ErrorCode.FUNCTION_NOT_FOUND_1) {
oldGetGeneratedKeys = true;
} else {
throw ex;
}
}
}
getGeneratedKeys = prepareCommand(
"SELECT SCOPE_IDENTITY() "
+ "WHERE SCOPE_IDENTITY() IS NOT NULL",
......
......@@ -31,7 +31,7 @@ import org.h2.expression.ParameterInterface;
import org.h2.message.DbException;
import org.h2.message.TraceObject;
import org.h2.result.ResultInterface;
import org.h2.tools.SimpleResultSet;
import org.h2.tools.MergedResultSet;
import org.h2.util.DateTimeUtils;
import org.h2.util.IOUtils;
import org.h2.util.New;
......@@ -61,7 +61,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements
protected CommandInterface command;
private final String sqlStatement;
private ArrayList<Value[]> batchParameters;
private ArrayList<Object> batchIdentities;
private MergedResultSet batchIdentities;
private HashMap<String, Integer> cachedColumnLabelMap;
JdbcPreparedStatement(JdbcConnection conn, String sql, int id,
......@@ -1243,7 +1243,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements
// set
batchParameters = New.arrayList();
}
batchIdentities = New.arrayList();
batchIdentities = new MergedResultSet();
int size = batchParameters.size();
int[] result = new int[size];
boolean error = false;
......@@ -1262,9 +1262,7 @@ public class JdbcPreparedStatement extends JdbcStatement implements
try {
result[i] = executeUpdateInternal();
ResultSet rs = conn.getGeneratedKeys(this, id);
while (rs.next()) {
batchIdentities.add(rs.getObject(1));
}
batchIdentities.add(rs);
} catch (Exception re) {
SQLException e = logAndConvert(re);
if (next == null) {
......@@ -1293,14 +1291,8 @@ public class JdbcPreparedStatement extends JdbcStatement implements
@Override
public ResultSet getGeneratedKeys() throws SQLException {
if (batchIdentities != null && !batchIdentities.isEmpty()) {
SimpleResultSet rs = new SimpleResultSet();
rs.addColumn("identity", java.sql.Types.INTEGER,
10, 0);
for (Object o : batchIdentities) {
rs.addRow(o);
}
return rs;
if (batchIdentities != null) {
return batchIdentities.getKeys();
}
return super.getGeneratedKeys();
}
......
......@@ -203,6 +203,7 @@ public class TriggerObject extends SchemaObjectBase {
* times for each statement.
*
* @param session the session
* @param table the table
* @param oldRow the old row
* @param newRow the new row
* @param beforeAction true if this method is called before the operation is
......@@ -210,7 +211,7 @@ public class TriggerObject extends SchemaObjectBase {
* @param rollback when the operation occurred within a rollback
* @return true if no further action is required (for 'instead of' triggers)
*/
public boolean fireRow(Session session, Row oldRow, Row newRow,
public boolean fireRow(Session session, Table table, Row oldRow, Row newRow,
boolean beforeAction, boolean rollback) {
if (!rowBased || before != beforeAction) {
return false;
......@@ -260,6 +261,7 @@ public class TriggerObject extends SchemaObjectBase {
Object o = newList[i];
if (o != newListBackup[i]) {
Value v = DataType.convertToValue(session, o, Value.UNKNOWN);
session.getGeneratedKeys().add(table.getColumn(i), v);
newRow.setValue(i, v);
}
}
......
......@@ -321,6 +321,7 @@ public class Column {
value = ValueNull.INSTANCE;
} else {
value = localDefaultExpression.getValue(session).convertTo(type);
session.getGeneratedKeys().add(this, value);
if (primaryKey) {
session.setLastIdentity(value);
}
......@@ -330,6 +331,7 @@ public class Column {
if (value == ValueNull.INSTANCE) {
if (convertNullToDefault) {
value = localDefaultExpression.getValue(session).convertTo(type);
session.getGeneratedKeys().add(this, value);
}
if (value == ValueNull.INSTANCE && !nullable) {
if (mode.convertInsertNullToZero) {
......
......@@ -1026,7 +1026,7 @@ public abstract class Table extends SchemaObjectBase {
boolean beforeAction, boolean rollback) {
if (triggers != null) {
for (TriggerObject trigger : triggers) {
boolean done = trigger.fireRow(session, oldRow, newRow, beforeAction, rollback);
boolean done = trigger.fireRow(session, this, oldRow, newRow, beforeAction, rollback);
if (done) {
return true;
}
......
......@@ -91,7 +91,7 @@ public class TestPreparedStatement extends TestBase {
testSubquery(conn);
testObject(conn);
testIdentity(conn);
testBatchGeneratedKeys(conn);
testGeneratedKeys(conn);
testDataTypes(conn);
testGetMoreResults(conn);
testBlob(conn);
......@@ -1318,11 +1318,115 @@ public class TestPreparedStatement extends TestBase {
stat.execute("DROP SEQUENCE SEQ");
}
private void testBatchGeneratedKeys(Connection conn) throws SQLException {
public static class TestGeneratedKeysTrigger implements Trigger {
@Override
public void init(Connection conn, String schemaName, String triggerName, String tableName, boolean before,
int type) throws SQLException {
}
@Override
public void fire(Connection conn, Object[] oldRow, Object[] newRow) throws SQLException {
if (newRow[0] == null) {
newRow[0] = UUID.randomUUID();
}
}
@Override
public void close() throws SQLException {
}
@Override
public void remove() throws SQLException {
}
}
private void testGeneratedKeys(Connection conn) throws SQLException {
Statement stat = conn.createStatement();
stat.execute("create table test(id bigint)");
stat.execute("create sequence seq");
PreparedStatement prep = conn.prepareStatement(
"insert into test values (30), (next value for seq),"
+ " (next value for seq), (next value for seq), (20)",
PreparedStatement.RETURN_GENERATED_KEYS);
prep.executeUpdate();
ResultSet rs = prep.getGeneratedKeys();
rs.next();
assertEquals(1L, rs.getLong(1));
rs.next();
assertEquals(2L, rs.getLong(1));
rs.next();
assertEquals(3L, rs.getLong(1));
assertFalse(rs.next());
stat.execute("drop table test");
stat.execute("drop sequence seq");
stat.execute("create table test(id bigint auto_increment, value int)");
stat.execute("insert into test(value) values (1), (2)");
rs = stat.getGeneratedKeys();
rs.next();
assertEquals(1L, rs.getLong(1));
rs.next();
assertEquals(2L, rs.getLong(1));
assertFalse(rs.next());
stat.execute("drop table test");
stat.execute("create table test(id bigint auto_increment, uid uuid default random_uuid(), value int)");
prep = conn.prepareStatement("insert into test(value) values (?), (?)",
PreparedStatement.RETURN_GENERATED_KEYS);
prep.setInt(1, 1);
prep.setInt(2, 2);
prep.addBatch();
prep.setInt(1, 3);
prep.setInt(1, 4);
prep.addBatch();
prep.executeBatch();
rs = prep.getGeneratedKeys();
rs.next();
assertEquals(1L, rs.getLong(1));
UUID u1 = (UUID) rs.getObject(2);
assertTrue(u1 != null);
rs.next();
assertEquals(2L, rs.getLong(1));
UUID u2 = (UUID) rs.getObject(2);
assertTrue(u2 != null);
rs.next();
assertEquals(3L, rs.getLong(1));
UUID u3 = (UUID) rs.getObject(2);
assertTrue(u3 != null);
rs.next();
assertEquals(4L, rs.getLong(1));
UUID u4 = (UUID) rs.getObject(2);
assertTrue(u4 != null);
assertFalse(rs.next());
assertFalse(u1.equals(u2));
assertFalse(u2.equals(u3));
assertFalse(u3.equals(u4));
stat.execute("drop table test");
stat.execute("create table test(id uuid, value int)");
stat.execute("create trigger test_insert before insert on test for each row call \""
+ TestGeneratedKeysTrigger.class.getName()
+ '"');
stat.executeUpdate("insert into test(value) values (10), (20)");
rs = stat.getGeneratedKeys();
rs.next();
u1 = (UUID) rs.getObject(1);
rs.next();
u2 = (UUID) rs.getObject(1);
assertFalse(rs.next());
rs = stat.executeQuery("select id from test order by value");
rs.next();
assertEquals(u1, rs.getObject(1));
rs.next();
assertEquals(u2, rs.getObject(1));
stat.execute("drop trigger test_insert");
stat.execute("drop table test");
stat.execute("CREATE SEQUENCE SEQ");
stat.execute("CREATE TABLE TEST(ID INT)");
PreparedStatement prep = conn.prepareStatement(
prep = conn.prepareStatement(
"INSERT INTO TEST VALUES(NEXT VALUE FOR SEQ)");
prep.addBatch();
prep.addBatch();
......
......@@ -447,7 +447,7 @@ public class TestWeb extends TestBase {
assertContains(result, "There is currently no running statement");
result = client.get(url,
"query.do?sql=@generated insert into test(id) values(test_sequence.nextval)");
assertContains(result, "SCOPE_IDENTITY()");
assertContains(result, "<tr><th>ID</th></tr><tr><td>1</td></tr>");
result = client.get(url, "query.do?sql=@maxrows 2000");
assertContains(result, "Max rowcount is set");
result = client.get(url, "query.do?sql=@password_hash user password");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论