提交 f79d8fca authored 作者: Evgenij Ryazanov's avatar Evgenij Ryazanov

Fix return type of PERCENTILE_CONT and MEDIAN

上级 92091718
......@@ -5,6 +5,7 @@
*/
package org.h2.expression.aggregate;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
......@@ -322,8 +323,8 @@ public class Aggregate extends AbstractAggregate {
if (v == ValueNull.INSTANCE) {
return ValueNull.INSTANCE;
}
double arg = v.getDouble();
if (arg >= 0d && arg <= 1d) {
BigDecimal arg = v.getBigDecimal();
if (arg.signum() >= 0 && arg.compareTo(BigDecimal.ONE) <= 0) {
return Percentile.getFromIndex(session, orderByList.get(0).expression, type.getValueType(),
orderByList, arg, aggregateType == AggregateType.PERCENTILE_CONT);
} else {
......@@ -332,7 +333,7 @@ public class Aggregate extends AbstractAggregate {
}
}
case MEDIAN:
return Percentile.getFromIndex(session, on, type.getValueType(), orderByList, 0.5d, true);
return Percentile.getFromIndex(session, on, type.getValueType(), orderByList, Percentile.HALF, true);
case ENVELOPE:
return ((MVSpatialIndex) AggregateDataEnvelope.getGeometryColumnIndex(on)).getBounds(session);
default:
......@@ -402,8 +403,8 @@ public class Aggregate extends AbstractAggregate {
if (v == ValueNull.INSTANCE) {
return ValueNull.INSTANCE;
}
double arg = v.getDouble();
if (arg >= 0d && arg <= 1d) {
BigDecimal arg = v.getBigDecimal();
if (arg.signum() >= 0 && arg.compareTo(BigDecimal.ONE) <= 0) {
return Percentile.getValue(session.getDatabase(), array, type.getValueType(), orderByList, arg,
aggregateType == AggregateType.PERCENTILE_CONT);
} else {
......@@ -416,7 +417,8 @@ public class Aggregate extends AbstractAggregate {
if (array == null) {
return ValueNull.INSTANCE;
}
return Percentile.getValue(session.getDatabase(), array, type.getValueType(), orderByList, 0.5d, true);
return Percentile.getValue(session.getDatabase(), array, type.getValueType(), orderByList, Percentile.HALF,
true);
}
case MODE:
return getMode(session, data);
......@@ -589,9 +591,23 @@ public class Aggregate extends AbstractAggregate {
break;
case MIN:
case MAX:
case MEDIAN:
break;
case PERCENTILE_CONT:
type = orderByList.get(0).expression.getType();
//$FALL-THROUGH$
case MEDIAN:
switch (type.getValueType()) {
case Value.BYTE:
case Value.SHORT:
case Value.INT:
case Value.LONG:
case Value.DECIMAL:
case Value.DOUBLE:
case Value.FLOAT:
type = TypeInfo.TYPE_DECIMAL_DEFAULT;
break;
}
break;
case PERCENTILE_DISC:
case MODE:
type = orderByList.get(0).expression.getType();
......
......@@ -32,11 +32,7 @@ import org.h2.value.CompareMode;
import org.h2.value.Value;
import org.h2.value.ValueDate;
import org.h2.value.ValueDecimal;
import org.h2.value.ValueDouble;
import org.h2.value.ValueFloat;
import org.h2.value.ValueInt;
import org.h2.value.ValueInterval;
import org.h2.value.ValueLong;
import org.h2.value.ValueNull;
import org.h2.value.ValueTime;
import org.h2.value.ValueTimestamp;
......@@ -47,6 +43,11 @@ import org.h2.value.ValueTimestampTimeZone;
*/
final class Percentile {
/**
* BigDecimal value of 0.5.
*/
static final BigDecimal HALF = BigDecimal.valueOf(0.5d);
private static boolean isNullsLast(Index index) {
IndexColumn ic = index.getIndexColumns()[0];
int sortType = ic.sortType;
......@@ -106,22 +107,22 @@ final class Percentile {
* @return the result
*/
static Value getValue(Database database, Value[] array, int dataType, ArrayList<SelectOrderBy> orderByList,
double percentile, boolean interpolate) {
BigDecimal percentile, boolean interpolate) {
final CompareMode compareMode = database.getCompareMode();
Arrays.sort(array, compareMode);
int count = array.length;
boolean reverseIndex = orderByList != null && (orderByList.get(0).sortType & SortOrder.DESCENDING) != 0;
double fpRow = (count - 1) * percentile;
int rowIdx1 = (int) fpRow;
double factor = fpRow - rowIdx1;
BigDecimal fpRow = BigDecimal.valueOf(count - 1).multiply(percentile);
int rowIdx1 = fpRow.intValue();
BigDecimal factor = fpRow.subtract(BigDecimal.valueOf(rowIdx1));
int rowIdx2;
if (factor == 0d) {
if (factor.signum() == 0) {
interpolate = false;
rowIdx2 = rowIdx1;
} else {
rowIdx2 = rowIdx1 + 1;
if (!interpolate) {
if (factor > 0.5d) {
if (factor.compareTo(HALF) > 0) {
rowIdx1 = rowIdx2;
} else {
rowIdx2 = rowIdx1;
......@@ -151,7 +152,7 @@ final class Percentile {
* @return the result
*/
static Value getFromIndex(Session session, Expression expression, int dataType,
ArrayList<SelectOrderBy> orderByList, double percentile, boolean interpolate) {
ArrayList<SelectOrderBy> orderByList, BigDecimal percentile, boolean interpolate) {
Index index = getColumnIndex(expression);
long count = index.getRowCount(session);
if (count == 0) {
......@@ -199,17 +200,17 @@ final class Percentile {
}
boolean reverseIndex = (orderByList != null ? orderByList.get(0).sortType & SortOrder.DESCENDING : 0)
!= (index.getIndexColumns()[0].sortType & SortOrder.DESCENDING);
double fpRow = (count - 1) * percentile;
long rowIdx1 = (long) fpRow;
double factor = fpRow - rowIdx1;
BigDecimal fpRow = BigDecimal.valueOf(count - 1).multiply(percentile);
long rowIdx1 = fpRow.longValue();
BigDecimal factor = fpRow.subtract(BigDecimal.valueOf(rowIdx1));
long rowIdx2;
if (factor == 0d) {
if (factor.signum() == 0) {
interpolate = false;
rowIdx2 = rowIdx1;
} else {
rowIdx2 = rowIdx1 + 1;
if (!interpolate) {
if (factor > 0.5d) {
if (factor.compareTo(HALF) > 0) {
rowIdx1 = rowIdx2;
} else {
rowIdx2 = rowIdx1;
......@@ -246,10 +247,10 @@ final class Percentile {
}
return interpolate(v, v2, factor, dataType, database.getMode(), database.getCompareMode());
}
return v;
return v.convertTo(dataType);
}
private static Value interpolate(Value v0, Value v1, double factor, int dataType, Mode databaseMode,
private static Value interpolate(Value v0, Value v1, BigDecimal factor, int dataType, Mode databaseMode,
CompareMode compareMode) {
if (v0.compareTo(v1, databaseMode, compareMode) == 0) {
return v0.convertTo(dataType);
......@@ -258,21 +259,18 @@ final class Percentile {
case Value.BYTE:
case Value.SHORT:
case Value.INT:
return ValueInt.get((int) (v0.getInt() * (1 - factor) + v1.getInt() * factor)).convertTo(dataType);
return ValueDecimal.get(
interpolateDecimal(BigDecimal.valueOf(v0.getInt()), BigDecimal.valueOf(v1.getInt()), factor));
case Value.LONG:
return ValueLong
.get(interpolateDecimal(BigDecimal.valueOf(v0.getLong()), BigDecimal.valueOf(v1.getLong()), factor)
.longValue());
return ValueDecimal.get(
interpolateDecimal(BigDecimal.valueOf(v0.getLong()), BigDecimal.valueOf(v1.getLong()), factor));
case Value.DECIMAL:
return ValueDecimal.get(interpolateDecimal(v0.getBigDecimal(), v1.getBigDecimal(), factor));
case Value.FLOAT:
return ValueFloat.get(
interpolateDecimal(BigDecimal.valueOf(v0.getFloat()), BigDecimal.valueOf(v1.getFloat()), factor)
.floatValue());
case Value.DOUBLE:
return ValueDouble.get(
interpolateDecimal(BigDecimal.valueOf(v0.getDouble()), BigDecimal.valueOf(v1.getDouble()), factor)
.doubleValue());
return ValueDecimal.get(
interpolateDecimal(
BigDecimal.valueOf(v0.getDouble()), BigDecimal.valueOf(v1.getDouble()), factor));
case Value.TIME: {
ValueTime t0 = (ValueTime) v0.convertTo(Value.TIME), t1 = (ValueTime) v1.convertTo(Value.TIME);
BigDecimal n0 = BigDecimal.valueOf(t0.getNanos());
......@@ -307,12 +305,15 @@ final class Percentile {
ts1 = (ValueTimestampTimeZone) v1.convertTo(Value.TIMESTAMP_TZ);
BigDecimal a0 = timestampToDecimal(ts0.getDateValue(), ts0.getTimeNanos());
BigDecimal a1 = timestampToDecimal(ts1.getDateValue(), ts1.getTimeNanos());
double offset = ts0.getTimeZoneOffsetMins() * (1 - factor) + ts1.getTimeZoneOffsetMins() * factor;
short sOffset = (short) offset;
BigDecimal offset = BigDecimal.valueOf(ts0.getTimeZoneOffsetMins())
.multiply(BigDecimal.ONE.subtract(factor))
.add(BigDecimal.valueOf(ts1.getTimeZoneOffsetMins()).multiply(factor));
short shortOffset = offset.shortValue();
BigDecimal shortOffsetBD = BigDecimal.valueOf(shortOffset);
BigDecimal bd = interpolateDecimal(a0, a1, factor);
if (offset != sOffset) {
bd = bd.add(BigDecimal.valueOf(offset - sOffset)
.multiply(BigDecimal.valueOf(DateTimeUtils.NANOS_PER_MINUTE)));
if (offset.compareTo(shortOffsetBD) != 0) {
bd = bd.add(
offset.subtract(shortOffsetBD).multiply(BigDecimal.valueOf(DateTimeUtils.NANOS_PER_MINUTE)));
}
BigInteger[] dr = bd.toBigInteger().divideAndRemainder(IntervalUtils.NANOS_PER_DAY_BI);
long absoluteDay = dr[0].longValue();
......@@ -322,7 +323,7 @@ final class Percentile {
absoluteDay--;
}
return ValueTimestampTimeZone.fromDateValueAndNanos(DateTimeUtils.dateValueFromAbsoluteDay(absoluteDay),
timeNanos, sOffset);
timeNanos, shortOffset);
}
case Value.INTERVAL_YEAR:
case Value.INTERVAL_MONTH:
......@@ -343,7 +344,7 @@ final class Percentile {
.toBigInteger());
default:
// Use the same rules as PERCENTILE_DISC
return (factor > 0.5d ? v1 : v0).convertTo(dataType);
return (factor.compareTo(HALF) > 0 ? v1 : v0).convertTo(dataType);
}
}
......@@ -352,8 +353,8 @@ final class Percentile {
.multiply(IntervalUtils.NANOS_PER_DAY_BI).add(BigInteger.valueOf(timeNanos)));
}
private static BigDecimal interpolateDecimal(BigDecimal d0, BigDecimal d1, double factor) {
return d0.multiply(BigDecimal.valueOf(1 - factor)).add(d1.multiply(BigDecimal.valueOf(factor)));
private static BigDecimal interpolateDecimal(BigDecimal d0, BigDecimal d1, BigDecimal factor) {
return d0.multiply(BigDecimal.ONE.subtract(factor)).add(d1.multiply(factor));
}
private Percentile() {
......
......@@ -130,6 +130,11 @@ public class ValueFloat extends Value {
return value;
}
@Override
public double getDouble() {
return value;
}
@Override
public String getString() {
return Float.toString(value);
......
......@@ -35,7 +35,7 @@ select
> rows: 1
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
......@@ -45,8 +45,8 @@ select
percentile_disc(0.5) within group (order by v desc) d50d,
median(v) m from test;
> D50A D50D M
> ---- ---- --
> 10 20 15
> ---- ---- ----
> 10 20 15.0
> rows: 1
drop table test;
......@@ -84,7 +84,7 @@ select
> rows: 1
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
......@@ -94,8 +94,8 @@ select
percentile_disc(0.5) within group (order by v desc) d50d,
median(v) m from test;
> D50A D50D M
> ---- ---- --
> 10 20 15
> ---- ---- ----
> 10 20 15.0
> rows: 1
drop table test;
......@@ -133,7 +133,7 @@ select
> rows: 1
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
......@@ -143,8 +143,8 @@ select
percentile_disc(0.5) within group (order by v desc) d50d,
median(v) m from test;
> D50A D50D M
> ---- ---- --
> 10 20 15
> ---- ---- ----
> 10 20 15.0
> rows: 1
drop table test;
......@@ -182,7 +182,7 @@ select
> rows: 1
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
......@@ -192,8 +192,8 @@ select
percentile_disc(0.5) within group (order by v desc) d50d,
median(v) m from test;
> D50A D50D M
> ---- ---- --
> 10 20 15
> ---- ---- ----
> 10 20 15.0
> rows: 1
drop table test;
......@@ -231,7 +231,7 @@ select
> rows: 1
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
......@@ -241,8 +241,8 @@ select
percentile_disc(0.5) within group (order by v desc) d50d,
median(v) m from test;
> D50A D50D M
> ---- ---- --
> 10 20 15
> ---- ---- ----
> 10 20 15.0
> rows: 1
drop table test;
......@@ -280,7 +280,7 @@ select
> rows: 1
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
......@@ -290,8 +290,8 @@ select
percentile_disc(0.5) within group (order by v desc) d50d,
median(v) m from test;
> D50A D50D M
> ---- ---- --
> 10 20 15
> ---- ---- ----
> 10 20 15.0
> rows: 1
drop table test;
......@@ -313,13 +313,13 @@ select median(v) from test;
>> 20
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
select median(v) from test;
>> 15
>> 15.0
drop table test;
> ok
......@@ -340,13 +340,13 @@ select median(v) from test;
>> 20
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
select median(v) from test;
>> 15
>> 15.0
drop table test;
> ok
......@@ -367,13 +367,13 @@ select median(v) from test;
>> 20
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
select median(v) from test;
>> 15
>> 15.0
drop table test;
> ok
......@@ -394,13 +394,13 @@ select median(v) from test;
>> 20
select median(distinct v) from test;
>> 15
>> 15.0
insert into test values (10);
> update count: 1
select median(v) from test;
>> 15
>> 15.0
drop table test;
> ok
......@@ -421,13 +421,13 @@ select median(v) from test;
>> 2.0
select median(distinct v) from test;
>> 1.5
>> 1.50
insert into test values (1);
> update count: 1
select median(v) from test;
>> 1.5
>> 1.50
drop table test;
> ok
......@@ -448,13 +448,13 @@ select median(v) from test;
>> 2.0
select median(distinct v) from test;
>> 1.5
>> 1.50
insert into test values (1);
> update count: 1
select median(v) from test;
>> 1.5
>> 1.50
drop table test;
> ok
......@@ -650,7 +650,7 @@ insert into test values ('Group 2A', 10), ('Group 2A', 10), ('Group 2A', 20),
select name, median(value) from test group by name order by name;
> NAME MEDIAN(VALUE)
> -------- -------------
> Group 1X 45
> Group 1X 45.0
> Group 2A 10
> Group 3B null
> rows (ordered): 3
......@@ -707,7 +707,7 @@ insert into test values (10), (20);
> update count: 2
select median(v) from test;
>> 15
>> 15.0
insert into test values (20), (10), (20);
> update count: 3
......@@ -729,7 +729,7 @@ insert into test values (10), (20), (30), (40), (50), (60), (70), (80), (90), (1
select median(v), median(v) filter (where v >= 40) from test where v <= 100;
> MEDIAN(V) MEDIAN(V) FILTER (WHERE (V >= 40))
> --------- ----------------------------------
> 55 70
> 55.0 70
> rows: 1
create index test_idx on test(v);
......@@ -738,13 +738,13 @@ create index test_idx on test(v);
select median(v), median(v) filter (where v >= 40) from test where v <= 100;
> MEDIAN(V) MEDIAN(V) FILTER (WHERE (V >= 40))
> --------- ----------------------------------
> 55 70
> 55.0 70
> rows: 1
select median(v), median(v) filter (where v >= 40) from test;
> MEDIAN(V) MEDIAN(V) FILTER (WHERE (V >= 40))
> --------- ----------------------------------
> 65 80
> 65.0 80
> rows: 1
drop table test;
......@@ -774,7 +774,7 @@ select dept, median(amount) filter (where amount >= 20) from test group by dept
> ------ --------------------------------------------
> First 30
> Second 22
> Third 160
> Third 160.0
> rows (ordered): 3
select dept, median(amount) filter (where amount >= 20) from test
......@@ -782,7 +782,7 @@ select dept, median(amount) filter (where amount >= 20) from test
> DEPT MEDIAN(AMOUNT) FILTER (WHERE (AMOUNT >= 20))
> ------ --------------------------------------------
> First 30
> Second 21
> Second 21.0
> Third 150
> rows (ordered): 3
......@@ -804,10 +804,10 @@ select
percentile_cont(0.95) within group (order by v) c95a,
percentile_cont(0.95) within group (order by v desc) c95d,
g from test group by g;
> C05A C05D C50 C50D C95A C95D G
> ---- ---- --- ---- ---- ---- -
> 1 9 5 5 9 1 1
> 11 89 25 25 89 11 2
> C05A C05D C50 C50D C95A C95D G
> ----- ----- ---- ---- ----- ----- -
> 1.45 9.55 5.5 5.5 9.55 1.45 1
> 11.50 89.50 25.0 25.0 89.50 11.50 2
> rows: 2
select
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论