今日も窓辺でプログラム

外資系企業勤めのエンジニアが勉強した内容をまとめておくブログ

LeetCode 338: Counting Bits

この記事で扱う問題

LeetCodeのCounting Bitsという問題を解きます。問題を適当日本語訳で引用すると、

負でない整数 num が与えられる。 0 <= i <= num を満たす全ての i について、iの2進数表現に含まれる1の数を計算して配列として出力せよ。

例:num = 5 の場合、戻り値は [0,1,1,2,1,2] となる。

https://leetcode.com/problems/counting-bits/

といった感じです。

LeetCode自体の解説記事はこちら:
LeetCode: コーディング面接に向けた練習に使えるサイトの紹介 - 今日も窓辺でプログラム

最も単純な方法

最も単純な方法は、0 <= i <= num のすべての i について、1が何個含まれるか数えることです。iひとつについて整数のビット数分の計算量が発生してしまうので、トータルでO(n * sizeof(integer))の計算量となります。

例を詳しく追ってみる

問題文には num = 5 の場合の例が与えられています。iの2進数表現に含まれる1のビットの数をf(i)とすると、0 <= i <= 5 での2進数表現、f(i)は次の表のようになります。

i 2進数 f(i)
0 0b000 0
1 0b001 1
2 0b010 1
3 0b011 2
4 0b100 1
5 0b101 2

この表を見ていると、f(i)はそれより前のf(i')を使用して計算することができそうです。動的計画法を使って、小さいほうからf(i)を順に埋めていくと、O(n)で解けそうです。

最上位の1であるビットを反転した値を使用する方法

まず思いついた方法は、最上位の1であるビットを0に反転させた値を使用する方法でした。例えば5 = 0b101 の場合は、最上位ビットの1を0に反転させると 0b001 (=1)になります。
動的計画法で小さいiから順に計算していると、f(0b001)は既知なので、f(0b101) = f(0b001) + 1 と計算することができます。
現在の最上位ビットの位置を mask という変数で保持しておけば、最上位ビットの反転は簡単にできます。
C++でのコードはこのような形になります。

class Solution {
public:
    vector<int> countBits(int num) {
        // ans[0] = 0 なので、全要素0で初期化しておく。
        vector<int> ans(num + 1, 0);

        // この mask が現在のiの最も大きい'1'ビットの位置を保持している
        int mask = 1;
        for (int i = 1; i <= num; i++) {
            // mask は常に正しくなるように、iが2の累乗になるごとに左にずらす
            if ((mask << 1) <= i) { mask <<= 1; }

            // i^maskで最上位の1が0に反転する。
            // ans[i^mask] に、0に反転させた最上位の1をカウントしてans[i]を求める
            ans[i] = ans[i^mask] + 1;
        }

        return ans;
    }
};

このコードは、時間計算量・空間計算量ともにO(n)となります。

より短いコードもある

最上位ビットを反転させる方法は、例を順番に観察していて思いついたのですが、何も反転させるのは最上位の1でなくてもいいのです。例えば最下位の1のビットを反転させても同様のコードが書けるでしょう。
そして、最下位の1のビットを反転させる方法は、以前の記事で書いた n&(n-1) という少し変わった式が使えるので、より短いコードで済みます。
n & (n - 1) == 0 の意味とは? - 今日も窓辺でプログラム

詳細は上記記事に譲りますが、n&(n-1)という計算をするとnの最下位の1であるビットが0になります。
これを使うと、先ほどのコードはここまで短くなります。

class Solution {
public:
    vector<int> countBits(int num) {
        // ans[0] = 0 なので、全要素0で初期化しておく。
        vector<int> ans(num + 1, 0);

        // i&(i-1)が最下位の1を反転させたものなので、
        // ans[i&(i-1)]と最下位の1を足すとans[i]が求められる
        for (int i = 1; i <= num; i++) {
            ans[i] = ans[i&(i - 1)] + 1;
        }

        return ans;
    }
};

計算量は、先ほどのコードと同様で時間空間ともにO(n)です。maskの保持・更新がない分、若干こちらのほうが計算回数は少ないです。