提交 20eeda89 authored 作者: Thomas Mueller's avatar Thomas Mueller

Use HMAC for authenticating remote LOB id's, removing the need for maintaining a…

Use HMAC for authenticating remote LOB id's, removing the need for maintaining a cache, and removing the limit on the number of LOBs per result set.
上级 57acb158
...@@ -698,7 +698,9 @@ public class SessionRemote extends SessionWithState implements DataHandler { ...@@ -698,7 +698,9 @@ public class SessionRemote extends SessionWithState implements DataHandler {
traceOperation("LOB_READ", (int) lobId); traceOperation("LOB_READ", (int) lobId);
transfer.writeInt(SessionRemote.LOB_READ); transfer.writeInt(SessionRemote.LOB_READ);
transfer.writeLong(lobId); transfer.writeLong(lobId);
transfer.writeBytes(hmac); if (clientVersion >= Constants.TCP_PROTOCOL_VERSION_12) {
transfer.writeBytes(hmac);
}
transfer.writeLong(offset); transfer.writeLong(offset);
transfer.writeInt(length); transfer.writeInt(length);
done(transfer); done(transfer);
......
...@@ -37,6 +37,7 @@ import org.h2.util.SmallMap; ...@@ -37,6 +37,7 @@ import org.h2.util.SmallMap;
import org.h2.util.StringUtils; import org.h2.util.StringUtils;
import org.h2.value.Transfer; import org.h2.value.Transfer;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueLobDb;
/** /**
* One server thread is opened per client connection. * One server thread is opened per client connection.
...@@ -50,7 +51,10 @@ public class TcpServerThread implements Runnable { ...@@ -50,7 +51,10 @@ public class TcpServerThread implements Runnable {
private Thread thread; private Thread thread;
private Command commit; private Command commit;
private SmallMap cache = new SmallMap(SysProperties.SERVER_CACHED_OBJECTS); private SmallMap cache = new SmallMap(SysProperties.SERVER_CACHED_OBJECTS);
private SmallLRUCache<Long, CachedInputStream> lobs = SmallLRUCache.newInstance(SysProperties.SERVER_CACHED_OBJECTS); private SmallLRUCache<Long, CachedInputStream> lobs =
SmallLRUCache.newInstance(Math.max(
SysProperties.SERVER_CACHED_OBJECTS,
SysProperties.SERVER_RESULT_SET_FETCH_SIZE * 5));
private int threadId; private int threadId;
private int clientVersion; private int clientVersion;
private String sessionId; private String sessionId;
...@@ -392,13 +396,27 @@ public class TcpServerThread implements Runnable { ...@@ -392,13 +396,27 @@ public class TcpServerThread implements Runnable {
break; break;
} }
case SessionRemote.LOB_READ: { case SessionRemote.LOB_READ: {
byte[] hmac = transfer.readBytes();
long lobId = transfer.readLong(); long lobId = transfer.readLong();
transfer.verifyLobMac(hmac, lobId); byte[] hmac;
CachedInputStream in = lobs.get(lobId); CachedInputStream in;
if (in == null) { if (clientVersion >= Constants.TCP_PROTOCOL_VERSION_11) {
in = new CachedInputStream(null); if (clientVersion >= Constants.TCP_PROTOCOL_VERSION_12) {
lobs.put(lobId, in); hmac = transfer.readBytes();
transfer.verifyLobMac(hmac, lobId);
} else {
hmac = null;
}
in = lobs.get(lobId);
if (in == null) {
in = new CachedInputStream(null);
lobs.put(lobId, in);
}
} else {
hmac = null;
in = lobs.get(lobId);
if (in == null) {
throw DbException.get(ErrorCode.OBJECT_CLOSED);
}
} }
long offset = transfer.readLong(); long offset = transfer.readLong();
if (in.getPos() != offset) { if (in.getPos() != offset) {
...@@ -425,7 +443,7 @@ public class TcpServerThread implements Runnable { ...@@ -425,7 +443,7 @@ public class TcpServerThread implements Runnable {
close(); close();
} }
} }
private int getState(int oldModificationId) { private int getState(int oldModificationId) {
if (session.getModificationId() == oldModificationId) { if (session.getModificationId() == oldModificationId) {
return SessionRemote.STATUS_OK; return SessionRemote.STATUS_OK;
...@@ -438,13 +456,30 @@ public class TcpServerThread implements Runnable { ...@@ -438,13 +456,30 @@ public class TcpServerThread implements Runnable {
transfer.writeBoolean(true); transfer.writeBoolean(true);
Value[] v = result.currentRow(); Value[] v = result.currentRow();
for (int i = 0; i < result.getVisibleColumnCount(); i++) { for (int i = 0; i < result.getVisibleColumnCount(); i++) {
transfer.writeValue(v[i]); if (clientVersion >= Constants.TCP_PROTOCOL_VERSION_12) {
transfer.writeValue(v[i]);
} else {
writeValue(v[i]);
}
} }
} else { } else {
transfer.writeBoolean(false); transfer.writeBoolean(false);
} }
} }
private void writeValue(Value v) throws IOException {
if (v.getType() == Value.CLOB || v.getType() == Value.BLOB) {
if (v instanceof ValueLobDb) {
ValueLobDb lob = (ValueLobDb) v;
if (lob.isStored()) {
long id = lob.getLobId();
lobs.put(id, new CachedInputStream(null));
}
}
}
transfer.writeValue(v);
}
void setThread(Thread thread) { void setThread(Thread thread) {
this.thread = thread; this.thread = thread;
} }
......
...@@ -21,6 +21,7 @@ import java.sql.ResultSetMetaData; ...@@ -21,6 +21,7 @@ import java.sql.ResultSetMetaData;
import java.sql.SQLException; import java.sql.SQLException;
import java.sql.Time; import java.sql.Time;
import java.sql.Timestamp; import java.sql.Timestamp;
import java.util.Arrays;
import org.h2.constant.ErrorCode; import org.h2.constant.ErrorCode;
import org.h2.engine.Constants; import org.h2.engine.Constants;
import org.h2.engine.SessionInterface; import org.h2.engine.SessionInterface;
...@@ -46,7 +47,7 @@ public class Transfer { ...@@ -46,7 +47,7 @@ public class Transfer {
private static final int BUFFER_SIZE = 16 * 1024; private static final int BUFFER_SIZE = 16 * 1024;
private static final int LOB_MAGIC = 0x1234; private static final int LOB_MAGIC = 0x1234;
private static final int LOB_MAC_SALT_LENGTH = 16; private static final int LOB_MAC_SALT_LENGTH = 16;
private Socket socket; private Socket socket;
private DataInputStream in; private DataInputStream in;
private DataOutputStream out; private DataOutputStream out;
...@@ -409,14 +410,16 @@ public class Transfer { ...@@ -409,14 +410,16 @@ public class Transfer {
writeString(v.getString()); writeString(v.getString());
break; break;
case Value.BLOB: { case Value.BLOB: {
if (version >= Constants.TCP_PROTOCOL_VERSION_12) { if (version >= Constants.TCP_PROTOCOL_VERSION_11) {
if (v instanceof ValueLobDb) { if (v instanceof ValueLobDb) {
ValueLobDb lob = (ValueLobDb) v; ValueLobDb lob = (ValueLobDb) v;
if (lob.isStored()) { if (lob.isStored()) {
writeLong(-1); writeLong(-1);
writeInt(lob.getTableId()); writeInt(lob.getTableId());
writeLong(lob.getLobId()); writeLong(lob.getLobId());
writeBytes(calculateLobMac(lob.getLobId())); if (version >= Constants.TCP_PROTOCOL_VERSION_12) {
writeBytes(calculateLobMac(lob.getLobId()));
}
writeLong(lob.getPrecision()); writeLong(lob.getPrecision());
break; break;
} }
...@@ -435,14 +438,16 @@ public class Transfer { ...@@ -435,14 +438,16 @@ public class Transfer {
break; break;
} }
case Value.CLOB: { case Value.CLOB: {
if (version >= Constants.TCP_PROTOCOL_VERSION_12) { if (version >= Constants.TCP_PROTOCOL_VERSION_11) {
if (v instanceof ValueLobDb) { if (v instanceof ValueLobDb) {
ValueLobDb lob = (ValueLobDb) v; ValueLobDb lob = (ValueLobDb) v;
if (lob.isStored()) { if (lob.isStored()) {
writeLong(-1); writeLong(-1);
writeInt(lob.getTableId()); writeInt(lob.getTableId());
writeBytes(calculateLobMac(lob.getLobId()));
writeLong(lob.getLobId()); writeLong(lob.getLobId());
if (version >= Constants.TCP_PROTOCOL_VERSION_12) {
writeBytes(calculateLobMac(lob.getLobId()));
}
writeLong(lob.getPrecision()); writeLong(lob.getPrecision());
break; break;
} }
...@@ -573,11 +578,16 @@ public class Transfer { ...@@ -573,11 +578,16 @@ public class Transfer {
return ValueStringFixed.get(readString()); return ValueStringFixed.get(readString());
case Value.BLOB: { case Value.BLOB: {
long length = readLong(); long length = readLong();
if (version >= Constants.TCP_PROTOCOL_VERSION_12) { if (version >= Constants.TCP_PROTOCOL_VERSION_11) {
if (length == -1) { if (length == -1) {
int tableId = readInt(); int tableId = readInt();
long id = readLong(); long id = readLong();
byte[] hmac = readBytes(); byte[] hmac;
if (version >= Constants.TCP_PROTOCOL_VERSION_12) {
hmac = readBytes();
} else {
hmac = null;
}
long precision = readLong(); long precision = readLong();
return ValueLobDb.create(Value.BLOB, session.getDataHandler().getLobStorage(), tableId, id, hmac, precision); return ValueLobDb.create(Value.BLOB, session.getDataHandler().getLobStorage(), tableId, id, hmac, precision);
} }
...@@ -599,11 +609,16 @@ public class Transfer { ...@@ -599,11 +609,16 @@ public class Transfer {
} }
case Value.CLOB: { case Value.CLOB: {
long length = readLong(); long length = readLong();
if (version >= Constants.TCP_PROTOCOL_VERSION_12) { if (version >= Constants.TCP_PROTOCOL_VERSION_11) {
if (length == -1) { if (length == -1) {
int tableId = readInt(); int tableId = readInt();
long id = readLong(); long id = readLong();
byte[] hmac = readBytes(); byte[] hmac;
if (version >= Constants.TCP_PROTOCOL_VERSION_12) {
hmac = readBytes();
} else {
hmac = null;
}
long precision = readLong(); long precision = readLong();
return ValueLobDb.create(Value.CLOB, session.getDataHandler().getLobStorage(), tableId, id, hmac, precision); return ValueLobDb.create(Value.CLOB, session.getDataHandler().getLobStorage(), tableId, id, hmac, precision);
} }
...@@ -712,11 +727,13 @@ public class Transfer { ...@@ -712,11 +727,13 @@ public class Transfer {
} }
/** /**
* Verify the HMAC.
*
* @throws DbException if the HMAC does not verify * @throws DbException if the HMAC does not verify
*/ */
public void verifyLobMac(byte[] hmacData, long lobId) { public void verifyLobMac(byte[] hmacData, long lobId) {
byte[] result = calculateLobMac(lobId); byte[] result = calculateLobMac(lobId);
if (!result.equals(hmacData)) { if (!Arrays.equals(result, hmacData)) {
throw DbException.get(ErrorCode.REMOTE_CONNECTION_NOT_ALLOWED); throw DbException.get(ErrorCode.REMOTE_CONNECTION_NOT_ALLOWED);
} }
} }
...@@ -725,23 +742,10 @@ public class Transfer { ...@@ -725,23 +742,10 @@ public class Transfer {
if (lobMacSalt == null) { if (lobMacSalt == null) {
lobMacSalt = MathUtils.secureRandomBytes(LOB_MAC_SALT_LENGTH); lobMacSalt = MathUtils.secureRandomBytes(LOB_MAC_SALT_LENGTH);
} }
byte[] data = new byte[8];
byte[] hmacData = SHA256.getHashWithSalt(longToBytes(lobId), lobMacSalt); Utils.writeLong(data, 0, lobId);
byte[] hmacData = SHA256.getHashWithSalt(data, lobMacSalt);
return hmacData; return hmacData;
} }
private static byte[] longToBytes(long src) {
byte[] data = new byte[8];
data[0] = (byte) (src >> 56);
data[1] = (byte) (src >> 48);
data[2] = (byte) (src >> 40);
data[3] = (byte) (src >> 32);
data[4] = (byte) (src >> 24);
data[5] = (byte) (src >> 16);
data[6] = (byte) (src >> 8);
data[7] = (byte) src;
return data;
}
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论