提交 1b19a776 authored 作者: Thomas Mueller's avatar Thomas Mueller

A minimal perfect hash function tool: use universal hashing callback (protection…

A minimal perfect hash function tool: use universal hashing callback (protection against hash flooding)
上级 05719a66
......@@ -29,8 +29,8 @@ public class TestPerfectHash extends TestBase {
*/
public static void main(String... a) throws Exception {
TestPerfectHash test = (TestPerfectHash) TestBase.createCaller().init();
test.measure();
test.test();
test.measure();
}
/**
......@@ -38,25 +38,36 @@ public class TestPerfectHash extends TestBase {
*/
public void measure() {
int size = 1000000;
testMinimal(size / 10);
int s;
long time = System.currentTimeMillis();
s = testMinimal(size);
time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (minimal) in " + time + " ms");
System.out.println((double) s / size + " bits/key (minimal) in " +
time + " ms");
time = System.currentTimeMillis();
s = testMinimalWithString(size);
time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (minimal; String keys) in " + time + " ms");
System.out.println((double) s / size +
" bits/key (minimal; String keys) in " +
time + " ms");
time = System.currentTimeMillis();
s = test(size, true);
System.out.println((double) s / size + " bits/key (minimal old)");
time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (minimal old) in " +
time + " ms");
time = System.currentTimeMillis();
s = test(size, false);
System.out.println((double) s / size + " bits/key (not minimal)");
time = System.currentTimeMillis() - time;
System.out.println((double) s / size + " bits/key (not minimal) in " +
time + " ms");
}
@Override
public void test() {
testBrokenHashFunction();
for (int i = 0; i < 100; i++) {
testMinimal(i);
}
......@@ -73,6 +84,31 @@ public class TestPerfectHash extends TestBase {
}
}
private void testBrokenHashFunction() {
int size = 10000;
Random r = new Random(10000);
HashSet<String> set = new HashSet<String>(size);
while (set.size() < size) {
set.add("x " + r.nextDouble());
}
for (int test = 1; test < 10; test++) {
final int badUntilLevel = test;
UniversalHash<String> badHash = new UniversalHash<String>() {
@Override
public int hashCode(String o, int index) {
if (index < badUntilLevel) {
return 0;
}
return StringHash.getFastHash(o, index);
}
};
byte[] desc = MinimalPerfectHash.generate(set, badHash);
testMinimal(desc, set, badHash);
}
}
private int test(int size, boolean minimal) {
Random r = new Random(size);
HashSet<Integer> set = new HashSet<Integer>();
......@@ -108,7 +144,7 @@ public class TestPerfectHash extends TestBase {
private int testMinimal(int size) {
Random r = new Random(size);
HashSet<Long> set = new HashSet<Long>();
HashSet<Long> set = new HashSet<Long>(size);
while (set.size() < size) {
set.add((long) r.nextInt());
}
......@@ -121,7 +157,7 @@ public class TestPerfectHash extends TestBase {
private int testMinimalWithString(int size) {
Random r = new Random(size);
HashSet<String> set = new HashSet<String>();
HashSet<String> set = new HashSet<String>(size);
while (set.size() < size) {
set.add("x " + r.nextDouble());
}
......
......@@ -27,25 +27,38 @@ import java.util.zip.Inflater;
* At the end of the generation process, the data is compressed using a general
* purpose compression tool (Deflate / Huffman coding) down to 2.0 bits per key.
* The uncompressed data is around 2.2 bits per key. With arithmetic coding,
* about 1.9 bits per key are needed. Generating the hash function takes about 4
* second per million keys with 8 cores (multithreaded). At the expense of
* processing time, a lower number of bits per key would be possible (for
* example 1.85 bits per key with 33000 keys, using 10 seconds generation time,
* with Huffman coding). The algorithm automatically scales with the number of
* available CPUs (using as many threads as there are processors).
* about 1.9 bits per key are needed. Generating the hash function takes about
* 2.5 seconds per million keys with 8 cores (multithreaded). The algorithm
* automatically scales with the number of available CPUs (using as many threads
* as there are processors). At the expense of processing time, a lower number
* of bits per key would be possible (for example 1.84 bits per key with 100000
* keys, using 32 seconds generation time, with Huffman coding).
* <p>
* The memory usage to efficiently calculate hash values is around 2.5 bits per
* key (the space needed for the uncompressed description, plus 8 bytes for
* every top-level bucket).
* <p>
* At each level, only one user defined hash function per object is called
* (about 3 hash functions per key). The result is further processed using a
* supplemental hash function, so that the default user defined hash function
* doesn't need to be sophisticated (it doesn't need to be non-linear, have a
* good avalanche effect, or generate random looking data; it just should
* produce few conflicts if possible).
* <p>
* To protect against hash flooding and similar attacks, cryptographically
* secure functions such as SipHash or SHA-256 can be used. However, such slower
* functions only need to be used in higher recursions levels, so that in the
* normal case (where no attack is happening), only fast, but less secure, hash
* functions are needed.
* secure functions such as SipHash or SHA-256 can be used. However, such
* (slower) functions only need to be used if regular hash functions produce too
* many conflicts. This case is detected when generating the perfect hash
* function, by checking if there are too many conflicts (more than 2160 entries
* in one top-level bucket). In this case, the next hash function is used. That
* way, in the normal case, where no attack is happening, only fast, but less
* secure, hash functions are called. It is fine to use the regular hashCode
* method as the level 0 hash function.
* <p>
* In-place updating of the hash table is not implemented but possible in
* theory, by patching the hash function description.
* theory, by patching the hash function description. With a small change,
* non-minimal perfect hash functions can be calculated (for example 1.22 bits
* per key at a fill rate of 81%).
*
* @param <K> the key type
*/
......@@ -101,16 +114,23 @@ public class MinimalPerfectHash<K> {
private final byte[] data;
/**
* The size up to the given top-level bucket in the data array. Used to
* The size up to the given root-level bucket in the data array. Used to
* speed up calculating the hash of a key.
*/
private final int[] topSize;
private final int[] rootSize;
/**
* The position of the given top-level bucket in the data array. Used to
* The position of the given root-level bucket in the data array. Used to
* speed up calculating the hash of a key.
*/
private final int[] topPos;
private final int[] rootPos;
/**
* The hash function level at the root of the tree. Typically 0, except if
* the hash function at that level didn't split the entries as expected
* (which can be due to a bad hash function, or due to an attack).
*/
private final int rootLevel;
/**
* Create a hash object to convert keys to hashes.
......@@ -121,21 +141,23 @@ public class MinimalPerfectHash<K> {
this.hash = hash;
byte[] b = data = expand(desc);
if (b[0] == SPLIT_MANY) {
rootLevel = b[b.length - 1] & 255;
int split = readVarInt(b, 1);
topSize = new int[split];
topPos = new int[split];
rootSize = new int[split];
rootPos = new int[split];
int pos = 1 + getVarIntLength(b, 1);
int sizeSum = 0;
for (int i = 0; i < split; i++) {
topSize[i] = sizeSum;
topPos[i] = pos;
rootSize[i] = sizeSum;
rootPos[i] = pos;
int start = pos;
pos = getNextPos(pos);
sizeSum += getSizeSum(start, pos);
}
} else {
topSize = null;
topPos = null;
rootLevel = 0;
rootSize = null;
rootPos = null;
}
}
......@@ -146,7 +168,7 @@ public class MinimalPerfectHash<K> {
* @return the hash value
*/
public int get(K x) {
return get(0, x, 0);
return get(0, x, true, rootLevel);
}
/**
......@@ -154,10 +176,11 @@ public class MinimalPerfectHash<K> {
*
* @param pos the start position
* @param x the key
* @param isRoot whether this is the root of the tree
* @param level the level
* @return the hash value
*/
private int get(int pos, K x, int level) {
private int get(int pos, K x, boolean isRoot, int level) {
int n = readVarInt(data, pos);
if (n < 2) {
return 0;
......@@ -176,9 +199,9 @@ public class MinimalPerfectHash<K> {
}
int h = hash(x, hash, level, 0, split);
int s;
if (level == 0 && topPos != null) {
s = topSize[h];
pos = topPos[h];
if (isRoot && rootPos != null) {
s = rootSize[h];
pos = rootPos[h];
} else {
int start = pos;
for (int i = 0; i < h; i++) {
......@@ -186,7 +209,7 @@ public class MinimalPerfectHash<K> {
}
s = getSizeSum(start, pos);
}
return s + get(pos, x, level + 1);
return s + get(pos, x, false, level + 1);
}
/**
......@@ -281,7 +304,7 @@ public class MinimalPerfectHash<K> {
int level, ByteArrayOutputStream out) {
int size = list.size();
if (size <= 1) {
writeVarInt(out, size);
out.write(size);
return;
}
if (size <= MAX_SIZE) {
......@@ -312,34 +335,49 @@ public class MinimalPerfectHash<K> {
split = (size - 47) / DIVIDE;
}
split = Math.max(2, split);
if (split >= SPLIT_MANY) {
writeVarInt(out, SPLIT_MANY);
}
writeVarInt(out, split);
ArrayList<ArrayList<K>> lists =
new ArrayList<ArrayList<K>>(split);
boolean isRoot = level == 0;
ArrayList<ArrayList<K>> lists;
do {
lists = new ArrayList<ArrayList<K>>(split);
for (int i = 0; i < split; i++) {
lists.add(new ArrayList<K>(size / split));
}
for (int i = 0; i < size; i++) {
K x = list.get(i);
lists.get(hash(x, hash, level, 0, split)).add(x);
ArrayList<K> l = lists.get(hash(x, hash, level, 0, split));
l.add(x);
if (isRoot && split >= SPLIT_MANY &&
l.size() > 36 * DIVIDE * 10) {
// a bad hash function or attack was detected
level++;
lists = null;
break;
}
}
boolean multiThreaded = level == 0 && list.size() > 1000;
} while (lists == null);
if (split >= SPLIT_MANY) {
out.write(SPLIT_MANY);
}
writeVarInt(out, split);
boolean multiThreaded = isRoot && list.size() > 1000;
list.clear();
list.trimToSize();
if (multiThreaded) {
generateMultiThreaded(lists, hash, out);
generateMultiThreaded(lists, hash, level, out);
} else {
for (ArrayList<K> s2 : lists) {
generate(s2, hash, level + 1, out);
}
}
if (isRoot && split >= SPLIT_MANY) {
out.write(level);
}
}
private static <K> void generateMultiThreaded(
final ArrayList<ArrayList<K>> lists,
final UniversalHash<K> hash,
final int level,
ByteArrayOutputStream out) {
final ArrayList<ByteArrayOutputStream> outList =
new ArrayList<ByteArrayOutputStream>();
......@@ -360,7 +398,7 @@ public class MinimalPerfectHash<K> {
list = lists.remove(0);
outList.add(temp);
}
generate(list, hash, 1, temp);
generate(list, hash, level + 1, temp);
}
}
};
......@@ -567,6 +605,13 @@ public class MinimalPerfectHash<K> {
return getSipHash24(o, index, 0);
}
/**
* A cryptographically weak hash function. It is supposed to be fast.
*
* @param o the string
* @param x the seed
* @return the hash value
*/
public static int getFastHash(String o, int x) {
int result = o.length();
for (int i = 0; i < o.length(); i++) {
......@@ -580,12 +625,12 @@ public class MinimalPerfectHash<K> {
* A cryptographically relatively secure hash function. It is supposed
* to protected against hash-flooding denial-of-service attacks.
*
* @param o the object
* @param o the string
* @param k0 key 0
* @param k1 key 1
* @return the hash value
*/
private static int getSipHash24(String o, long k0, long k1) {
public static int getSipHash24(String o, long k0, long k1) {
long v0 = k0 ^ 0x736f6d6570736575L;
long v1 = k1 ^ 0x646f72616e646f6dL;
long v2 = k0 ^ 0x6c7967656e657261L;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论