提交 ba9bb532 authored 作者: Thomas Mueller's avatar Thomas Mueller

The functions SUM and AVG could overflow.

上级 48f24252
...@@ -18,7 +18,9 @@ Change Log ...@@ -18,7 +18,9 @@ Change Log
<h1>Change Log</h1> <h1>Change Log</h1>
<h2>Next Version (unreleased)</h2> <h2>Next Version (unreleased)</h2>
<ul><li>The emergency reserve file has been removed. It didn't provide an appropriate <ul><li>The function SUM could overflow when using large values. It returns now a data type that is safe.
</li><li>The function AVG could overflow when using large values. Fixed.
</li><li>The emergency reserve file has been removed. It didn't provide an appropriate
solution for the problem. It is still possible for an application to detect and deal with solution for the problem. It is still possible for an application to detect and deal with
the low disk space problem (deleting temporary files for example) the low disk space problem (deleting temporary files for example)
using DatabaseEventListener.diskSpaceIsLow, but this method is now always called using DatabaseEventListener.diskSpaceIsLow, but this method is now always called
......
...@@ -52,65 +52,65 @@ public class Aggregate extends Expression { ...@@ -52,65 +52,65 @@ public class Aggregate extends Expression {
*/ */
public static final int COUNT = 1; public static final int COUNT = 1;
/**
* The aggregate type for GROUP_CONCAT(...).
*/
public static final int GROUP_CONCAT = 2;
/** /**
* The aggregate type for SUM(expression). * The aggregate type for SUM(expression).
*/ */
public static final int SUM = 2; static final int SUM = 3;
/** /**
* The aggregate type for MIN(expression). * The aggregate type for MIN(expression).
*/ */
public static final int MIN = 3; static final int MIN = 4;
/** /**
* The aggregate type for MAX(expression). * The aggregate type for MAX(expression).
*/ */
public static final int MAX = 4; static final int MAX = 5;
/** /**
* The aggregate type for AVG(expression). * The aggregate type for AVG(expression).
*/ */
public static final int AVG = 5; static final int AVG = 6;
/**
* The aggregate type for GROUP_CONCAT(...).
*/
public static final int GROUP_CONCAT = 6;
/** /**
* The aggregate type for STDDEV_POP(expression). * The aggregate type for STDDEV_POP(expression).
*/ */
public static final int STDDEV_POP = 7; static final int STDDEV_POP = 7;
/** /**
* The aggregate type for STDDEV_SAMP(expression). * The aggregate type for STDDEV_SAMP(expression).
*/ */
public static final int STDDEV_SAMP = 8; static final int STDDEV_SAMP = 8;
/** /**
* The aggregate type for VAR_POP(expression). * The aggregate type for VAR_POP(expression).
*/ */
public static final int VAR_POP = 9; static final int VAR_POP = 9;
/** /**
* The aggregate type for VAR_SAMP(expression). * The aggregate type for VAR_SAMP(expression).
*/ */
public static final int VAR_SAMP = 10; static final int VAR_SAMP = 10;
/** /**
* The aggregate type for BOOL_OR(expression). * The aggregate type for BOOL_OR(expression).
*/ */
public static final int BOOL_OR = 11; static final int BOOL_OR = 11;
/** /**
* The aggregate type for BOOL_AND(expression). * The aggregate type for BOOL_AND(expression).
*/ */
public static final int BOOL_AND = 12; static final int BOOL_AND = 12;
/** /**
* The aggregate type for SELECTIVITY(expression). * The aggregate type for SELECTIVITY(expression).
*/ */
public static final int SELECTIVITY = 13; static final int SELECTIVITY = 13;
private static final HashMap AGGREGATES = new HashMap(); private static final HashMap AGGREGATES = new HashMap();
...@@ -233,7 +233,7 @@ public class Aggregate extends Expression { ...@@ -233,7 +233,7 @@ public class Aggregate extends Expression {
AggregateData data = (AggregateData) group.get(this); AggregateData data = (AggregateData) group.get(this);
if (data == null) { if (data == null) {
data = new AggregateData(type); data = new AggregateData(type, dataType);
group.put(this, data); group.put(this, data);
} }
Value v = on == null ? null : on.getValue(session); Value v = on == null ? null : on.getValue(session);
...@@ -287,7 +287,7 @@ public class Aggregate extends Expression { ...@@ -287,7 +287,7 @@ public class Aggregate extends Expression {
} }
AggregateData data = (AggregateData) group.get(this); AggregateData data = (AggregateData) group.get(this);
if (data == null) { if (data == null) {
data = new AggregateData(type); data = new AggregateData(type, dataType);
} }
Value v = data.getValue(session.getDatabase(), distinct); Value v = data.getValue(session.getDatabase(), distinct);
if (type == GROUP_CONCAT) { if (type == GROUP_CONCAT) {
...@@ -394,6 +394,11 @@ public class Aggregate extends Expression { ...@@ -394,6 +394,11 @@ public class Aggregate extends Expression {
displaySize = ValueInt.DISPLAY_SIZE; displaySize = ValueInt.DISPLAY_SIZE;
break; break;
case SUM: case SUM:
if (!DataType.supportsAdd(dataType)) {
throw Message.getSQLException(ErrorCode.SUM_OR_AVG_ON_WRONG_DATATYPE_1, getSQL());
}
dataType = DataType.getAddProofType(dataType);
break;
case AVG: case AVG:
if (!DataType.supportsAdd(dataType)) { if (!DataType.supportsAdd(dataType)) {
throw Message.getSQLException(ErrorCode.SUM_OR_AVG_ON_WRONG_DATATYPE_1, getSQL()); throw Message.getSQLException(ErrorCode.SUM_OR_AVG_ON_WRONG_DATATYPE_1, getSQL());
......
...@@ -13,6 +13,7 @@ import org.h2.engine.Database; ...@@ -13,6 +13,7 @@ import org.h2.engine.Database;
import org.h2.message.Message; import org.h2.message.Message;
import org.h2.util.ObjectArray; import org.h2.util.ObjectArray;
import org.h2.util.ValueHashMap; import org.h2.util.ValueHashMap;
import org.h2.value.DataType;
import org.h2.value.Value; import org.h2.value.Value;
import org.h2.value.ValueBoolean; import org.h2.value.ValueBoolean;
import org.h2.value.ValueDouble; import org.h2.value.ValueDouble;
...@@ -25,14 +26,16 @@ import org.h2.value.ValueNull; ...@@ -25,14 +26,16 @@ import org.h2.value.ValueNull;
*/ */
class AggregateData { class AggregateData {
private final int aggregateType; private final int aggregateType;
private final int dataType;
private long count; private long count;
private ValueHashMap distinctValues; private ValueHashMap distinctValues;
private Value value; private Value value;
private double sum, vpn; private double sum, vpn;
private ObjectArray list; private ObjectArray list;
AggregateData(int aggregateType) { AggregateData(int aggregateType, int dataType) {
this.aggregateType = aggregateType; this.aggregateType = aggregateType;
this.dataType = dataType;
} }
/** /**
...@@ -75,9 +78,16 @@ class AggregateData { ...@@ -75,9 +78,16 @@ class AggregateData {
case Aggregate.COUNT: case Aggregate.COUNT:
return; return;
case Aggregate.SUM: case Aggregate.SUM:
if (value == null) {
value = v.convertTo(dataType);
} else {
v = v.convertTo(value.getType());
value = value.add(v);
}
break;
case Aggregate.AVG: case Aggregate.AVG:
if (value == null) { if (value == null) {
value = v; value = v.convertTo(DataType.getAddProofType(dataType));
} else { } else {
v = v.convertTo(value.getType()); v = v.convertTo(value.getType());
value = value.add(v); value = value.add(v);
...@@ -216,7 +226,7 @@ class AggregateData { ...@@ -216,7 +226,7 @@ class AggregateData {
default: default:
throw Message.getInternalError("type=" + aggregateType); throw Message.getInternalError("type=" + aggregateType);
} }
return v == null ? ValueNull.INSTANCE : v; return v == null ? ValueNull.INSTANCE : v.convertTo(dataType);
} }
private Value divide(Value a, long count) throws SQLException { private Value divide(Value a, long count) throws SQLException {
......
...@@ -894,6 +894,29 @@ public class DataType { ...@@ -894,6 +894,29 @@ public class DataType {
} }
} }
/**
* Get the data type that will not overflow when calling 'add' 2 billion times.
*
* @param type the value type
* @return the data type that supports adding
*/
public static int getAddProofType(int type) {
switch (type) {
case Value.BYTE:
return Value.LONG;
case Value.FLOAT:
return Value.DOUBLE;
case Value.INT:
return Value.LONG;
case Value.LONG:
return Value.DECIMAL;
case Value.SHORT:
return Value.LONG;
default:
return type;
}
}
/** /**
* Get the default value in the form of a Java object for the given Java class. * Get the default value in the form of a Java object for the given Java class.
* *
......
...@@ -268,7 +268,7 @@ public class SamplesTest extends TestBase { ...@@ -268,7 +268,7 @@ public class SamplesTest extends TestBase {
private void testSum() { private void testSum() {
Product p = new Product(); Product p = new Product();
Integer sum = db.from(p).selectFirst(sum(p.unitsInStock)); Long sum = db.from(p).selectFirst(sum(p.unitsInStock));
assertEquals(323, sum.intValue()); assertEquals(323, sum.intValue());
Double sumPrice = db.from(p).selectFirst(sum(p.unitPrice)); Double sumPrice = db.from(p).selectFirst(sum(p.unitPrice));
assertEquals(313.35, sumPrice.doubleValue()); assertEquals(313.35, sumPrice.doubleValue());
......
select sum(cast(x as int)) from system_range(2147483547, 2147483637);
> 195421006872;
select sum(x) from system_range(9223372036854775707, 9223372036854775797);
> 839326855353784593432;
select sum(cast(100 as tinyint)) from system_range(1, 1000);
> 100000;
select sum(cast(100 as smallint)) from system_range(1, 1000);
> 100000;
select avg(cast(x as int)) from system_range(2147483547, 2147483637);
> 2147483592;
select avg(x) from system_range(9223372036854775707, 9223372036854775797);
> 9223372036854775752;
select avg(cast(100 as tinyint)) from system_range(1, 1000);
> 100;
select avg(cast(100 as smallint)) from system_range(1, 1000);
> 100;
select datediff(yyyy, now(), now()); select datediff(yyyy, now(), now());
> 0; > 0;
create table t(d date) as select '2008-11-01' union select '2008-11-02'; create table t(d date) as select '2008-11-01' union select '2008-11-02';
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论