JavaのHashMap(特に、Integerが絡む場合)のオーバーヘッドについての知見メモ。
突然だが、次の問題を考えよう。
32bit整数の列がたくさん(<=2*10^6個)与えられる。どの数が幾つずつ存在するか調べて報告せよ。
HashMapを用いて素直に組むと、次のような実装になるだろう。
// HashMap版
public void find(int[] a) {
Map<Integer,Integer> degree = new HashMap<>();
for (int i = 0 ; i < a.length ; i++) {
degree.put(a[i], degree.getOrDefault(a[i], 0)+1);
}
// something to do
}
これでハッシュの実装が良く、与えられる数に意地悪がなければ計算量は O(n)
だろうか?
公式ドキュメントを読むと、以下の記述が見つかる。
HashMapのインスタンスには、その性能に影響を与える2つのパラメータである初期容量および負荷係数があります。容量はハッシュ表のバケット数であり、初期容量は単純にハッシュ表が作成された時点での容量です。負荷係数は、ハッシュ表がどの程度いっぱいになると、その容量が自動的に増加されるかの基準です。ハッシュ表エントリ数が負荷係数と現在の容量の積を超えると、ハッシュ表のハッシュがやり直され(つまり、内部データ構造が再構築され)、ハッシュ表のバケット数は約2倍になります。
で、OpenJDKの実装を見る と、初期容量は16、負荷係数が0.75とある。したがって、約n種類のキーを挿入した場合、おおよそ logn
回、容量が2倍になり、これまで入れたキーの再配置が行われる。よって、マージテク と同じ考え方で一つのキーに着目すると、最大 logn
回の移動が行われるので、大雑把に見積もっても計算量は O(nlogn)
になる。ほんとか?
もう少しまともに解析すると、エントリの移動回数の合計は 1 + 2 + 4 + ... + n
なのでやっぱり O(n)
。
ちなみに、初期容量を十分に大きく設定すれば速くなるのだろうか?
// HashMap版
public void find(int[] a) {
Map<Integer,Integer> degree = new HashMap<>(a.length * 2);
for (int i = 0 ; i < a.length ; i++) {
degree.put(degree.getOrDefault(a[i], 0)+1);
}
// something to do
}
残念ながらそんなに早くならない。(測定結果は本エントリ末尾にまとめて示す。) ところで、Javaで競プロをやる人たちの間では、こういうタスクにソートを用いると速いことが知られている(要出典)。
// Sort
public void find(int[] a) {
Arrays.sort(a);
for (int i = 0 ; i < a.length ; ) {
int j = i;
while (j < a.length && a[i] == a[j]) {
j++;
}
// a[i] appears j-i times
i = j;
}
return response;
}
計算量は O(nlogn)
だが、こちらの方がずっと速い。
Integerのボクシング処理とノードオブジェクトの生成部分がつらいのでは。(適当)
O(n)
のくせにソートに負けるとは面汚しよ・・・ということで、ボクシングやオブジェクト生成のオーバーヘッドを削った IntHashMap
を書いた。
削除ができない が、今回の用途の場合は追加と参照があれば十分だ。ハッシュ関数は OpenJDK6のHashMap から借用した。
public class IntHashMap {
int defaultValue;
int M;
int[] next;
int[] entryKey;
int[] entryValue;
int nextIndex;
public IntHashMap(int capacity, int defaultValue) {
int c = Math.max(32, Integer.highestOneBit(capacity-1)<<1);
this.defaultValue = defaultValue;
this.next = new int[c];
this.entryKey = new int[c];
this.entryValue = new int[c];
this.M = c / 2;
this.nextIndex = M;
Arrays.fill(next, -1);
Arrays.fill(entryValue, defaultValue);
}
public boolean containsKey(int key) {
return detectEntry(key) >= 0;
}
public int get(int key) {
int pos = detectEntry(key);
if (pos < 0) {
return defaultValue;
}
return entryValue[pos];
}
public void put(int key, int value) {
int pos = detectEntry(key);
if (pos < 0) {
addEntry(-pos-1, key, value);
} else {
entryValue[pos] = value;
}
}
private int detectEntry(int key) {
int pos = hashPosition(key);
while (true) {
if (entryKey[pos] == key) {
return pos;
}
if (next[pos] == -1) {
break;
}
pos = next[pos];
}
return -(pos+1);
}
private void addEntry(int pos, int key, int value) {
assert(next[pos] == -1);
next[pos] = nextIndex++;
int newpos = next[pos];
entryKey[newpos] = key;
entryValue[newpos] = value;
}
private int hashPosition(int key) {
return hash(key) & (M-1);
}
private static int hash(int h) {
h ^= (h >>> 20) ^ (h >>> 12);
return h ^ (h >>> 7) ^ (h >>> 4);
}
}
測定に用いた環境、結果を示す。測定プログラムは これ。
seed に 1〜10を与えて200万個の配列を生成
public static int[] gen(long seed) {
Random rand = new Random(seed);
int size = 2_000_000;
int[] ret = new int[size];
for (int i = 0 ; i < size ; i++) {
ret[i] = rand.nextInt(1_000_000_000);
}
return ret;
}
関数を呼んでから戻るまでを各seedで測定、平均を取った
何 | seed1〜10の平均(ms) | 入力壊す? |
---|---|---|
HashMap版(デフォルト初期容量) | 770.816 | No |
HashMap版(初期容量400万) | 851.236 | No |
ソート版 | 255.595 | Yes |
IntHashMap版 | 330.828 | No |
うん、まぁそれはそうという結果に。(ソート版は入力を壊してるので厳密にはアンフェア)