提交 1b19a776 authored 作者: Thomas Mueller's avatar Thomas Mueller

A minimal perfect hash function tool: use universal hashing callback (protection…

A minimal perfect hash function tool: use universal hashing callback (protection against hash flooding)
上级 05719a66
...@@ -29,8 +29,8 @@ public class TestPerfectHash extends TestBase { ...@@ -29,8 +29,8 @@ public class TestPerfectHash extends TestBase {
*/ */
public static void main(String... a) throws Exception { public static void main(String... a) throws Exception {
TestPerfectHash test = (TestPerfectHash) TestBase.createCaller().init(); TestPerfectHash test = (TestPerfectHash) TestBase.createCaller().init();
test.measure();
test.test(); test.test();
test.measure();
} }
/** /**
...@@ -38,25 +38,36 @@ public class TestPerfectHash extends TestBase { ...@@ -38,25 +38,36 @@ public class TestPerfectHash extends TestBase {
*/ */
public void measure() { public void measure() {
int size = 1000000; int size = 1000000;
testMinimal(size / 10);
int s; int s;
long time = System.currentTimeMillis(); long time = System.currentTimeMillis();
s = testMinimal(size); s = testMinimal(size);
time = System.currentTimeMillis() - time; time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (minimal) in " + time + " ms"); System.out.println((double) s / size + " bits/key (minimal) in " +
time + " ms");
time = System.currentTimeMillis();
s = testMinimalWithString(size); time = System.currentTimeMillis();
time = System.currentTimeMillis() - time; s = testMinimalWithString(size);
System.out.println((double) s / size + " bits/key (minimal; String keys) in " + time + " ms"); time = System.currentTimeMillis() - time;
System.out.println((double) s / size +
" bits/key (minimal; String keys) in " +
time + " ms");
time = System.currentTimeMillis();
s = test(size, true); s = test(size, true);
System.out.println((double) s / size + " bits/key (minimal old)"); time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (minimal old) in " +
time + " ms");
time = System.currentTimeMillis();
s = test(size, false); s = test(size, false);
System.out.println((double) s / size + " bits/key (not minimal)"); time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (not minimal) in " +
time + " ms");
} }
@Override @Override
public void test() { public void test() {
testBrokenHashFunction();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
testMinimal(i); testMinimal(i);
} }
...@@ -72,6 +83,31 @@ public class TestPerfectHash extends TestBase { ...@@ -72,6 +83,31 @@ public class TestPerfectHash extends TestBase {
test(i, false); test(i, false);
} }
} }
private void testBrokenHashFunction() {
int size = 10000;
Random r = new Random(10000);
HashSet<String> set = new HashSet<String>(size);
while (set.size() < size) {
set.add("x " + r.nextDouble());
}
for (int test = 1; test < 10; test++) {
final int badUntilLevel = test;
UniversalHash<String> badHash = new UniversalHash<String>() {
@Override
public int hashCode(String o, int index) {
if (index < badUntilLevel) {
return 0;
}
return StringHash.getFastHash(o, index);
}
};
byte[] desc = MinimalPerfectHash.generate(set, badHash);
testMinimal(desc, set, badHash);
}
}
private int test(int size, boolean minimal) { private int test(int size, boolean minimal) {
Random r = new Random(size); Random r = new Random(size);
...@@ -108,7 +144,7 @@ public class TestPerfectHash extends TestBase { ...@@ -108,7 +144,7 @@ public class TestPerfectHash extends TestBase {
private int testMinimal(int size) { private int testMinimal(int size) {
Random r = new Random(size); Random r = new Random(size);
HashSet<Long> set = new HashSet<Long>(); HashSet<Long> set = new HashSet<Long>(size);
while (set.size() < size) { while (set.size() < size) {
set.add((long) r.nextInt()); set.add((long) r.nextInt());
} }
...@@ -121,7 +157,7 @@ public class TestPerfectHash extends TestBase { ...@@ -121,7 +157,7 @@ public class TestPerfectHash extends TestBase {
private int testMinimalWithString(int size) { private int testMinimalWithString(int size) {
Random r = new Random(size); Random r = new Random(size);
HashSet<String> set = new HashSet<String>(); HashSet<String> set = new HashSet<String>(size);
while (set.size() < size) { while (set.size() < size) {
set.add("x " + r.nextDouble()); set.add("x " + r.nextDouble());
} }
......
...@@ -27,25 +27,38 @@ import java.util.zip.Inflater; ...@@ -27,25 +27,38 @@ import java.util.zip.Inflater;
* At the end of the generation process, the data is compressed using a general * At the end of the generation process, the data is compressed using a general
* purpose compression tool (Deflate / Huffman coding) down to 2.0 bits per key. * purpose compression tool (Deflate / Huffman coding) down to 2.0 bits per key.
* The uncompressed data is around 2.2 bits per key. With arithmetic coding, * The uncompressed data is around 2.2 bits per key. With arithmetic coding,
* about 1.9 bits per key are needed. Generating the hash function takes about 4 * about 1.9 bits per key are needed. Generating the hash function takes about
* second per million keys with 8 cores (multithreaded). At the expense of * 2.5 seconds per million keys with 8 cores (multithreaded). The algorithm
* processing time, a lower number of bits per key would be possible (for * automatically scales with the number of available CPUs (using as many threads
* example 1.85 bits per key with 33000 keys, using 10 seconds generation time, * as there are processors). At the expense of processing time, a lower number
* with Huffman coding). The algorithm automatically scales with the number of * of bits per key would be possible (for example 1.84 bits per key with 100000
* available CPUs (using as many threads as there are processors). * keys, using 32 seconds generation time, with Huffman coding).
* <p> * <p>
* The memory usage to efficiently calculate hash values is around 2.5 bits per * The memory usage to efficiently calculate hash values is around 2.5 bits per
* key (the space needed for the uncompressed description, plus 8 bytes for * key (the space needed for the uncompressed description, plus 8 bytes for
* every top-level bucket). * every top-level bucket).
* <p> * <p>
* At each level, only one user defined hash function per object is called
* (about 3 hash functions per key). The result is further processed using a
* supplemental hash function, so that the default user defined hash function
* doesn't need to be sophisticated (it doesn't need to be non-linear, have a
* good avalanche effect, or generate random looking data; it just should
* produce few conflicts if possible).
* <p>
* To protect against hash flooding and similar attacks, cryptographically * To protect against hash flooding and similar attacks, cryptographically
* secure functions such as SipHash or SHA-256 can be used. However, such slower * secure functions such as SipHash or SHA-256 can be used. However, such
* functions only need to be used in higher recursions levels, so that in the * (slower) functions only need to be used if regular hash functions produce too
* normal case (where no attack is happening), only fast, but less secure, hash * many conflicts. This case is detected when generating the perfect hash
* functions are needed. * function, by checking if there are too many conflicts (more than 2160 entries
* in one top-level bucket). In this case, the next hash function is used. That
* way, in the normal case, where no attack is happening, only fast, but less
* secure, hash functions are called. It is fine to use the regular hashCode
* method as the level 0 hash function.
* <p> * <p>
* In-place updating of the hash table is not implemented but possible in * In-place updating of the hash table is not implemented but possible in
* theory, by patching the hash function description. * theory, by patching the hash function description. With a small change,
* non-minimal perfect hash functions can be calculated (for example 1.22 bits
* per key at a fill rate of 81%).
* *
* @param <K> the key type * @param <K> the key type
*/ */
...@@ -101,16 +114,23 @@ public class MinimalPerfectHash<K> { ...@@ -101,16 +114,23 @@ public class MinimalPerfectHash<K> {
private final byte[] data; private final byte[] data;
/** /**
* The size up to the given top-level bucket in the data array. Used to * The size up to the given root-level bucket in the data array. Used to
* speed up calculating the hash of a key. * speed up calculating the hash of a key.
*/ */
private final int[] topSize; private final int[] rootSize;
/** /**
* The position of the given top-level bucket in the data array. Used to * The position of the given root-level bucket in the data array. Used to
* speed up calculating the hash of a key. * speed up calculating the hash of a key.
*/ */
private final int[] topPos; private final int[] rootPos;
/**
* The hash function level at the root of the tree. Typically 0, except if
* the hash function at that level didn't split the entries as expected
* (which can be due to a bad hash function, or due to an attack).
*/
private final int rootLevel;
/** /**
* Create a hash object to convert keys to hashes. * Create a hash object to convert keys to hashes.
...@@ -121,21 +141,23 @@ public class MinimalPerfectHash<K> { ...@@ -121,21 +141,23 @@ public class MinimalPerfectHash<K> {
this.hash = hash; this.hash = hash;
byte[] b = data = expand(desc); byte[] b = data = expand(desc);
if (b[0] == SPLIT_MANY) { if (b[0] == SPLIT_MANY) {
rootLevel = b[b.length - 1] & 255;
int split = readVarInt(b, 1); int split = readVarInt(b, 1);
topSize = new int[split]; rootSize = new int[split];
topPos = new int[split]; rootPos = new int[split];
int pos = 1 + getVarIntLength(b, 1); int pos = 1 + getVarIntLength(b, 1);
int sizeSum = 0; int sizeSum = 0;
for (int i = 0; i < split; i++) { for (int i = 0; i < split; i++) {
topSize[i] = sizeSum; rootSize[i] = sizeSum;
topPos[i] = pos; rootPos[i] = pos;
int start = pos; int start = pos;
pos = getNextPos(pos); pos = getNextPos(pos);
sizeSum += getSizeSum(start, pos); sizeSum += getSizeSum(start, pos);
} }
} else { } else {
topSize = null; rootLevel = 0;
topPos = null; rootSize = null;
rootPos = null;
} }
} }
...@@ -146,7 +168,7 @@ public class MinimalPerfectHash<K> { ...@@ -146,7 +168,7 @@ public class MinimalPerfectHash<K> {
* @return the hash value * @return the hash value
*/ */
public int get(K x) { public int get(K x) {
return get(0, x, 0); return get(0, x, true, rootLevel);
} }
/** /**
...@@ -154,10 +176,11 @@ public class MinimalPerfectHash<K> { ...@@ -154,10 +176,11 @@ public class MinimalPerfectHash<K> {
* *
* @param pos the start position * @param pos the start position
* @param x the key * @param x the key
* @param isRoot whether this is the root of the tree
* @param level the level * @param level the level
* @return the hash value * @return the hash value
*/ */
private int get(int pos, K x, int level) { private int get(int pos, K x, boolean isRoot, int level) {
int n = readVarInt(data, pos); int n = readVarInt(data, pos);
if (n < 2) { if (n < 2) {
return 0; return 0;
...@@ -176,9 +199,9 @@ public class MinimalPerfectHash<K> { ...@@ -176,9 +199,9 @@ public class MinimalPerfectHash<K> {
} }
int h = hash(x, hash, level, 0, split); int h = hash(x, hash, level, 0, split);
int s; int s;
if (level == 0 && topPos != null) { if (isRoot && rootPos != null) {
s = topSize[h]; s = rootSize[h];
pos = topPos[h]; pos = rootPos[h];
} else { } else {
int start = pos; int start = pos;
for (int i = 0; i < h; i++) { for (int i = 0; i < h; i++) {
...@@ -186,7 +209,7 @@ public class MinimalPerfectHash<K> { ...@@ -186,7 +209,7 @@ public class MinimalPerfectHash<K> {
} }
s = getSizeSum(start, pos); s = getSizeSum(start, pos);
} }
return s + get(pos, x, level + 1); return s + get(pos, x, false, level + 1);
} }
/** /**
...@@ -281,7 +304,7 @@ public class MinimalPerfectHash<K> { ...@@ -281,7 +304,7 @@ public class MinimalPerfectHash<K> {
int level, ByteArrayOutputStream out) { int level, ByteArrayOutputStream out) {
int size = list.size(); int size = list.size();
if (size <= 1) { if (size <= 1) {
writeVarInt(out, size); out.write(size);
return; return;
} }
if (size <= MAX_SIZE) { if (size <= MAX_SIZE) {
...@@ -312,34 +335,49 @@ public class MinimalPerfectHash<K> { ...@@ -312,34 +335,49 @@ public class MinimalPerfectHash<K> {
split = (size - 47) / DIVIDE; split = (size - 47) / DIVIDE;
} }
split = Math.max(2, split); split = Math.max(2, split);
boolean isRoot = level == 0;
ArrayList<ArrayList<K>> lists;
do {
lists = new ArrayList<ArrayList<K>>(split);
for (int i = 0; i < split; i++) {
lists.add(new ArrayList<K>(size / split));
}
for (int i = 0; i < size; i++) {
K x = list.get(i);
ArrayList<K> l = lists.get(hash(x, hash, level, 0, split));
l.add(x);
if (isRoot && split >= SPLIT_MANY &&
l.size() > 36 * DIVIDE * 10) {
// a bad hash function or attack was detected
level++;
lists = null;
break;
}
}
} while (lists == null);
if (split >= SPLIT_MANY) { if (split >= SPLIT_MANY) {
writeVarInt(out, SPLIT_MANY); out.write(SPLIT_MANY);
} }
writeVarInt(out, split); writeVarInt(out, split);
ArrayList<ArrayList<K>> lists = boolean multiThreaded = isRoot && list.size() > 1000;
new ArrayList<ArrayList<K>>(split);
for (int i = 0; i < split; i++) {
lists.add(new ArrayList<K>(size / split));
}
for (int i = 0; i < size; i++) {
K x = list.get(i);
lists.get(hash(x, hash, level, 0, split)).add(x);
}
boolean multiThreaded = level == 0 && list.size() > 1000;
list.clear(); list.clear();
list.trimToSize(); list.trimToSize();
if (multiThreaded) { if (multiThreaded) {
generateMultiThreaded(lists, hash, out); generateMultiThreaded(lists, hash, level, out);
} else { } else {
for (ArrayList<K> s2 : lists) { for (ArrayList<K> s2 : lists) {
generate(s2, hash, level + 1, out); generate(s2, hash, level + 1, out);
} }
} }
if (isRoot && split >= SPLIT_MANY) {
out.write(level);
}
} }
private static <K> void generateMultiThreaded( private static <K> void generateMultiThreaded(
final ArrayList<ArrayList<K>> lists, final ArrayList<ArrayList<K>> lists,
final UniversalHash<K> hash, final UniversalHash<K> hash,
final int level,
ByteArrayOutputStream out) { ByteArrayOutputStream out) {
final ArrayList<ByteArrayOutputStream> outList = final ArrayList<ByteArrayOutputStream> outList =
new ArrayList<ByteArrayOutputStream>(); new ArrayList<ByteArrayOutputStream>();
...@@ -360,7 +398,7 @@ public class MinimalPerfectHash<K> { ...@@ -360,7 +398,7 @@ public class MinimalPerfectHash<K> {
list = lists.remove(0); list = lists.remove(0);
outList.add(temp); outList.add(temp);
} }
generate(list, hash, 1, temp); generate(list, hash, level + 1, temp);
} }
} }
}; };
...@@ -567,6 +605,13 @@ public class MinimalPerfectHash<K> { ...@@ -567,6 +605,13 @@ public class MinimalPerfectHash<K> {
return getSipHash24(o, index, 0); return getSipHash24(o, index, 0);
} }
/**
* A cryptographically weak hash function. It is supposed to be fast.
*
* @param o the string
* @param x the seed
* @return the hash value
*/
public static int getFastHash(String o, int x) { public static int getFastHash(String o, int x) {
int result = o.length(); int result = o.length();
for (int i = 0; i < o.length(); i++) { for (int i = 0; i < o.length(); i++) {
...@@ -580,12 +625,12 @@ public class MinimalPerfectHash<K> { ...@@ -580,12 +625,12 @@ public class MinimalPerfectHash<K> {
* A cryptographically relatively secure hash function. It is supposed * A cryptographically relatively secure hash function. It is supposed
* to protected against hash-flooding denial-of-service attacks. * to protected against hash-flooding denial-of-service attacks.
* *
* @param o the object * @param o the string
* @param k0 key 0 * @param k0 key 0
* @param k1 key 1 * @param k1 key 1
* @return the hash value * @return the hash value
*/ */
private static int getSipHash24(String o, long k0, long k1) { public static int getSipHash24(String o, long k0, long k1) {
long v0 = k0 ^ 0x736f6d6570736575L; long v0 = k0 ^ 0x736f6d6570736575L;
long v1 = k1 ^ 0x646f72616e646f6dL; long v1 = k1 ^ 0x646f72616e646f6dL;
long v2 = k0 ^ 0x6c7967656e657261L; long v2 = k0 ^ 0x6c7967656e657261L;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论