提交 74462788 authored 作者: Thomas Mueller's avatar Thomas Mueller

A minimal perfect hash function tool: use universal hashing callback (with sample implementations)

上级 912789cb
...@@ -5,11 +5,15 @@ ...@@ -5,11 +5,15 @@
*/ */
package org.h2.test.unit; package org.h2.test.unit;
import java.util.BitSet;
import java.util.HashSet; import java.util.HashSet;
import java.util.Random; import java.util.Random;
import java.util.Set; import java.util.Set;
import org.h2.dev.hash.MinimalPerfectHash; import org.h2.dev.hash.MinimalPerfectHash;
import org.h2.dev.hash.MinimalPerfectHash.LongHash;
import org.h2.dev.hash.MinimalPerfectHash.StringHash;
import org.h2.dev.hash.MinimalPerfectHash.UniversalHash;
import org.h2.dev.hash.PerfectHash; import org.h2.dev.hash.PerfectHash;
import org.h2.test.TestBase; import org.h2.test.TestBase;
...@@ -25,8 +29,8 @@ public class TestPerfectHash extends TestBase { ...@@ -25,8 +29,8 @@ public class TestPerfectHash extends TestBase {
*/ */
public static void main(String... a) throws Exception { public static void main(String... a) throws Exception {
TestPerfectHash test = (TestPerfectHash) TestBase.createCaller().init(); TestPerfectHash test = (TestPerfectHash) TestBase.createCaller().init();
test.test();
test.measure(); test.measure();
test.test();
} }
/** /**
...@@ -34,9 +38,16 @@ public class TestPerfectHash extends TestBase { ...@@ -34,9 +38,16 @@ public class TestPerfectHash extends TestBase {
*/ */
public void measure() { public void measure() {
int size = 1000000; int size = 1000000;
int s;
int s = testMinimal(size); long time = System.currentTimeMillis();
System.out.println((double) s / size + " bits/key (minimal)"); s = testMinimal(size);
time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (minimal) in " + time + " ms");
time = System.currentTimeMillis();
s = testMinimalWithString(size);
time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (minimal; String keys) in " + time + " ms");
s = test(size, true); s = test(size, true);
System.out.println((double) s / size + " bits/key (minimal old)"); System.out.println((double) s / size + " bits/key (minimal old)");
...@@ -97,27 +108,41 @@ public class TestPerfectHash extends TestBase { ...@@ -97,27 +108,41 @@ public class TestPerfectHash extends TestBase {
private int testMinimal(int size) { private int testMinimal(int size) {
Random r = new Random(size); Random r = new Random(size);
HashSet<Integer> set = new HashSet<Integer>(); HashSet<Long> set = new HashSet<Long>();
while (set.size() < size) { while (set.size() < size) {
set.add(r.nextInt()); set.add((long) r.nextInt());
} }
byte[] desc = MinimalPerfectHash.generate(set); LongHash hf = new LongHash();
int max = testMinimal(desc, set); byte[] desc = MinimalPerfectHash.generate(set, hf);
int max = testMinimal(desc, set, hf);
assertEquals(size - 1, max);
return desc.length * 8;
}
private int testMinimalWithString(int size) {
Random r = new Random(size);
HashSet<String> set = new HashSet<String>();
while (set.size() < size) {
set.add("x " + r.nextDouble());
}
StringHash hf = new StringHash();
byte[] desc = MinimalPerfectHash.generate(set, hf);
int max = testMinimal(desc, set, hf);
assertEquals(size - 1, max); assertEquals(size - 1, max);
return desc.length * 8; return desc.length * 8;
} }
private int testMinimal(byte[] desc, Set<Integer> set) { private <K> int testMinimal(byte[] desc, Set<K> set, UniversalHash<K> hf) {
int max = -1; int max = -1;
HashSet<Integer> test = new HashSet<Integer>(); BitSet test = new BitSet();
MinimalPerfectHash hash = new MinimalPerfectHash(desc); MinimalPerfectHash<K> hash = new MinimalPerfectHash<K>(desc, hf);
for (int x : set) { for (K x : set) {
int h = hash.get(x); int h = hash.get(x);
assertTrue(h >= 0); assertTrue(h >= 0);
assertTrue(h <= set.size() * 3); assertTrue(h <= set.size() * 3);
max = Math.max(max, h); max = Math.max(max, h);
assertFalse(test.contains(h)); assertFalse(test.get(h));
test.add(h); test.set(h);
} }
return max; return max;
} }
......
...@@ -7,6 +7,7 @@ package org.h2.dev.hash; ...@@ -7,6 +7,7 @@ package org.h2.dev.hash;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.Charset;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Set; import java.util.Set;
import java.util.zip.Deflater; import java.util.zip.Deflater;
...@@ -26,8 +27,8 @@ import java.util.zip.Inflater; ...@@ -26,8 +27,8 @@ import java.util.zip.Inflater;
* At the end of the generation process, the data is compressed using a general * At the end of the generation process, the data is compressed using a general
* purpose compression tool (Deflate / Huffman coding) down to 2.0 bits per key. * purpose compression tool (Deflate / Huffman coding) down to 2.0 bits per key.
* The uncompressed data is around 2.2 bits per key. With arithmetic coding, * The uncompressed data is around 2.2 bits per key. With arithmetic coding,
* about 1.9 bits per key are needed. Generating the hash function takes about * about 1.9 bits per key are needed. Generating the hash function takes about 4
* 2.5 second per million keys with 8 cores (multithreaded). At the expense of * second per million keys with 8 cores (multithreaded). At the expense of
* processing time, a lower number of bits per key would be possible (for * processing time, a lower number of bits per key would be possible (for
* example 1.85 bits per key with 33000 keys, using 10 seconds generation time, * example 1.85 bits per key with 33000 keys, using 10 seconds generation time,
* with Huffman coding). The algorithm automatically scales with the number of * with Huffman coding). The algorithm automatically scales with the number of
...@@ -37,10 +38,18 @@ import java.util.zip.Inflater; ...@@ -37,10 +38,18 @@ import java.util.zip.Inflater;
* key (the space needed for the uncompressed description, plus 8 bytes for * key (the space needed for the uncompressed description, plus 8 bytes for
* every top-level bucket). * every top-level bucket).
* <p> * <p>
* To protect against hash flooding and similar attacks, cryptographically
* secure functions such as SipHash or SHA-256 can be used. However, such slower
* functions only need to be used in higher recursions levels, so that in the
* normal case (where no attack is happening), only fast, but less secure, hash
* functions are needed.
* <p>
* In-place updating of the hash table is not implemented but possible in * In-place updating of the hash table is not implemented but possible in
* theory, by patching the hash function description. * theory, by patching the hash function description.
*
* @param <K> the key type
*/ */
public class MinimalPerfectHash { public class MinimalPerfectHash<K> {
/** /**
* Large buckets are typically divided into buckets of this size. * Large buckets are typically divided into buckets of this size.
...@@ -79,6 +88,11 @@ public class MinimalPerfectHash { ...@@ -79,6 +88,11 @@ public class MinimalPerfectHash {
} }
SIZE_OFFSETS[SIZE_OFFSETS.length - 1] = last; SIZE_OFFSETS[SIZE_OFFSETS.length - 1] = last;
} }
/**
* The universal hash function.
*/
private final UniversalHash<K> hash;
/** /**
* The description of the hash function. Used for calculating the hash of a * The description of the hash function. Used for calculating the hash of a
...@@ -103,7 +117,8 @@ public class MinimalPerfectHash { ...@@ -103,7 +117,8 @@ public class MinimalPerfectHash {
* *
* @param desc the data returned by the generate method * @param desc the data returned by the generate method
*/ */
public MinimalPerfectHash(byte[] desc) { public MinimalPerfectHash(byte[] desc, UniversalHash<K> hash) {
this.hash = hash;
byte[] b = data = expand(desc); byte[] b = data = expand(desc);
if (b[0] == SPLIT_MANY) { if (b[0] == SPLIT_MANY) {
int split = readVarInt(b, 1); int split = readVarInt(b, 1);
...@@ -130,7 +145,7 @@ public class MinimalPerfectHash { ...@@ -130,7 +145,7 @@ public class MinimalPerfectHash {
* @param x the key * @param x the key
* @return the hash value * @return the hash value
*/ */
public int get(int x) { public int get(K x) {
return get(0, x, 0); return get(0, x, 0);
} }
...@@ -142,14 +157,14 @@ public class MinimalPerfectHash { ...@@ -142,14 +157,14 @@ public class MinimalPerfectHash {
* @param level the level * @param level the level
* @return the hash value * @return the hash value
*/ */
private int get(int pos, int x, int level) { private int get(int pos, K x, int level) {
int n = readVarInt(data, pos); int n = readVarInt(data, pos);
if (n < 2) { if (n < 2) {
return 0; return 0;
} else if (n > SPLIT_MANY) { } else if (n > SPLIT_MANY) {
int size = getSize(n); int size = getSize(n);
int offset = getOffset(n, size); int offset = getOffset(n, size);
return hash(x, level, offset, size); return hash(x, hash, level, offset, size);
} }
pos++; pos++;
int split; int split;
...@@ -159,7 +174,7 @@ public class MinimalPerfectHash { ...@@ -159,7 +174,7 @@ public class MinimalPerfectHash {
} else { } else {
split = n; split = n;
} }
int h = hash(x, level, 0, split); int h = hash(x, hash, level, 0, split);
int s; int s;
if (level == 0 && topPos != null) { if (level == 0 && topPos != null) {
s = topSize[h]; s = topSize[h];
...@@ -247,11 +262,11 @@ public class MinimalPerfectHash { ...@@ -247,11 +262,11 @@ public class MinimalPerfectHash {
* @param set the data * @param set the data
* @return the hash function description * @return the hash function description
*/ */
public static byte[] generate(Set<Integer> set) { public static <K> byte[] generate(Set<K> set, UniversalHash<K> hash) {
ArrayList<Integer> list = new ArrayList<Integer>(); ArrayList<K> list = new ArrayList<K>();
list.addAll(set); list.addAll(set);
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
generate(list, 0, out); generate(list, hash, 0, out);
return compress(out.toByteArray()); return compress(out.toByteArray());
} }
...@@ -262,8 +277,8 @@ public class MinimalPerfectHash { ...@@ -262,8 +277,8 @@ public class MinimalPerfectHash {
* @param level the recursion level * @param level the recursion level
* @param out the output stream * @param out the output stream
*/ */
static void generate(ArrayList<Integer> list, int level, static <K> void generate(ArrayList<K> list, UniversalHash<K> hash,
ByteArrayOutputStream out) { int level, ByteArrayOutputStream out) {
int size = list.size(); int size = list.size();
if (size <= 1) { if (size <= 1) {
writeVarInt(out, size); writeVarInt(out, size);
...@@ -271,11 +286,15 @@ public class MinimalPerfectHash { ...@@ -271,11 +286,15 @@ public class MinimalPerfectHash {
} }
if (size <= MAX_SIZE) { if (size <= MAX_SIZE) {
int maxOffset = MAX_OFFSETS[size]; int maxOffset = MAX_OFFSETS[size];
int[] hashes = new int[size];
for (int i = 0; i < size; i++) {
hashes[i] = hash.hashCode(list.get(i), level);
}
nextOffset: nextOffset:
for (int offset = 0; offset < maxOffset; offset++) { for (int offset = 0; offset < maxOffset; offset++) {
int bits = 0; int bits = 0;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
int x = list.get(i); int x = hashes[i];
int h = hash(x, level, offset, size); int h = hash(x, level, offset, size);
if ((bits & (1 << h)) != 0) { if ((bits & (1 << h)) != 0) {
continue nextOffset; continue nextOffset;
...@@ -297,29 +316,30 @@ public class MinimalPerfectHash { ...@@ -297,29 +316,30 @@ public class MinimalPerfectHash {
writeVarInt(out, SPLIT_MANY); writeVarInt(out, SPLIT_MANY);
} }
writeVarInt(out, split); writeVarInt(out, split);
ArrayList<ArrayList<Integer>> lists = ArrayList<ArrayList<K>> lists =
new ArrayList<ArrayList<Integer>>(split); new ArrayList<ArrayList<K>>(split);
for (int i = 0; i < split; i++) { for (int i = 0; i < split; i++) {
lists.add(new ArrayList<Integer>(size / split)); lists.add(new ArrayList<K>(size / split));
} }
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
int x = list.get(i); K x = list.get(i);
lists.get(hash(x, level, 0, split)).add(x); lists.get(hash(x, hash, level, 0, split)).add(x);
} }
boolean multiThreaded = level == 0 && list.size() > 1000; boolean multiThreaded = level == 0 && list.size() > 1000;
list.clear(); list.clear();
list.trimToSize(); list.trimToSize();
if (multiThreaded) { if (multiThreaded) {
generateMultiThreaded(lists, out); generateMultiThreaded(lists, hash, out);
} else { } else {
for (ArrayList<Integer> s2 : lists) { for (ArrayList<K> s2 : lists) {
generate(s2, level + 1, out); generate(s2, hash, level + 1, out);
} }
} }
} }
private static void generateMultiThreaded( private static <K> void generateMultiThreaded(
final ArrayList<ArrayList<Integer>> lists, final ArrayList<ArrayList<K>> lists,
final UniversalHash<K> hash,
ByteArrayOutputStream out) { ByteArrayOutputStream out) {
final ArrayList<ByteArrayOutputStream> outList = final ArrayList<ByteArrayOutputStream> outList =
new ArrayList<ByteArrayOutputStream>(); new ArrayList<ByteArrayOutputStream>();
...@@ -330,7 +350,7 @@ public class MinimalPerfectHash { ...@@ -330,7 +350,7 @@ public class MinimalPerfectHash {
@Override @Override
public void run() { public void run() {
while (true) { while (true) {
ArrayList<Integer> list; ArrayList<K> list;
ByteArrayOutputStream temp = ByteArrayOutputStream temp =
new ByteArrayOutputStream(); new ByteArrayOutputStream();
synchronized (lists) { synchronized (lists) {
...@@ -340,7 +360,7 @@ public class MinimalPerfectHash { ...@@ -340,7 +360,7 @@ public class MinimalPerfectHash {
list = lists.remove(0); list = lists.remove(0);
outList.add(temp); outList.add(temp);
} }
generate(list, 1, temp); generate(list, hash, 1, temp);
} }
} }
}; };
...@@ -366,13 +386,22 @@ public class MinimalPerfectHash { ...@@ -366,13 +386,22 @@ public class MinimalPerfectHash {
* Calculate the hash of a key. The result depends on the key, the recursion * Calculate the hash of a key. The result depends on the key, the recursion
* level, and the offset. * level, and the offset.
* *
* @param x the key * @param o the key
* @param level the recursion level * @param level the recursion level
* @param offset the index of the hash function * @param offset the index of the hash function
* @param size the size of the bucket * @param size the size of the bucket
* @return the hash (a value between 0, including, and the size, excluding) * @return the hash (a value between 0, including, and the size, excluding)
*/ */
private static int hash(int x, int level, int offset, int size) { private static <K> int hash(K o, UniversalHash<K> hash, int level, int offset, int size) {
int x = hash.hashCode(o, level);
x += level + offset * 16;
x = ((x >>> 16) ^ x) * 0x45d9f3b;
x = ((x >>> 16) ^ x) * 0x45d9f3b;
x = (x >>> 16) ^ x;
return Math.abs(x % size);
}
private static <K> int hash(int x, int level, int offset, int size) {
x += level + offset * 16; x += level + offset * 16;
x = ((x >>> 16) ^ x) * 0x45d9f3b; x = ((x >>> 16) ^ x) * 0x45d9f3b;
x = ((x >>> 16) ^ x) * 0x45d9f3b; x = ((x >>> 16) ^ x) * 0x45d9f3b;
...@@ -466,5 +495,142 @@ public class MinimalPerfectHash { ...@@ -466,5 +495,142 @@ public class MinimalPerfectHash {
} }
return out.toByteArray(); return out.toByteArray();
} }
/**
* An interface that can calculate multiple hash values for an object. The
* returned hash value of two distinct objects may be the same for a given
* hash function index, but as more hash functions indexes are called for
* those objects, the returned value must eventually be different.
* <p>
* The returned value does not need to be uniformly distributed.
*
* @param <T> the type
*/
public interface UniversalHash<T> {
/**
* Calculate the hash of the given object.
*
* @param o the object
* @param index the hash function index (index 0 is used first, so the
* method should be very fast with index 0; index 1 and so on
* are only called when really needed)
* @return the hash value
*/
int hashCode(T o, int index);
}
/**
* A sample hash implementation for long keys.
*/
public static class LongHash implements UniversalHash<Long> {
@Override
public int hashCode(Long o, int index) {
if (index == 0) {
return o.hashCode();
} else if (index < 8) {
long x = o.longValue();
x += index;
x = ((x >>> 32) ^ x) * 0x45d9f3b;
x = ((x >>> 32) ^ x) * 0x45d9f3b;
return (int) (x ^ (x >>> 32));
}
// get the lower or higher 32 bit depending on the index
int shift = (index & 1) * 32;
return (int) (o.longValue() >>> shift);
}
}
/**
* A sample hash implementation for integer keys.
*/
public static class StringHash implements UniversalHash<String> {
private static final Charset UTF8 = Charset.forName("UTF-8");
@Override
public int hashCode(String o, int index) {
if (index == 0) {
// use the default hash of a string, which might already be
// available
return o.hashCode();
} else if (index < 8) {
// use a different hash function, which is fast but not
// cryptographically secure
return getFastHash(o, index);
}
// this method is supposed to be cryptographically secure;
// we could use SHA-256 for higher indexes
return getSipHash24(o, index, 0);
}
public static int getFastHash(String o, int x) {
int result = o.length();
for (int i = 0; i < o.length(); i++) {
x = 31 + ((x >>> 16) ^ x) * 0x45d9f3b;
result += x * (1 + o.charAt(i));
}
return result;
}
/**
* A cryptographically relatively secure hash function. It is supposed
* to protected against hash-flooding denial-of-service attacks.
*
* @param o the object
* @param k0 key 0
* @param k1 key 1
* @return the hash value
*/
private static int getSipHash24(String o, long k0, long k1) {
long v0 = k0 ^ 0x736f6d6570736575L;
long v1 = k1 ^ 0x646f72616e646f6dL;
long v2 = k0 ^ 0x6c7967656e657261L;
long v3 = k1 ^ 0x7465646279746573L;
byte[] b = o.getBytes(UTF8);
int len = b.length, repeat;
for (int off = 0; off <= len + 8; off += 8) {
long m;
if (off <= len) {
m = 0;
int i = 0;
for (; i < 8 && off + i < len; i++) {
m |= ((long) b[off + i] & 255) << (8 * i);
}
if (i < 8) {
m |= ((long) b.length) << 56;
}
v3 ^= m;
repeat = 2;
} else {
m = 0;
v2 ^= 0xff;
repeat = 4;
}
for (int i = 0; i < repeat; i++) {
v0 += v1;
v2 += v3;
v1 = Long.rotateLeft(v1, 13);
v3 = Long.rotateLeft(v3, 16);
v1 ^= v0;
v3 ^= v2;
v0 = Long.rotateLeft(v0, 32);
v2 += v1;
v0 += v3;
v1 = Long.rotateLeft(v1, 17);
v3 = Long.rotateLeft(v3, 21);
v1 ^= v2;
v3 ^= v0;
v2 = Long.rotateLeft(v2, 32);
}
v0 ^= m;
}
return (int) (v0 ^ v1 ^ v2 ^ v3);
}
}
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论