提交 83f041ba authored 作者: Thomas Mueller's avatar Thomas Mueller

Arithmetic and ANS compression (for the minimum perfect hash function)

上级 677a9793
......@@ -149,7 +149,9 @@ import org.h2.test.synth.TestRandomSQL;
import org.h2.test.synth.TestTimer;
import org.h2.test.synth.sql.TestSynth;
import org.h2.test.synth.thread.TestMulti;
import org.h2.test.unit.TestAnsCompression;
import org.h2.test.unit.TestAutoReconnect;
import org.h2.test.unit.TestBinaryArithmeticStream;
import org.h2.test.unit.TestBitField;
import org.h2.test.unit.TestBitStream;
import org.h2.test.unit.TestBnf;
......@@ -792,7 +794,9 @@ kill -9 `jps -l | grep "org.h2.test." | cut -d " " -f 1`
addTest(new TestTransactionStore());
// unit
addTest(new TestAnsCompression());
addTest(new TestAutoReconnect());
addTest(new TestBinaryArithmeticStream());
addTest(new TestBitField());
addTest(new TestBitStream());
addTest(new TestBnf());
......
/*
* Copyright 2004-2014 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.test.unit;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.h2.dev.util.AnsCompression;
import org.h2.dev.util.BinaryArithmeticStream;
import org.h2.dev.util.BitStream;
import org.h2.test.TestBase;
/**
* Tests the ANS (Asymmetric Numeral Systems) compression tool.
*/
public class TestAnsCompression extends TestBase {
/**
* Run just this test.
*
* @param a ignored
*/
public static void main(String... a) throws Exception {
TestBase.createCaller().init().test();
}
@Override
public void test() throws Exception {
testScaleFrequencies();
testRandomized();
testCompressionRate();
}
private void testCompressionRate() throws IOException {
byte[] data = new byte[1024 * 1024];
Random r = new Random(1);
for (int i = 0; i < data.length; i++) {
data[i] = (byte) (r.nextInt(4) * r.nextInt(4));
}
int[] freq = new int[256];
AnsCompression.countFrequencies(freq, data);
int lenAns = AnsCompression.encode(freq, data).length;
BitStream.Huffman huff = new BitStream.Huffman(freq);
ByteArrayOutputStream out = new ByteArrayOutputStream();
BitStream.Out o = new BitStream.Out(out);
for (byte x : data) {
huff.write(o, x & 255);
}
o.flush();
int lenHuff = out.toByteArray().length;
BinaryArithmeticStream.Huffman aHuff = new BinaryArithmeticStream.Huffman(
freq);
out = new ByteArrayOutputStream();
BinaryArithmeticStream.Out o2 = new BinaryArithmeticStream.Out(out);
for (byte x : data) {
aHuff.write(o2, x & 255);
}
o2.flush();
int lenArithmetic = out.toByteArray().length;
assertTrue(lenAns < lenArithmetic);
assertTrue(lenArithmetic < lenHuff);
assertTrue(lenHuff < data.length);
}
private void testScaleFrequencies() {
Random r = new Random(1);
for (int j = 0; j < 100; j++) {
int symbolCount = r.nextInt(200) + 1;
int[] freq = new int[symbolCount];
for (int total = symbolCount * 2; total < 10000; total *= 2) {
for (int i = 0; i < freq.length; i++) {
freq[i] = r.nextInt(1000) + 1;
}
AnsCompression.scaleFrequencies(freq, total);
}
}
int[] freq = new int[]{0, 1, 1, 1000};
AnsCompression.scaleFrequencies(freq, 100);
assertEquals("[0, 1, 1, 98]", Arrays.toString(freq));
}
private void testRandomized() {
Random r = new Random(1);
int symbolCount = r.nextInt(200) + 1;
int[] freq = new int[symbolCount];
for (int i = 0; i < freq.length; i++) {
freq[i] = r.nextInt(1000) + 1;
}
int seed = r.nextInt();
r.setSeed(seed);
int len = 10000;
byte[] data = new byte[len];
r.nextBytes(data);
freq = new int[256];
AnsCompression.countFrequencies(freq, data);
byte[] encoded = AnsCompression.encode(freq, data);
byte[] decoded = AnsCompression.decode(freq, encoded, data.length);
for (int i = 0; i < len; i++) {
int expected = data[i];
assertEquals(expected, decoded[i]);
}
}
}
/*
* Copyright 2004-2014 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.test.unit;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Random;
import org.h2.dev.util.BinaryArithmeticStream;
import org.h2.dev.util.BinaryArithmeticStream.Huffman;
import org.h2.dev.util.BinaryArithmeticStream.In;
import org.h2.dev.util.BinaryArithmeticStream.Out;
import org.h2.dev.util.BitStream;
import org.h2.test.TestBase;
/**
* Test the binary arithmetic stream utility.
*/
public class TestBinaryArithmeticStream extends TestBase {
/**
* Run just this test.
*
* @param a ignored
*/
public static void main(String... a) throws Exception {
TestBase.createCaller().init().test();
}
@Override
public void test() throws Exception {
testCompareWithHuffman();
testHuffmanRandomized();
testCompressionRatio();
testRandomized();
testPerformance();
}
private void testCompareWithHuffman() throws IOException {
Random r = new Random(1);
for (int test = 0; test < 10; test++) {
int[] freq = new int[4];
for (int i = 0; i < freq.length; i++) {
freq[i] = 0 + r.nextInt(1000);
}
BinaryArithmeticStream.Huffman ah = new BinaryArithmeticStream.Huffman(
freq);
BitStream.Huffman hh = new BitStream.Huffman(freq);
ByteArrayOutputStream hbOut = new ByteArrayOutputStream();
ByteArrayOutputStream abOut = new ByteArrayOutputStream();
BitStream.Out bOut = new BitStream.Out(hbOut);
BinaryArithmeticStream.Out aOut = new BinaryArithmeticStream.Out(abOut);
for (int i = 0; i < freq.length; i++) {
for (int j = 0; j < freq[i]; j++) {
int x = i;
hh.write(bOut, x);
ah.write(aOut, x);
}
}
assertTrue(hbOut.toByteArray().length >= abOut.toByteArray().length);
}
}
private void testHuffmanRandomized() throws IOException {
Random r = new Random(1);
int[] freq = new int[r.nextInt(200) + 1];
for (int i = 0; i < freq.length; i++) {
freq[i] = r.nextInt(1000) + 1;
}
int seed = r.nextInt();
r.setSeed(seed);
Huffman huff = new Huffman(freq);
ByteArrayOutputStream byteOut = new ByteArrayOutputStream();
Out out = new Out(byteOut);
for (int i = 0; i < 10000; i++) {
huff.write(out, r.nextInt(freq.length));
}
out.flush();
In in = new In(new ByteArrayInputStream(byteOut.toByteArray()));
r.setSeed(seed);
for (int i = 0; i < 10000; i++) {
int expected = r.nextInt(freq.length);
int got = huff.read(in);
assertEquals(expected, got);
}
}
private void testPerformance() throws IOException {
Random r = new Random();
// long time = System.currentTimeMillis();
// Profiler prof = new Profiler().startCollecting();
for (int seed = 0; seed < 10000; seed++) {
r.setSeed(seed);
ByteArrayOutputStream byteOut = new ByteArrayOutputStream();
Out out = new Out(byteOut);
int len = 100;
for (int i = 0; i < len; i++) {
boolean v = r.nextBoolean();
int prob = r.nextInt(BinaryArithmeticStream.MAX_PROBABILITY);
out.writeBit(v, prob);
}
out.flush();
r.setSeed(seed);
ByteArrayInputStream byteIn = new ByteArrayInputStream(
byteOut.toByteArray());
In in = new In(byteIn);
for (int i = 0; i < len; i++) {
boolean expected = r.nextBoolean();
int prob = r.nextInt(BinaryArithmeticStream.MAX_PROBABILITY);
assertEquals(expected, in.readBit(prob));
}
}
// time = System.currentTimeMillis() - time;
// System.out.println("time: " + time);
// System.out.println(prof.getTop(5));
}
private void testCompressionRatio() throws IOException {
ByteArrayOutputStream byteOut = new ByteArrayOutputStream();
Out out = new Out(byteOut);
int prob = 1000;
int len = 1024;
for (int i = 0; i < len; i++) {
out.writeBit(true, prob);
}
out.flush();
ByteArrayInputStream byteIn = new ByteArrayInputStream(
byteOut.toByteArray());
In in = new In(byteIn);
for (int i = 0; i < len; i++) {
assertTrue(in.readBit(prob));
}
// System.out.println(len / 8 + " comp: " +
// byteOut.toByteArray().length);
}
private void testRandomized() throws IOException {
for (int i = 0; i < 10000; i = (int) ((i + 10) * 1.1)) {
testRandomized(i);
}
}
private void testRandomized(int len) throws IOException {
Random r = new Random();
int seed = r.nextInt();
r.setSeed(seed);
ByteArrayOutputStream byteOut = new ByteArrayOutputStream();
Out out = new Out(byteOut);
for (int i = 0; i < len; i++) {
int prob = r.nextInt(BinaryArithmeticStream.MAX_PROBABILITY);
out.writeBit(r.nextBoolean(), prob);
}
out.flush();
byteOut.write(r.nextInt(255));
ByteArrayInputStream byteIn = new ByteArrayInputStream(
byteOut.toByteArray());
In in = new In(byteIn);
r.setSeed(seed);
for (int i = 0; i < len; i++) {
int prob = r.nextInt(BinaryArithmeticStream.MAX_PROBABILITY);
boolean expected = r.nextBoolean();
boolean got = in.readBit(prob);
assertEquals(expected, got);
}
assertEquals(r.nextInt(255), byteIn.read());
}
}
/*
* Copyright 2004-2014 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.dev.util;
import java.nio.ByteBuffer;
import java.util.Arrays;
/**
* An ANS (Asymmetric Numeral Systems) compression tool.
* It uses the range variant.
*/
public class AnsCompression {
private static final long TOP = 1L << 24;
private static final int SHIFT = 12;
private static final int MASK = (1 << SHIFT) - 1;
private static final long MAX = (TOP >> SHIFT) << 32;
private AnsCompression() {
// a utility class
}
/**
* Count the frequencies of codes in the data, and increment the target
* frequency table.
*
* @param freq the target frequency table
* @param data the data
*/
public static void countFrequencies(int[] freq, byte[] data) {
for (byte x : data) {
freq[x & 0xff]++;
}
}
/**
* Scale the frequencies to a new total. Frequencies of 0 are kept as 0;
* larger frequencies result in at least 1.
*
* @param freq the (source and target) frequency table
* @param total the target total (sum of all frequencies)
*/
public static void scaleFrequencies(int[] freq, int total) {
int len = freq.length, sum = 0;
for (int x : freq) {
sum += x;
}
// the list of: (error << 8) + index
int[] errors = new int[len];
int totalError = -total;
for (int i = 0; i < len; i++) {
int old = freq[i];
if (old == 0) {
continue;
}
int ideal = (int) (old * total * 256L / sum);
// 1 too high so we can decrement if needed
int x = 1 + ideal / 256;
freq[i] = x;
totalError += x;
errors[i] = ((x * 256 - ideal) << 8) + i;
}
// we don't need to sort, we could just calculate
// which one is the nth element - but sorting is simpler
Arrays.sort(errors);
if (totalError < 0) {
// integer overflow
throw new IllegalArgumentException();
}
while (totalError > 0) {
for (int i = 0; totalError > 0 && i < len; i++) {
int index = errors[i] & 0xff;
if (freq[index] > 1) {
freq[index]--;
totalError--;
}
}
}
}
/**
* Generate the cumulative frequency table.
*
* @param freq the source frequency table
* @return the cumulative table, with one entry more
*/
static int[] generateCumulativeFrequencies(int[] freq) {
int len = freq.length;
int[] cumulativeFreq = new int[len + 1];
for (int i = 0, x = 0; i < len; i++) {
x += freq[i];
cumulativeFreq[i + 1] = x;
}
return cumulativeFreq;
}
/**
* Generate the frequency-to-code table.
*
* @param cumulativeFreq the cumulative frequency table
* @return the result
*/
private static byte[] generateFrequencyToCode(int[] cumulativeFreq) {
byte[] freqToCode = new byte[1 << SHIFT];
int x = 0;
byte s = -1;
for (int i : cumulativeFreq) {
while (x < i) {
freqToCode[x++] = s;
}
s++;
}
return freqToCode;
}
/**
* Encode the data.
*
* @param freq the frequency table (will be scaled)
* @param data the source data (uncompressed)
* @return the compressed data
*/
public static byte[] encode(int[] freq, byte[] data) {
scaleFrequencies(freq, 1 << SHIFT);
int[] cumulativeFreq = generateCumulativeFrequencies(freq);
ByteBuffer buff = ByteBuffer.allocate(data.length * 2);
buff = encode(data, freq, cumulativeFreq, buff);
return Arrays.copyOfRange(buff.array(),
buff.arrayOffset() + buff.position(), buff.arrayOffset() + buff.limit());
}
private static ByteBuffer encode(byte[] data, int[] freq,
int[] cumulativeFreq, ByteBuffer buff) {
long state = TOP;
// encoding happens backwards
int b = buff.limit();
for (int p = data.length - 1; p >= 0; p--) {
int x = data[p] & 0xff;
int f = freq[x];
while (state >= MAX * f) {
b -= 4;
buff.putInt(b, (int) state);
state >>>= 32;
}
state = ((state / f) << SHIFT) + (state % f) + cumulativeFreq[x];
}
b -= 8;
buff.putLong(b, state);
buff.position(b);
return buff.slice();
}
/**
* Decode the data.
*
* @param freq the frequency table (will be scaled)
* @param data the compressed data
* @param length the target length
* @return the uncompressed result
*/
public static byte[] decode(int[] freq, byte[] data, int length) {
scaleFrequencies(freq, 1 << SHIFT);
int[] cumulativeFreq = generateCumulativeFrequencies(freq);
byte[] freqToCode = generateFrequencyToCode(cumulativeFreq);
byte[] out = new byte[length];
decode(data, freq, cumulativeFreq, freqToCode, out);
return out;
}
private static void decode(byte[] data, int[] freq, int[] cumulativeFreq,
byte[] freqToCode, byte[] out) {
ByteBuffer buff = ByteBuffer.wrap(data);
long state = buff.getLong();
for (int i = 0, size = out.length; i < size; i++) {
int x = (int) state & MASK;
int c = freqToCode[x] & 0xff;
out[i] = (byte) c;
state = (freq[c] * (state >> SHIFT)) + x - cumulativeFreq[c];
while (state < TOP) {
state = (state << 32) | (buff.getInt() & 0xffffffffL);
}
}
}
}
/*
* Copyright 2004-2014 H2 Group. Multiple-Licensed under the MPL 2.0,
* and the EPL 1.0 (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.dev.util;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.PriorityQueue;
/**
* A binary arithmetic stream.
*/
public class BinaryArithmeticStream {
/**
* The maximum probability.
*/
public static final int MAX_PROBABILITY = (1 << 12) - 1;
/**
* The low marker.
*/
protected int low;
/**
* The high marker.
*/
protected int high = 0xffffffff;
/**
* A binary arithmetic input stream.
*/
public static class In extends BinaryArithmeticStream {
private final InputStream in;
private int data;
public In(InputStream in) throws IOException {
this.in = in;
data = ((in.read() & 0xff) << 24) |
((in.read() & 0xff) << 16) |
((in.read() & 0xff) << 8) |
(in.read() & 0xff);
}
/**
* Read a bit.
*
* @param probability the probability that the value is true
* @return the value
*/
public boolean readBit(int probability) throws IOException {
int split = low + probability * ((high - low) >>> 12);
boolean value;
// compare unsigned
if (data + Integer.MIN_VALUE > split + Integer.MIN_VALUE) {
low = split + 1;
value = false;
} else {
high = split;
value = true;
}
while (low >>> 24 == high >>> 24) {
data = (data << 8) | (in.read() & 0xff);
low <<= 8;
high = (high << 8) | 0xff;
}
return value;
}
/**
* Read a value that is stored as a Golomb code.
*
* @param divisor the divisor
* @return the value
*/
public int readGolomb(int divisor) throws IOException {
int q = 0;
while (readBit(MAX_PROBABILITY / 2)) {
q++;
}
int bit = 31 - Integer.numberOfLeadingZeros(divisor - 1);
int r = 0;
if (bit >= 0) {
int cutOff = (2 << bit) - divisor;
for (; bit > 0; bit--) {
r = (r << 1) + (readBit(MAX_PROBABILITY / 2) ? 1 : 0);
}
if (r >= cutOff) {
r = (r << 1) + (readBit(MAX_PROBABILITY / 2) ? 1 : 0) - cutOff;
}
}
return q * divisor + r;
}
}
/**
* A binary arithmetic output stream.
*/
public static class Out extends BinaryArithmeticStream {
private final OutputStream out;
public Out(OutputStream out) {
this.out = out;
}
/**
* Write a bit.
*
* @param value the value
* @param probability the probability that the value is true
*/
public void writeBit(boolean value, int probability) throws IOException {
int split = low + probability * ((high - low) >>> 12);
if (value) {
high = split;
} else {
low = split + 1;
}
while (low >>> 24 == high >>> 24) {
out.write(high >> 24);
low <<= 8;
high = (high << 8) | 0xff;
}
}
/**
* Flush the stream.
*/
public void flush() throws IOException {
out.write(high >> 24);
out.write(high >> 16);
out.write(high >> 8);
out.write(high);
}
/**
* Write the Golomb code of a value.
*
* @param divisor the divisor
* @param value the value
*/
public void writeGolomb(int divisor, int value) throws IOException {
int q = value / divisor;
for (int i = 0; i < q; i++) {
writeBit(true, MAX_PROBABILITY / 2);
}
writeBit(false, MAX_PROBABILITY / 2);
int r = value - q * divisor;
int bit = 31 - Integer.numberOfLeadingZeros(divisor - 1);
if (r < ((2 << bit) - divisor)) {
bit--;
} else {
r += (2 << bit) - divisor;
}
for (; bit >= 0; bit--) {
writeBit(((r >>> bit) & 1) == 1, MAX_PROBABILITY / 2);
}
}
}
/**
* A Huffman code table / tree.
*/
public static class Huffman {
private final int[] codes;
private final Node tree;
public Huffman(int[] frequencies) {
PriorityQueue<Node> queue = new PriorityQueue<Node>();
for (int i = 0; i < frequencies.length; i++) {
int f = frequencies[i];
if (f > 0) {
queue.offer(new Node(i, f));
}
}
while (queue.size() > 1) {
queue.offer(new Node(queue.poll(), queue.poll()));
}
codes = new int[frequencies.length];
tree = queue.poll();
if (tree != null) {
tree.initCodes(codes, 1);
}
}
/**
* Write a value.
*
* @param out the output stream
* @param value the value to write
*/
public void write(Out out, int value) throws IOException {
int code = codes[value];
int bitCount = 30 - Integer.numberOfLeadingZeros(code);
Node n = tree;
for (int i = bitCount; i >= 0; i--) {
boolean goRight = ((code >> i) & 1) == 1;
int prob = MAX_PROBABILITY *
n.right.frequency / n.frequency;
out.writeBit(goRight, prob);
n = goRight ? n.right : n.left;
}
}
/**
* Read a value.
*
* @param in the input stream
* @return the value
*/
public int read(In in) throws IOException {
Node n = tree;
while (n.left != null) {
int prob = MAX_PROBABILITY *
n.right.frequency / n.frequency;
boolean goRight = in.readBit(prob);
n = goRight ? n.right : n.left;
}
return n.value;
}
}
/**
* A Huffman code node.
*/
private static class Node implements Comparable<Node> {
int value;
Node left;
Node right;
final int frequency;
Node(int value, int frequency) {
this.frequency = frequency;
this.value = value;
}
Node(Node left, Node right) {
this.left = left;
this.right = right;
this.frequency = left.frequency + right.frequency;
}
@Override
public int compareTo(Node o) {
return frequency - o.frequency;
}
void initCodes(int[] codes, int bits) {
if (left == null) {
codes[value] = bits;
} else {
left.initCodes(codes, bits << 1);
right.initCodes(codes, (bits << 1) + 1);
}
}
}
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论