提交 0707251f authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Fix prepared statements with multiple commands with parameters

上级 6666417b
...@@ -21,6 +21,8 @@ Change Log ...@@ -21,6 +21,8 @@ Change Log
<h2>Next Version (unreleased)</h2> <h2>Next Version (unreleased)</h2>
<ul> <ul>
<li>Issue #1590: Error on executing "DELETE FROM table1 WHERE ID = ?; DELETE FROM table2 WHERE ID = ?;"
</li>
<li>Issue #1727: Support ISODOW as identifier for the extract function additional to ISO_DAY_OF_WEEK <li>Issue #1727: Support ISODOW as identifier for the extract function additional to ISO_DAY_OF_WEEK
</li> </li>
<li>PR #1580, #1726: Disable remote database creation by default <li>PR #1580, #1726: Disable remote database creation by default
......
...@@ -8,6 +8,7 @@ package org.h2.command; ...@@ -8,6 +8,7 @@ package org.h2.command;
import java.util.ArrayList; import java.util.ArrayList;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.expression.Parameter;
import org.h2.expression.ParameterInterface; import org.h2.expression.ParameterInterface;
import org.h2.result.ResultInterface; import org.h2.result.ResultInterface;
...@@ -16,51 +17,53 @@ import org.h2.result.ResultInterface; ...@@ -16,51 +17,53 @@ import org.h2.result.ResultInterface;
*/ */
class CommandList extends Command { class CommandList extends Command {
private final Command command; private final ArrayList<Command> commands;
private final String remaining; private final ArrayList<Parameter> parameters;
CommandList(Session session, String sql, Command c, String remaining) { CommandList(Session session, String sql, ArrayList<Command> commands, ArrayList<Parameter> parameters) {
super(session, sql); super(session, sql);
this.command = c; this.commands = commands;
this.remaining = remaining; this.parameters = parameters;
} }
@Override @Override
public ArrayList<? extends ParameterInterface> getParameters() { public ArrayList<? extends ParameterInterface> getParameters() {
return command.getParameters(); return parameters;
} }
private void executeRemaining() { private void executeRemaining() {
Command remainingCommand = session.prepareLocal(remaining); for (int i = 1, l = commands.size(); i < l; i++) {
if (remainingCommand.isQuery()) { Command command = commands.get(i);
remainingCommand.query(0); if (command.isQuery()) {
command.query(0);
} else { } else {
remainingCommand.update(); command.update();
}
} }
} }
@Override @Override
public int update() { public int update() {
int updateCount = command.executeUpdate(false).getUpdateCount(); int updateCount = commands.get(0).executeUpdate(false).getUpdateCount();
executeRemaining(); executeRemaining();
return updateCount; return updateCount;
} }
@Override @Override
public void prepareJoinBatch() { public void prepareJoinBatch() {
command.prepareJoinBatch(); commands.get(0).prepareJoinBatch();
} }
@Override @Override
public ResultInterface query(int maxrows) { public ResultInterface query(int maxrows) {
ResultInterface result = command.query(maxrows); ResultInterface result = commands.get(0).query(maxrows);
executeRemaining(); executeRemaining();
return result; return result;
} }
@Override @Override
public boolean isQuery() { public boolean isQuery() {
return command.isQuery(); return commands.get(0).isQuery();
} }
@Override @Override
...@@ -75,12 +78,12 @@ class CommandList extends Command { ...@@ -75,12 +78,12 @@ class CommandList extends Command {
@Override @Override
public ResultInterface queryMeta() { public ResultInterface queryMeta() {
return command.queryMeta(); return commands.get(0).queryMeta();
} }
@Override @Override
public int getCommandType() { public int getCommandType() {
return command.getCommandType(); return commands.get(0).getCommandType();
} }
} }
...@@ -634,14 +634,15 @@ public class Parser { ...@@ -634,14 +634,15 @@ public class Parser {
private Prepared currentPrepared; private Prepared currentPrepared;
private Select currentSelect; private Select currentSelect;
private ArrayList<Parameter> parameters; private ArrayList<Parameter> parameters;
private ArrayList<Parameter> indexedParameterList;
private ArrayList<Parameter> suppliedParameters;
private ArrayList<Parameter> suppliedParameterList;
private String schemaName; private String schemaName;
private ArrayList<String> expectedList; private ArrayList<String> expectedList;
private boolean rightsChecked; private boolean rightsChecked;
private boolean recompileAlways; private boolean recompileAlways;
private boolean literalsChecked; private boolean literalsChecked;
private ArrayList<Parameter> indexedParameterList;
private int orderInFrom; private int orderInFrom;
private ArrayList<Parameter> suppliedParameterList;
public Parser(Session session) { public Parser(Session session) {
this.database = session.getDatabase(); this.database = session.getDatabase();
...@@ -672,9 +673,35 @@ public class Parser { ...@@ -672,9 +673,35 @@ public class Parser {
*/ */
public Command prepareCommand(String sql) { public Command prepareCommand(String sql) {
try { try {
Command c = prepareSingleCommand(sql);
if (currentTokenType == SEMICOLON) {
String remaining = originalSQL.substring(parseIndex);
if (!StringUtils.isWhitespaceOrEmpty(remaining)) {
c = prepareCommandList(c, sql, remaining);
}
}
return c;
} catch (DbException e) {
throw e.addSQL(originalSQL);
}
}
private Command prepareCommandList(Command c, String sql, String remaining) {
ArrayList<Command> list = Utils.newSmallArrayList();
list.add(c);
do {
suppliedParameters = parameters;
suppliedParameterList = indexedParameterList;
list.add(prepareSingleCommand(remaining));
} while (currentTokenType == SEMICOLON
&& !StringUtils.isWhitespaceOrEmpty(remaining = originalSQL.substring(parseIndex)));
return new CommandList(session, sql, list, parameters);
}
private Command prepareSingleCommand(String sql) {
Prepared p = parse(sql); Prepared p = parse(sql);
boolean hasMore = isToken(SEMICOLON); if (currentTokenType != SEMICOLON && currentTokenType != END) {
if (!hasMore && currentTokenType != END) { addExpected(SEMICOLON);
throw getSyntaxError(); throw getSyntaxError();
} }
try { try {
...@@ -683,17 +710,7 @@ public class Parser { ...@@ -683,17 +710,7 @@ public class Parser {
CommandContainer.clearCTE(session, p); CommandContainer.clearCTE(session, p);
throw t; throw t;
} }
Command c = new CommandContainer(session, sql, p); return new CommandContainer(session, sql, p);
if (hasMore) {
String remaining = originalSQL.substring(parseIndex);
if (!StringUtils.isWhitespaceOrEmpty(remaining)) {
c = new CommandList(session, sql, c, remaining);
}
}
return c;
} catch (DbException e) {
throw e.addSQL(originalSQL);
}
} }
/** /**
...@@ -727,12 +744,12 @@ public class Parser { ...@@ -727,12 +744,12 @@ public class Parser {
} else { } else {
expectedList = null; expectedList = null;
} }
parameters = Utils.newSmallArrayList(); parameters = suppliedParameters != null ? suppliedParameters : Utils.<Parameter>newSmallArrayList();
indexedParameterList = suppliedParameterList;
currentSelect = null; currentSelect = null;
currentPrepared = null; currentPrepared = null;
createView = null; createView = null;
recompileAlways = false; recompileAlways = false;
indexedParameterList = suppliedParameterList;
read(); read();
return parsePrepared(); return parsePrepared();
} }
......
...@@ -199,6 +199,7 @@ public class TestPreparedStatement extends TestDb { ...@@ -199,6 +199,7 @@ public class TestPreparedStatement extends TestDb {
testColumnMetaDataWithEquals(conn); testColumnMetaDataWithEquals(conn);
testColumnMetaDataWithIn(conn); testColumnMetaDataWithIn(conn);
testValueResultSet(conn); testValueResultSet(conn);
testMultipleStatements(conn);
conn.close(); conn.close();
testPreparedStatementWithLiteralsNone(); testPreparedStatementWithLiteralsNone();
testPreparedStatementWithIndexedParameterAndLiteralsNone(); testPreparedStatementWithIndexedParameterAndLiteralsNone();
...@@ -1758,4 +1759,31 @@ public class TestPreparedStatement extends TestDb { ...@@ -1758,4 +1759,31 @@ public class TestPreparedStatement extends TestDb {
} }
} }
private void testMultipleStatements(Connection conn) throws SQLException {
assertThrows(ErrorCode.CANNOT_MIX_INDEXED_AND_UNINDEXED_PARAMS, conn).prepareStatement("SELECT ?; SELECT ?1");
assertThrows(ErrorCode.CANNOT_MIX_INDEXED_AND_UNINDEXED_PARAMS, conn).prepareStatement("SELECT ?1; SELECT ?");
Statement stmt = conn.createStatement();
stmt.execute("CREATE TABLE TEST (ID IDENTITY, V INT)");
PreparedStatement ps = conn.prepareStatement("INSERT INTO TEST(V) VALUES ?; INSERT INTO TEST(V) VALUES ?");
ps.setInt(1, 1);
ps.setInt(2, 2);
ps.executeUpdate();
ps = conn.prepareStatement("INSERT INTO TEST(V) VALUES ?2; INSERT INTO TEST(V) VALUES ?1;");
ps.setInt(1, 3);
ps.setInt(2, 4);
ps.executeUpdate();
try (ResultSet rs = stmt.executeQuery("SELECT V FROM TEST ORDER BY ID")) {
assertTrue(rs.next());
assertEquals(1, rs.getInt(1));
assertTrue(rs.next());
assertEquals(2, rs.getInt(1));
assertTrue(rs.next());
assertEquals(4, rs.getInt(1));
assertTrue(rs.next());
assertEquals(3, rs.getInt(1));
assertFalse(rs.next());
}
stmt.execute("DROP TABLE TEST");
}
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论