Unverified 提交 afb7546e authored 作者: Noel Grandin's avatar Noel Grandin 提交者: GitHub

Merge pull request #840 from katzyn/median

Add MEDIAN aggregate
...@@ -123,7 +123,12 @@ public class Aggregate extends Expression { ...@@ -123,7 +123,12 @@ public class Aggregate extends Expression {
/** /**
* The aggregate type for HISTOGRAM(expression). * The aggregate type for HISTOGRAM(expression).
*/ */
HISTOGRAM HISTOGRAM,
/**
* The aggregate type for MEDIAN(expression).
*/
MEDIAN
} }
private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(24); private static final HashMap<String, AggregateType> AGGREGATES = new HashMap<>(24);
...@@ -187,6 +192,7 @@ public class Aggregate extends Expression { ...@@ -187,6 +192,7 @@ public class Aggregate extends Expression {
addAggregate("HISTOGRAM", AggregateType.HISTOGRAM); addAggregate("HISTOGRAM", AggregateType.HISTOGRAM);
addAggregate("BIT_OR", AggregateType.BIT_OR); addAggregate("BIT_OR", AggregateType.BIT_OR);
addAggregate("BIT_AND", AggregateType.BIT_AND); addAggregate("BIT_AND", AggregateType.BIT_AND);
addAggregate("MEDIAN", AggregateType.MEDIAN);
} }
private static void addAggregate(String name, AggregateType type) { private static void addAggregate(String name, AggregateType type) {
...@@ -287,7 +293,7 @@ public class Aggregate extends Expression { ...@@ -287,7 +293,7 @@ public class Aggregate extends Expression {
Table table = select.getTopTableFilter().getTable(); Table table = select.getTopTableFilter().getTable();
return ValueLong.get(table.getRowCount(session)); return ValueLong.get(table.getRowCount(session));
case MIN: case MIN:
case MAX: case MAX: {
boolean first = type == AggregateType.MIN; boolean first = type == AggregateType.MIN;
Index index = getMinMaxColumnIndex(); Index index = getMinMaxColumnIndex();
int sortType = index.getIndexColumns()[0].sortType; int sortType = index.getIndexColumns()[0].sortType;
...@@ -303,6 +309,10 @@ public class Aggregate extends Expression { ...@@ -303,6 +309,10 @@ public class Aggregate extends Expression {
v = row.getValue(index.getColumns()[0].getColumnId()); v = row.getValue(index.getColumns()[0].getColumnId());
} }
return v; return v;
}
case MEDIAN: {
return AggregateDataMedian.getFromIndex(session, on, dataType);
}
default: default:
DbException.throwInternalError("type=" + type); DbException.throwInternalError("type=" + type);
} }
...@@ -434,6 +444,7 @@ public class Aggregate extends Expression { ...@@ -434,6 +444,7 @@ public class Aggregate extends Expression {
break; break;
case MIN: case MIN:
case MAX: case MAX:
case MEDIAN:
break; break;
case STDDEV_POP: case STDDEV_POP:
case STDDEV_SAMP: case STDDEV_SAMP:
...@@ -568,6 +579,9 @@ public class Aggregate extends Expression { ...@@ -568,6 +579,9 @@ public class Aggregate extends Expression {
case BIT_OR: case BIT_OR:
text = "BIT_OR"; text = "BIT_OR";
break; break;
case MEDIAN:
text = "MEDIAN";
break;
default: default:
throw DbException.throwInternalError("type=" + type); throw DbException.throwInternalError("type=" + type);
} }
...@@ -606,6 +620,11 @@ public class Aggregate extends Expression { ...@@ -606,6 +620,11 @@ public class Aggregate extends Expression {
case MAX: case MAX:
Index index = getMinMaxColumnIndex(); Index index = getMinMaxColumnIndex();
return index != null; return index != null;
case MEDIAN:
if (distinct) {
return false;
}
return AggregateDataMedian.getMedianColumnIndex(on) != null;
default: default:
return false; return false;
} }
......
...@@ -31,6 +31,8 @@ abstract class AggregateData { ...@@ -31,6 +31,8 @@ abstract class AggregateData {
return new AggregateDataCount(); return new AggregateDataCount();
} else if (aggregateType == AggregateType.HISTOGRAM) { } else if (aggregateType == AggregateType.HISTOGRAM) {
return new AggregateDataHistogram(); return new AggregateDataHistogram();
} else if (aggregateType == AggregateType.MEDIAN) {
return new AggregateDataMedian();
} else { } else {
return new AggregateDataDefault(aggregateType); return new AggregateDataDefault(aggregateType);
} }
......
/*
* Copyright 2004-2018 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import org.h2.engine.Database;
import org.h2.engine.Session;
import org.h2.index.Cursor;
import org.h2.index.Index;
import org.h2.result.SearchRow;
import org.h2.result.SortOrder;
import org.h2.table.Column;
import org.h2.table.IndexColumn;
import org.h2.table.Table;
import org.h2.table.TableFilter;
import org.h2.util.DateTimeUtils;
import org.h2.value.CompareMode;
import org.h2.value.Value;
import org.h2.value.ValueDate;
import org.h2.value.ValueDecimal;
import org.h2.value.ValueDouble;
import org.h2.value.ValueFloat;
import org.h2.value.ValueInt;
import org.h2.value.ValueLong;
import org.h2.value.ValueNull;
import org.h2.value.ValueTime;
import org.h2.value.ValueTimestamp;
import org.h2.value.ValueTimestampTimeZone;
/**
* Data stored while calculating a MEDIAN aggregate.
*/
class AggregateDataMedian extends AggregateData {
private Collection<Value> values;
static Index getMedianColumnIndex(Expression on) {
if (on instanceof ExpressionColumn) {
ExpressionColumn col = (ExpressionColumn) on;
Column column = col.getColumn();
TableFilter filter = col.getTableFilter();
if (filter != null) {
Table table = filter.getTable();
ArrayList<Index> indexes = table.getIndexes();
Index result = null;
if (indexes != null) {
for (int i = 1, size = indexes.size(); i < size; i++) {
Index index = indexes.get(i);
if (!index.canFindNext()) {
continue;
}
if (!index.isFirstColumn(column)) {
continue;
}
IndexColumn ic = index.getIndexColumns()[0];
if (column.isNullable()) {
int sortType = ic.sortType;
// Nulls last is not supported
if ((sortType & SortOrder.NULLS_LAST) != 0)
continue;
// Descending without nulls first is not supported
if ((sortType & SortOrder.DESCENDING) != 0 && (sortType & SortOrder.NULLS_FIRST) == 0) {
continue;
}
}
if (result == null || result.getColumns().length > index.getColumns().length) {
result = index;
}
}
}
return result;
}
}
return null;
}
static Value getFromIndex(Session session, Expression on, int dataType) {
Index index = getMedianColumnIndex(on);
long count = index.getRowCount(session);
if (count == 0) {
return ValueNull.INSTANCE;
}
Cursor cursor = index.find(session, null, null);
cursor.next();
// Skip nulls
SearchRow row;
while (count > 0) {
row = cursor.getSearchRow();
if (row == null) {
return ValueNull.INSTANCE;
}
if (row.getValue(index.getColumns()[0].getColumnId()) == ValueNull.INSTANCE) {
count--;
cursor.next();
} else
break;
}
if (count == 0) {
return ValueNull.INSTANCE;
}
long skip = (count - 1) / 2;
for (int i = 0; i < skip; i++) {
cursor.next();
}
row = cursor.getSearchRow();
Value v;
if (row == null) {
v = ValueNull.INSTANCE;
} else {
v = row.getValue(index.getColumns()[0].getColumnId());
}
if ((count & 1) == 0) {
cursor.next();
row = cursor.getSearchRow();
if (row == null) {
return v;
}
Value v2 = row.getValue(index.getColumns()[0].getColumnId());
return getMedian(v, v2, dataType, session.getDatabase().getCompareMode());
}
return v;
}
@Override
void add(Database database, int dataType, boolean distinct, Value v) {
if (v == ValueNull.INSTANCE) {
return;
}
Collection<Value> c = values;
if (c == null) {
values = c = distinct ? new HashSet<Value>() : new ArrayList<Value>();
}
c.add(v);
}
@Override
Value getValue(Database database, int dataType, boolean distinct) {
Collection<Value> c = values;
if (c == null) {
return ValueNull.INSTANCE;
}
if (distinct && c instanceof ArrayList) {
c = new HashSet<>(c);
}
Value[] a = c.toArray(new Value[0]);
final CompareMode mode = database.getCompareMode();
Arrays.sort(a, new Comparator<Value>() {
@Override
public int compare(Value o1, Value o2) {
return o1.compareTo(o2, mode);
}
});
int len = a.length;
int idx = len / 2;
Value v1 = a[idx];
if ((len & 1) == 1) {
return v1.convertTo(dataType);
}
return getMedian(a[idx - 1], v1, dataType, mode);
}
static Value getMedian(Value v0, Value v1, int dataType, CompareMode mode) {
if (v0.compareTo(v1, mode) == 0) {
return v0.convertTo(dataType);
}
switch (dataType) {
case Value.BYTE:
case Value.SHORT:
case Value.INT:
return ValueInt.get((v0.getInt() + v1.getInt()) / 2).convertTo(dataType);
case Value.LONG:
return ValueLong.get((v0.getLong() + v1.getLong()) / 2);
case Value.DECIMAL:
return ValueDecimal.get(v0.getBigDecimal().add(v1.getBigDecimal()).divide(BigDecimal.valueOf(2)));
case Value.FLOAT:
return ValueFloat.get((v0.getFloat() + v1.getFloat()) / 2);
case Value.DOUBLE:
return ValueDouble.get((v0.getFloat() + v1.getDouble()) / 2);
case Value.TIME: {
return ValueTime.fromMillis((v0.getTime().getTime() + v1.getTime().getTime()) / 2);
}
case Value.DATE: {
ValueDate d0 = (ValueDate) v0.convertTo(Value.DATE), d1 = (ValueDate) v1.convertTo(Value.DATE);
return ValueDate.fromDateValue(
DateTimeUtils.dateValueFromAbsoluteDay((DateTimeUtils.absoluteDayFromDateValue(d0.getDateValue())
+ DateTimeUtils.absoluteDayFromDateValue(d1.getDateValue())) / 2));
}
case Value.TIMESTAMP: {
ValueTimestamp ts0 = (ValueTimestamp) v0.convertTo(Value.TIMESTAMP),
ts1 = (ValueTimestamp) v1.convertTo(Value.TIMESTAMP);
long dateSum = DateTimeUtils.absoluteDayFromDateValue(ts0.getDateValue())
+ DateTimeUtils.absoluteDayFromDateValue(ts1.getDateValue());
long nanos = (ts0.getTimeNanos() + ts1.getTimeNanos()) / 2;
if ((dateSum & 1) != 0) {
nanos += DateTimeUtils.NANOS_PER_DAY / 2;
if (nanos >= DateTimeUtils.NANOS_PER_DAY) {
nanos -= DateTimeUtils.NANOS_PER_DAY;
dateSum++;
}
}
return ValueTimestamp.fromDateValueAndNanos(DateTimeUtils.dateValueFromAbsoluteDay(dateSum / 2), nanos);
}
case Value.TIMESTAMP_TZ: {
ValueTimestampTimeZone ts0 = (ValueTimestampTimeZone) v0.convertTo(Value.TIMESTAMP_TZ),
ts1 = (ValueTimestampTimeZone) v1.convertTo(Value.TIMESTAMP_TZ);
long dateSum = DateTimeUtils.absoluteDayFromDateValue(ts0.getDateValue())
+ DateTimeUtils.absoluteDayFromDateValue(ts1.getDateValue());
long nanos = (ts0.getTimeNanos() + ts1.getTimeNanos()) / 2;
if ((dateSum & 1) != 0) {
nanos += DateTimeUtils.NANOS_PER_DAY / 2;
if (nanos >= DateTimeUtils.NANOS_PER_DAY) {
nanos -= DateTimeUtils.NANOS_PER_DAY;
dateSum++;
}
}
return ValueTimestampTimeZone.fromDateValueAndNanos(DateTimeUtils.dateValueFromAbsoluteDay(dateSum / 2),
nanos, (short) ((ts0.getTimeZoneOffsetMins() + ts1.getTimeZoneOffsetMins()) / 2));
}
default:
// Just return first
return v0.convertTo(dataType);
}
}
}
...@@ -42,7 +42,10 @@ public class DateTimeUtils { ...@@ -42,7 +42,10 @@ public class DateTimeUtils {
*/ */
public static final TimeZone UTC = TimeZone.getTimeZone("UTC"); public static final TimeZone UTC = TimeZone.getTimeZone("UTC");
private static final long NANOS_PER_DAY = MILLIS_PER_DAY * 1000000; /**
* The number of nanoseconds per day.
*/
public static final long NANOS_PER_DAY = MILLIS_PER_DAY * 1000000;
private static final int SHIFT_YEAR = 9; private static final int SHIFT_YEAR = 9;
private static final int SHIFT_MONTH = 5; private static final int SHIFT_MONTH = 5;
......
...@@ -740,12 +740,12 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -740,12 +740,12 @@ public class TestFunctions extends TestBase implements AggregateFunction {
deleteDb("functions"); deleteDb("functions");
Connection conn = getConnection("functions"); Connection conn = getConnection("functions");
Statement stat = conn.createStatement(); Statement stat = conn.createStatement();
stat.execute("CREATE AGGREGATE MEDIAN FOR \"" + stat.execute("CREATE AGGREGATE SIMPLE_MEDIAN FOR \"" +
MedianStringType.class.getName() + "\""); MedianStringType.class.getName() + "\"");
stat.execute("CREATE AGGREGATE IF NOT EXISTS MEDIAN FOR \"" + stat.execute("CREATE AGGREGATE IF NOT EXISTS SIMPLE_MEDIAN FOR \"" +
MedianStringType.class.getName() + "\""); MedianStringType.class.getName() + "\"");
ResultSet rs = stat.executeQuery( ResultSet rs = stat.executeQuery(
"SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)"); "SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
rs.next(); rs.next();
assertEquals("5", rs.getString(1)); assertEquals("5", rs.getString(1));
conn.close(); conn.close();
...@@ -756,22 +756,22 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -756,22 +756,22 @@ public class TestFunctions extends TestBase implements AggregateFunction {
conn = getConnection("functions"); conn = getConnection("functions");
stat = conn.createStatement(); stat = conn.createStatement();
stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)"); stat.executeQuery("SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
DatabaseMetaData meta = conn.getMetaData(); DatabaseMetaData meta = conn.getMetaData();
rs = meta.getProcedures(null, null, "MEDIAN"); rs = meta.getProcedures(null, null, "SIMPLE_MEDIAN");
assertTrue(rs.next()); assertTrue(rs.next());
assertFalse(rs.next()); assertFalse(rs.next());
rs = stat.executeQuery("SCRIPT"); rs = stat.executeQuery("SCRIPT");
boolean found = false; boolean found = false;
while (rs.next()) { while (rs.next()) {
String sql = rs.getString(1); String sql = rs.getString(1);
if (sql.contains("MEDIAN")) { if (sql.contains("SIMPLE_MEDIAN")) {
found = true; found = true;
} }
} }
assertTrue(found); assertTrue(found);
stat.execute("DROP AGGREGATE MEDIAN"); stat.execute("DROP AGGREGATE SIMPLE_MEDIAN");
stat.execute("DROP AGGREGATE IF EXISTS MEDIAN"); stat.execute("DROP AGGREGATE IF EXISTS SIMPLE_MEDIAN");
conn.close(); conn.close();
} }
...@@ -779,12 +779,12 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -779,12 +779,12 @@ public class TestFunctions extends TestBase implements AggregateFunction {
deleteDb("functions"); deleteDb("functions");
Connection conn = getConnection("functions"); Connection conn = getConnection("functions");
Statement stat = conn.createStatement(); Statement stat = conn.createStatement();
stat.execute("CREATE AGGREGATE MEDIAN FOR \"" + stat.execute("CREATE AGGREGATE SIMPLE_MEDIAN FOR \"" +
MedianString.class.getName() + "\""); MedianString.class.getName() + "\"");
stat.execute("CREATE AGGREGATE IF NOT EXISTS MEDIAN FOR \"" + stat.execute("CREATE AGGREGATE IF NOT EXISTS SIMPLE_MEDIAN FOR \"" +
MedianString.class.getName() + "\""); MedianString.class.getName() + "\"");
ResultSet rs = stat.executeQuery( ResultSet rs = stat.executeQuery(
"SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)"); "SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
rs.next(); rs.next();
assertEquals("5", rs.getString(1)); assertEquals("5", rs.getString(1));
conn.close(); conn.close();
...@@ -795,22 +795,22 @@ public class TestFunctions extends TestBase implements AggregateFunction { ...@@ -795,22 +795,22 @@ public class TestFunctions extends TestBase implements AggregateFunction {
conn = getConnection("functions"); conn = getConnection("functions");
stat = conn.createStatement(); stat = conn.createStatement();
stat.executeQuery("SELECT MEDIAN(X) FROM SYSTEM_RANGE(1, 9)"); stat.executeQuery("SELECT SIMPLE_MEDIAN(X) FROM SYSTEM_RANGE(1, 9)");
DatabaseMetaData meta = conn.getMetaData(); DatabaseMetaData meta = conn.getMetaData();
rs = meta.getProcedures(null, null, "MEDIAN"); rs = meta.getProcedures(null, null, "SIMPLE_MEDIAN");
assertTrue(rs.next()); assertTrue(rs.next());
assertFalse(rs.next()); assertFalse(rs.next());
rs = stat.executeQuery("SCRIPT"); rs = stat.executeQuery("SCRIPT");
boolean found = false; boolean found = false;
while (rs.next()) { while (rs.next()) {
String sql = rs.getString(1); String sql = rs.getString(1);
if (sql.contains("MEDIAN")) { if (sql.contains("SIMPLE_MEDIAN")) {
found = true; found = true;
} }
} }
assertTrue(found); assertTrue(found);
stat.execute("DROP AGGREGATE MEDIAN"); stat.execute("DROP AGGREGATE SIMPLE_MEDIAN");
stat.execute("DROP AGGREGATE IF EXISTS MEDIAN"); stat.execute("DROP AGGREGATE IF EXISTS SIMPLE_MEDIAN");
conn.close(); conn.close();
} }
......
...@@ -99,7 +99,7 @@ public class TestScript extends TestBase { ...@@ -99,7 +99,7 @@ public class TestScript extends TestBase {
testScript("datatypes/" + s + ".sql"); testScript("datatypes/" + s + ".sql");
} }
for (String s : new String[] { "avg", "bit-and", "bit-or", "count", for (String s : new String[] { "avg", "bit-and", "bit-or", "count",
"group-concat", "max", "min", "selectivity", "stddev-pop", "group-concat", "max", "median", "min", "selectivity", "stddev-pop",
"stddev-samp", "sum", "var-pop", "var-samp" }) { "stddev-samp", "sum", "var-pop", "var-samp" }) {
testScript("functions/aggregate/" + s + ".sql"); testScript("functions/aggregate/" + s + ".sql");
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论