提交 e87eb442 authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Return LocalResult from GeneratedKeys.getKeys()

上级 ce60c783
...@@ -14,7 +14,6 @@ import org.h2.engine.Session; ...@@ -14,7 +14,6 @@ import org.h2.engine.Session;
import org.h2.expression.ParameterInterface; import org.h2.expression.ParameterInterface;
import org.h2.message.DbException; import org.h2.message.DbException;
import org.h2.message.Trace; import org.h2.message.Trace;
import org.h2.result.LocalResult;
import org.h2.result.ResultInterface; import org.h2.result.ResultInterface;
import org.h2.result.ResultWithGeneratedKeys; import org.h2.result.ResultWithGeneratedKeys;
import org.h2.util.MathUtils; import org.h2.util.MathUtils;
...@@ -262,9 +261,7 @@ public abstract class Command implements CommandInterface { ...@@ -262,9 +261,7 @@ public abstract class Command implements CommandInterface {
int updateCount = update(); int updateCount = update();
if (!Boolean.FALSE.equals(generatedKeysRequest)) { if (!Boolean.FALSE.equals(generatedKeysRequest)) {
return new ResultWithGeneratedKeys.WithKeys(updateCount, return new ResultWithGeneratedKeys.WithKeys(updateCount,
LocalResult.read(session, session.getGeneratedKeys().getKeys(session));
session.getGeneratedKeys().getKeys(),
Integer.MAX_VALUE));
} }
return ResultWithGeneratedKeys.of(updateCount); return ResultWithGeneratedKeys.of(updateCount);
} catch (DbException e) { } catch (DbException e) {
......
...@@ -10,14 +10,16 @@ import java.util.Collections; ...@@ -10,14 +10,16 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.h2.expression.Expression;
import org.h2.expression.ExpressionColumn;
import org.h2.result.LocalResult;
import org.h2.result.Row; import org.h2.result.Row;
import org.h2.table.Column; import org.h2.table.Column;
import org.h2.table.Table; import org.h2.table.Table;
import org.h2.tools.SimpleResultSet;
import org.h2.util.MathUtils;
import org.h2.util.New; import org.h2.util.New;
import org.h2.util.StringUtils; import org.h2.util.StringUtils;
import org.h2.value.DataType; import org.h2.value.Value;
import org.h2.value.ValueNull;
/** /**
* Class for gathering and processing of generated keys. * Class for gathering and processing of generated keys.
...@@ -26,7 +28,7 @@ public final class GeneratedKeys { ...@@ -26,7 +28,7 @@ public final class GeneratedKeys {
/** /**
* Data for result set with generated keys. * Data for result set with generated keys.
*/ */
private final ArrayList<Map<Column, Object>> data = New.arrayList(); private final ArrayList<Map<Column, Value>> data = New.arrayList();
/** /**
* Columns with generated keys in the current row. * Columns with generated keys in the current row.
...@@ -97,14 +99,14 @@ public final class GeneratedKeys { ...@@ -97,14 +99,14 @@ public final class GeneratedKeys {
if (size > 0) { if (size > 0) {
if (size == 1) { if (size == 1) {
Column column = row.get(0); Column column = row.get(0);
data.add(Collections.singletonMap(column, tableRow.getValue(column.getColumnId()).getObject())); data.add(Collections.singletonMap(column, tableRow.getValue(column.getColumnId())));
if (!allColumns.contains(column)) { if (!allColumns.contains(column)) {
allColumns.add(column); allColumns.add(column);
} }
} else { } else {
HashMap<Column, Object> map = new HashMap<>(); HashMap<Column, Value> map = new HashMap<>();
for (Column column : row) { for (Column column : row) {
map.put(column, tableRow.getValue(column.getColumnId()).getObject()); map.put(column, tableRow.getValue(column.getColumnId()));
if (!allColumns.contains(column)) { if (!allColumns.contains(column)) {
allColumns.add(column); allColumns.add(column);
} }
...@@ -118,17 +120,18 @@ public final class GeneratedKeys { ...@@ -118,17 +120,18 @@ public final class GeneratedKeys {
/** /**
* Returns generated keys. * Returns generated keys.
* *
* @return result set with generated keys * @return local result with generated keys
*/ */
public SimpleResultSet getKeys() { public LocalResult getKeys(Session session) {
SimpleResultSet rs = new SimpleResultSet(); Database db = session == null ? null : session.getDatabase();
if (Boolean.FALSE.equals(generatedKeysRequest)) { if (Boolean.FALSE.equals(generatedKeysRequest)) {
return rs; return new LocalResult();
} }
ArrayList<ExpressionColumn> expressionColumns;
if (Boolean.TRUE.equals(generatedKeysRequest)) { if (Boolean.TRUE.equals(generatedKeysRequest)) {
expressionColumns = new ArrayList<>(allColumns.size());
for (Column column : allColumns) { for (Column column : allColumns) {
rs.addColumn(column.getName(), DataType.convertTypeToSQLType(column.getType()), expressionColumns.add(new ExpressionColumn(db, column));
MathUtils.convertLongToInt(column.getPrecision()), column.getScale());
} }
} else if (generatedKeysRequest instanceof int[]) { } else if (generatedKeysRequest instanceof int[]) {
if (table != null) { if (table != null) {
...@@ -136,21 +139,22 @@ public final class GeneratedKeys { ...@@ -136,21 +139,22 @@ public final class GeneratedKeys {
Column[] columns = table.getColumns(); Column[] columns = table.getColumns();
int cnt = columns.length; int cnt = columns.length;
allColumns.clear(); allColumns.clear();
expressionColumns = new ArrayList<>(indices.length);
for (int idx : indices) { for (int idx : indices) {
if (idx >= 1 && idx <= cnt) { if (idx >= 1 && idx <= cnt) {
Column column = columns[idx - 1]; Column column = columns[idx - 1];
rs.addColumn(column.getName(), DataType.convertTypeToSQLType(column.getType()), expressionColumns.add(new ExpressionColumn(db, column));
MathUtils.convertLongToInt(column.getPrecision()), column.getScale());
allColumns.add(column); allColumns.add(column);
} }
} }
} else { } else {
return rs; return new LocalResult();
} }
} else if (generatedKeysRequest instanceof String[]) { } else if (generatedKeysRequest instanceof String[]) {
if (table != null) { if (table != null) {
String[] names = (String[]) generatedKeysRequest; String[] names = (String[]) generatedKeysRequest;
allColumns.clear(); allColumns.clear();
expressionColumns = new ArrayList<>(names.length);
for (String name : names) { for (String name : names) {
Column column; Column column;
search: if (table.doesColumnExist(name)) { search: if (table.doesColumnExist(name)) {
...@@ -169,30 +173,36 @@ public final class GeneratedKeys { ...@@ -169,30 +173,36 @@ public final class GeneratedKeys {
continue; continue;
} }
} }
rs.addColumn(column.getName(), DataType.convertTypeToSQLType(column.getType()), expressionColumns.add(new ExpressionColumn(db, column));
MathUtils.convertLongToInt(column.getPrecision()), column.getScale());
allColumns.add(column); allColumns.add(column);
} }
} else { } else {
return rs; return new LocalResult();
} }
} else { } else {
return rs; return new LocalResult();
} }
if (rs.getColumnCount() == 0) { int columnCount = expressionColumns.size();
return rs; if (columnCount == 0) {
return new LocalResult();
} }
for (Map<Column, Object> map : data) { LocalResult result = new LocalResult(session, expressionColumns.toArray(new Expression[0]), columnCount);
Object[] row = new Object[allColumns.size()]; for (Map<Column, Value> map : data) {
for (Map.Entry<Column, Object> entry : map.entrySet()) { Value[] row = new Value[columnCount];
for (Map.Entry<Column, Value> entry : map.entrySet()) {
int idx = allColumns.indexOf(entry.getKey()); int idx = allColumns.indexOf(entry.getKey());
if (idx >= 0) { if (idx >= 0) {
row[idx] = entry.getValue(); row[idx] = entry.getValue();
} }
} }
rs.addRow(row); for (int i = 0; i < columnCount; i++) {
if (row[i] == null) {
row[i] = ValueNull.INSTANCE;
}
}
result.addRow(row);
} }
return rs; return result;
} }
/** /**
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论