提交 e8e74529 authored 作者: noelgrandin's avatar noelgrandin

Added support for ON DUPLICATE KEY UPDATE like MySQL with the values() function…

Added support for ON DUPLICATE KEY UPDATE like MySQL with the values() function to update with the value that was to be inserted.
Patch from Jean-Francois Noel.
上级 bdbdf484
...@@ -35,6 +35,8 @@ Change Log ...@@ -35,6 +35,8 @@ Change Log
</li><li>Issue 521: ScriptReader should implement Closeable </li><li>Issue 521: ScriptReader should implement Closeable
</li><li>Issue 524: RunScript.execute does not close its Statement, patch from Gaul. </li><li>Issue 524: RunScript.execute does not close its Statement, patch from Gaul.
</li><li>Add support for DB2 "WITH UR" clause, patch from litailang </li><li>Add support for DB2 "WITH UR" clause, patch from litailang
</li><li>Added support for ON DUPLICATE KEY UPDATE like MySQL with the values() function to update with the value that
was to be inserted. Patch from Jean-Francois Noel.
</li></ul> </li></ul>
<h2>Version 1.3.174 (2013-10-19)</h2> <h2>Version 1.3.174 (2013-10-19)</h2>
......
...@@ -1052,6 +1052,24 @@ public class Parser { ...@@ -1052,6 +1052,24 @@ public class Parser {
} else { } else {
command.setQuery(parseSelect()); command.setQuery(parseSelect());
} }
if (database.getMode().onDuplicateKeyUpdate) {
if (readIf("ON")) {
read("DUPLICATE");
read("KEY");
read("UPDATE");
do {
Column column = parseColumn(table);
read("=");
Expression expression;
if (readIf("DEFAULT")) {
expression = ValueExpression.getDefault();
} else {
expression = readExpression();
}
command.addAssignmentForDuplicate(column, expression);
} while (readIf(","));
}
}
return command; return command;
} }
...@@ -5681,7 +5699,7 @@ public class Parser { ...@@ -5681,7 +5699,7 @@ public class Parser {
return StringUtils.quoteIdentifier(s); return StringUtils.quoteIdentifier(s);
} }
} }
if (Parser.isKeyword(s, true)) { if (isKeyword(s, true)) {
return StringUtils.quoteIdentifier(s); return StringUtils.quoteIdentifier(s);
} }
return s; return s;
......
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
package org.h2.command.dml; package org.h2.command.dml;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.h2.api.Trigger; import org.h2.api.Trigger;
import org.h2.command.Command; import org.h2.command.Command;
import org.h2.command.CommandInterface; import org.h2.command.CommandInterface;
...@@ -15,7 +18,10 @@ import org.h2.constant.ErrorCode; ...@@ -15,7 +18,10 @@ import org.h2.constant.ErrorCode;
import org.h2.engine.Right; import org.h2.engine.Right;
import org.h2.engine.Session; import org.h2.engine.Session;
import org.h2.engine.UndoLogRecord; import org.h2.engine.UndoLogRecord;
import org.h2.expression.Comparison;
import org.h2.expression.ConditionAndOr;
import org.h2.expression.Expression; import org.h2.expression.Expression;
import org.h2.expression.ExpressionColumn;
import org.h2.expression.Parameter; import org.h2.expression.Parameter;
import org.h2.index.Index; import org.h2.index.Index;
import org.h2.message.DbException; import org.h2.message.DbException;
...@@ -24,9 +30,11 @@ import org.h2.result.ResultTarget; ...@@ -24,9 +30,11 @@ import org.h2.result.ResultTarget;
import org.h2.result.Row; import org.h2.result.Row;
import org.h2.table.Column; import org.h2.table.Column;
import org.h2.table.Table; import org.h2.table.Table;
import org.h2.table.TableFilter;
import org.h2.util.New; import org.h2.util.New;
import org.h2.util.StatementBuilder; import org.h2.util.StatementBuilder;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueNull;
/** /**
* This class represents the statement * This class represents the statement
...@@ -41,6 +49,11 @@ public class Insert extends Prepared implements ResultTarget { ...@@ -41,6 +49,11 @@ public class Insert extends Prepared implements ResultTarget {
private boolean sortedInsertMode; private boolean sortedInsertMode;
private int rowNumber; private int rowNumber;
private boolean insertFromSelect; private boolean insertFromSelect;
/**
* for MySQL-style INSERT ... ON DUPLICATE KEY UPDATE ....
*/
private HashMap<Column, Expression> duplicateKeyAssignmentMap;
public Insert(Session session) { public Insert(Session session) {
super(session); super(session);
...@@ -66,6 +79,23 @@ public class Insert extends Prepared implements ResultTarget { ...@@ -66,6 +79,23 @@ public class Insert extends Prepared implements ResultTarget {
this.query = query; this.query = query;
} }
/**
* Keep a collection of the columns to pass to update if a duplicate key
* happens, for MySQL-style INSERT ... ON DUPLICATE KEY UPDATE ....
*
* @param column the column
* @param expression the expression
*/
public void addAssignmentForDuplicate(Column column, Expression expression) {
if (duplicateKeyAssignmentMap == null) {
duplicateKeyAssignmentMap = New.hashMap();
}
if (duplicateKeyAssignmentMap.containsKey(column)) {
throw DbException.get(ErrorCode.DUPLICATE_COLUMN_NAME_1, column.getName());
}
duplicateKeyAssignmentMap.put(column, expression);
}
/** /**
* Add a row to this merge statement. * Add a row to this merge statement.
* *
...@@ -124,7 +154,11 @@ public class Insert extends Prepared implements ResultTarget { ...@@ -124,7 +154,11 @@ public class Insert extends Prepared implements ResultTarget {
boolean done = table.fireBeforeRow(session, null, newRow); boolean done = table.fireBeforeRow(session, null, newRow);
if (!done) { if (!done) {
table.lock(session, true, false); table.lock(session, true, false);
table.addRow(session, newRow); try {
table.addRow(session, newRow);
} catch (DbException de) {
handleOnDuplicate(de);
}
session.log(table, UndoLogRecord.INSERT, newRow); session.log(table, UndoLogRecord.INSERT, newRow);
table.fireAfterRow(session, null, newRow, false); table.fireAfterRow(session, null, newRow, false);
} }
...@@ -277,7 +311,90 @@ public class Insert extends Prepared implements ResultTarget { ...@@ -277,7 +311,90 @@ public class Insert extends Prepared implements ResultTarget {
@Override @Override
public boolean isCacheable() { public boolean isCacheable() {
return true; return duplicateKeyAssignmentMap == null || duplicateKeyAssignmentMap.isEmpty();
} }
private void handleOnDuplicate(DbException de) {
if (de.getErrorCode() != ErrorCode.DUPLICATE_KEY_1) {
throw de;
}
if (duplicateKeyAssignmentMap == null || duplicateKeyAssignmentMap.isEmpty()) {
throw de;
}
ArrayList<String> variableNames = new ArrayList<String>(duplicateKeyAssignmentMap.size());
for (int i = 0; i < columns.length; i++) {
String key = session.getCurrentSchemaName() + "." + table.getName() + "." + columns[i].getName();
variableNames.add(key);
session.setVariable(key, list.get(getCurrentRowNumber() - 1)[i].getValue(session));
}
Update command = new Update(session);
command.setTableFilter(new TableFilter(session, table, null, true, null));
for (Column column : duplicateKeyAssignmentMap.keySet()) {
command.setAssignment(column, duplicateKeyAssignmentMap.get(column));
}
Index foundIndex = searchForUpdateIndex();
if (foundIndex != null) {
command.setCondition(prepareUpdateCondition(foundIndex));
} else {
throw DbException.getUnsupportedException("Unable to apply ON DUPLICATE KEY UPDATE, no index found!");
}
command.prepare();
command.update();
for (String variableName : variableNames) {
session.setVariable(variableName, ValueNull.INSTANCE);
}
}
private Index searchForUpdateIndex() {
Index foundIndex = null;
for (Index index : table.getIndexes()) {
if (index.getIndexType().isPrimaryKey() || index.getIndexType().isUnique()) {
for (Column indexColumn : index.getColumns()) {
for (Column insertColumn : columns) {
if (indexColumn.getName() == insertColumn.getName()) {
foundIndex = index;
break;
} else {
foundIndex = null;
}
}
if (foundIndex == null) {
break;
}
}
if (foundIndex != null) {
break;
}
}
}
return foundIndex;
}
private Expression prepareUpdateCondition(Index foundIndex) {
Expression expression = null;
for (Column column : foundIndex.getColumns()) {
ExpressionColumn expressionColumn = new ExpressionColumn(session.getDatabase(),
session.getCurrentSchemaName(), null, column.getName());
for (int i = 0; i < columns.length; i++) {
if (expressionColumn.getColumnName().equals(columns[i].getName())) {
if (expression == null) {
expression = new Comparison(session, Comparison.EQUAL, expressionColumn,
list.get(getCurrentRowNumber() - 1)[i++]);
} else {
expression = new ConditionAndOr(ConditionAndOr.AND, expression, new Comparison(session,
Comparison.EQUAL, expressionColumn, list.get(0)[i++]));
}
}
}
}
return expression;
}
} }
...@@ -128,6 +128,11 @@ public class Mode { ...@@ -128,6 +128,11 @@ public class Mode {
*/ */
public boolean isolationLevelInSelectStatement; public boolean isolationLevelInSelectStatement;
/**
* MySQL style INSERT ... ON DUPLICATE KEY UPDATE ...
*/
public boolean onDuplicateKeyUpdate;
private final String name; private final String name;
static { static {
...@@ -169,6 +174,7 @@ public class Mode { ...@@ -169,6 +174,7 @@ public class Mode {
mode.convertInsertNullToZero = true; mode.convertInsertNullToZero = true;
mode.indexDefinitionInCreateTable = true; mode.indexDefinitionInCreateTable = true;
mode.lowerCaseIdentifiers = true; mode.lowerCaseIdentifiers = true;
mode.onDuplicateKeyUpdate = true;
add(mode); add(mode);
mode = new Mode("Oracle"); mode = new Mode("Oracle");
......
...@@ -102,6 +102,11 @@ public class Function extends Expression implements FunctionCall { ...@@ -102,6 +102,11 @@ public class Function extends Expression implements FunctionCall {
LINK_SCHEMA = 218, GREATEST = 219, LEAST = 220, CANCEL_SESSION = 221, SET = 222, TABLE = 223, TABLE_DISTINCT = 224, LINK_SCHEMA = 218, GREATEST = 219, LEAST = 220, CANCEL_SESSION = 221, SET = 222, TABLE = 223, TABLE_DISTINCT = 224,
FILE_READ = 225, TRANSACTION_ID = 226, TRUNCATE_VALUE = 227, NVL2 = 228, DECODE = 229, ARRAY_CONTAINS = 230; FILE_READ = 225, TRANSACTION_ID = 226, TRUNCATE_VALUE = 227, NVL2 = 228, DECODE = 229, ARRAY_CONTAINS = 230;
/**
* Used in MySQL-style INSERT ... ON DUPLICATE KEY UPDATE ... VALUES
*/
public static final int VALUES = 250;
/** /**
* This is called H2VERSION() and not VERSION(), because we return a fake value * This is called H2VERSION() and not VERSION(), because we return a fake value
* for VERSION() when running under the PostgreSQL ODBC driver. * for VERSION() when running under the PostgreSQL ODBC driver.
...@@ -365,6 +370,9 @@ public class Function extends Expression implements FunctionCall { ...@@ -365,6 +370,9 @@ public class Function extends Expression implements FunctionCall {
// pseudo function // pseudo function
addFunctionWithNull("ROW_NUMBER", ROW_NUMBER, 0, Value.LONG); addFunctionWithNull("ROW_NUMBER", ROW_NUMBER, 0, Value.LONG);
// ON DUPLICATE KEY VALUES function
addFunction("VALUES", VALUES, 1, Value.NULL, false, true, true);
} }
protected Function(Database database, FunctionInfo info) { protected Function(Database database, FunctionInfo info) {
...@@ -1339,6 +1347,9 @@ public class Function extends Expression implements FunctionCall { ...@@ -1339,6 +1347,9 @@ public class Function extends Expression implements FunctionCall {
result = ValueString.get(StringUtils.xmlText(v0.getString(), v1.getBoolean())); result = ValueString.get(StringUtils.xmlText(v0.getString(), v1.getBoolean()));
} }
break; break;
case VALUES:
result = session.getVariable(args[0].getSchemaName() + "." + args[0].getTableName() + "." + args[0].getColumnName());
break;
default: default:
throw DbException.throwInternalError("type=" + info.type); throw DbException.throwInternalError("type=" + info.type);
} }
......
...@@ -27,6 +27,7 @@ import org.h2.test.db.TestCsv; ...@@ -27,6 +27,7 @@ import org.h2.test.db.TestCsv;
import org.h2.test.db.TestDateStorage; import org.h2.test.db.TestDateStorage;
import org.h2.test.db.TestDeadlock; import org.h2.test.db.TestDeadlock;
import org.h2.test.db.TestDrop; import org.h2.test.db.TestDrop;
import org.h2.test.db.TestDuplicateKeyUpdate;
import org.h2.test.db.TestEncryptedDb; import org.h2.test.db.TestEncryptedDb;
import org.h2.test.db.TestExclusive; import org.h2.test.db.TestExclusive;
import org.h2.test.db.TestFullText; import org.h2.test.db.TestFullText;
...@@ -631,6 +632,7 @@ kill -9 `jps -l | grep "org.h2.test." | cut -d " " -f 1` ...@@ -631,6 +632,7 @@ kill -9 `jps -l | grep "org.h2.test." | cut -d " " -f 1`
new TestDateStorage().runTest(this); new TestDateStorage().runTest(this);
new TestDeadlock().runTest(this); new TestDeadlock().runTest(this);
new TestDrop().runTest(this); new TestDrop().runTest(this);
new TestDuplicateKeyUpdate().runTest(this);
new TestEncryptedDb().runTest(this); new TestEncryptedDb().runTest(this);
new TestExclusive().runTest(this); new TestExclusive().runTest(this);
new TestFullText().runTest(this); new TestFullText().runTest(this);
......
/*
* Copyright 2004-2013 H2 Group. Multiple-Licensed under the H2 License,
* Version 1.0, and under the Eclipse Public License, Version 1.0
* (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.test.db;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import org.h2.test.TestBase;
/**
* Tests for the ON DUPLICATE KEY UPDATE in the Insert class.
*/
public class TestDuplicateKeyUpdate extends TestBase {
/**
* Run just this test.
*
* @param a ignored
*/
public static void main(String... a) throws Exception {
TestBase.createCaller().init().test();
}
@Override
public void test() throws SQLException {
deleteDb("duplicateKeyUpdate");
Connection conn = getConnection("duplicateKeyUpdate;MODE=MySQL");
testDuplicateOnPrimary(conn);
testDuplicateOnUnique(conn);
testDuplicateCache(conn);
testDuplicateExpression(conn);
conn.close();
deleteDb("duplicateKeyUpdate");
}
private void testDuplicateOnPrimary(Connection conn) throws SQLException {
Statement stat = conn.createStatement();
ResultSet rs;
stat.execute("CREATE TABLE `table_test` (\n" + " `id` bigint(20) NOT NULL ,\n"
+ " `a_text` varchar(254) NOT NULL,\n" + " `some_text` varchar(254) NULL,\n"
+ " PRIMARY KEY (`id`)\n" + ") ;");
stat.execute("INSERT INTO table_test ( id, a_text, some_text ) VALUES (1, 'aaaaaaaaaa', 'aaaaaaaaaa')");
stat.execute("INSERT INTO table_test ( id, a_text, some_text ) VALUES (2, 'bbbbbbbbbb', 'bbbbbbbbbb')");
stat.execute("INSERT INTO table_test ( id, a_text, some_text ) VALUES (3, 'cccccccccc', 'cccccccccc')");
stat.execute("INSERT INTO table_test ( id, a_text, some_text ) VALUES (4, 'dddddddddd', 'dddddddddd')");
stat.execute("INSERT INTO table_test ( id, a_text, some_text ) VALUES (5, 'eeeeeeeeee', 'eeeeeeeeee')");
stat.execute("INSERT INTO table_test ( id , a_text, some_text ) VALUES (1, 'zzzzzzzzzz', 'abcdefghij') ON DUPLICATE KEY UPDATE some_text='UPDATE'");
rs = stat.executeQuery("SELECT some_text FROM table_test where id = 1");
rs.next();
assertEquals("UPDATE", rs.getNString(1));
stat.execute("INSERT INTO table_test ( id , a_text, some_text ) VALUES (3, 'zzzzzzzzzz', 'SOME TEXT') ON DUPLICATE KEY UPDATE some_text=values(some_text)");
rs = stat.executeQuery("SELECT some_text FROM table_test where id = 3");
rs.next();
assertEquals("SOME TEXT", rs.getNString(1));
}
private void testDuplicateOnUnique(Connection conn) throws SQLException {
Statement stat = conn.createStatement();
ResultSet rs;
stat.execute("CREATE TABLE `table_test2` (\n" + " `id` bigint(20) NOT NULL AUTO_INCREMENT,\n"
+ " `a_text` varchar(254) NOT NULL,\n" + " `some_text` varchar(254) NOT NULL,\n"
+ " `updatable_text` varchar(254) NULL,\n" + " PRIMARY KEY (`id`)\n" + ") ;");
stat.execute("CREATE UNIQUE INDEX index_name \n" + "ON table_test2 (a_text, some_text);");
stat.execute("INSERT INTO table_test2 ( a_text, some_text, updatable_text ) VALUES ('a', 'a', '1')");
stat.execute("INSERT INTO table_test2 ( a_text, some_text, updatable_text ) VALUES ('b', 'b', '2')");
stat.execute("INSERT INTO table_test2 ( a_text, some_text, updatable_text ) VALUES ('c', 'c', '3')");
stat.execute("INSERT INTO table_test2 ( a_text, some_text, updatable_text ) VALUES ('d', 'd', '4')");
stat.execute("INSERT INTO table_test2 ( a_text, some_text, updatable_text ) VALUES ('e', 'e', '5')");
stat.execute("INSERT INTO table_test2 ( a_text, some_text ) VALUES ('e', 'e') ON DUPLICATE KEY UPDATE updatable_text='UPDATE'");
rs = stat.executeQuery("SELECT updatable_text FROM table_test2 where a_text = 'e'");
rs.next();
assertEquals("UPDATE", rs.getNString(1));
stat.execute("INSERT INTO table_test2 (a_text, some_text, updatable_text ) VALUES ('b', 'b', 'jfno') ON DUPLICATE KEY UPDATE updatable_text=values(updatable_text)");
rs = stat.executeQuery("SELECT updatable_text FROM table_test2 where a_text = 'b'");
rs.next();
assertEquals("jfno", rs.getNString(1));
}
private void testDuplicateCache(Connection conn) throws SQLException {
Statement stat = conn.createStatement();
ResultSet rs;
stat.execute("CREATE TABLE `table_test3` (\n" + " `id` bigint(20) NOT NULL ,\n"
+ " `a_text` varchar(254) NOT NULL,\n" + " `some_text` varchar(254) NULL,\n"
+ " PRIMARY KEY (`id`)\n" + ") ;");
stat.execute("INSERT INTO table_test3 ( id, a_text, some_text ) VALUES (1, 'aaaaaaaaaa', 'aaaaaaaaaa')");
stat.execute("INSERT INTO table_test3 ( id , a_text, some_text ) VALUES (1, 'zzzzzzzzzz', 'SOME TEXT') ON DUPLICATE KEY UPDATE some_text=values(some_text)");
rs = stat.executeQuery("SELECT some_text FROM table_test3 where id = 1");
rs.next();
assertEquals("SOME TEXT", rs.getNString(1));
// Execute twice the same query to use the one from cache without
// parsing, caused the values parameter to be seen as ammbiguous
stat.execute("INSERT INTO table_test3 ( id , a_text, some_text ) VALUES (1, 'zzzzzzzzzz', 'SOME TEXT') ON DUPLICATE KEY UPDATE some_text=values(some_text)");
rs = stat.executeQuery("SELECT some_text FROM table_test3 where id = 1");
rs.next();
assertEquals("SOME TEXT", rs.getNString(1));
}
private void testDuplicateExpression(Connection conn) throws SQLException {
Statement stat = conn.createStatement();
ResultSet rs;
stat.execute("CREATE TABLE `table_test4` (\n" + " `id` bigint(20) NOT NULL ,\n"
+ " `a_text` varchar(254) NOT NULL,\n" + " `some_value` int(10) NULL,\n" + " PRIMARY KEY (`id`)\n"
+ ") ;");
stat.execute("INSERT INTO table_test4 ( id, a_text, some_value ) VALUES (1, 'aaaaaaaaaa', 5)");
stat.execute("INSERT INTO table_test4 ( id, a_text, some_value ) VALUES (2, 'aaaaaaaaaa', 5)");
stat.execute("INSERT INTO table_test4 ( id , a_text, some_value ) VALUES (1, 'b', 1) ON DUPLICATE KEY UPDATE some_value=some_value + values(some_value)");
stat.execute("INSERT INTO table_test4 ( id , a_text, some_value ) VALUES (1, 'b', 1) ON DUPLICATE KEY UPDATE some_value=some_value + 100");
stat.execute("INSERT INTO table_test4 ( id , a_text, some_value ) VALUES (2, 'b', 1) ON DUPLICATE KEY UPDATE some_value=values(some_value) + 1");
rs = stat.executeQuery("SELECT some_value FROM table_test4 where id = 1");
rs.next();
assertEquals(106, rs.getInt(1));
rs = stat.executeQuery("SELECT some_value FROM table_test4 where id = 2");
rs.next();
assertEquals(2, rs.getInt(1));
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论