今日も窓辺でプログラム

外資系企業勤めのエンジニアが勉強した内容をまとめておくブログ

LeetCode 307: Range Sum Query - Mutable をセグメント木で解く

この記事について

Union-findに続き、コーディング面接に備える上で大事そうなデータ構造であるセグメント木を勉強したので、セグメント木を使用して解ける問題を解いてみます。

セグメント木とは?

セグメント木自体の説明は、また他の方のスライドに丸投げします。。下記スライドの33ページ目から詳しく説明してあります。

www.slideshare.net

要するに、セグメント木とは、配列のある範囲内の最大値だったり最小値だったりという興味のある値を、完全二分木を使って管理するデータ構造です。
初期化は O(n)の時間がかかりますが、一度初期化してしまうと、どんな範囲内の最小値も O(\log n)で求めることができます。

今回解く問題

このセグメント木を使って、今回はLeetCodeの下記の問題を解きます。LeetCode自体の紹介記事はこちら。
LeetCode: コーディング面接に向けた練習に使えるサイトの紹介 - 今日も窓辺でプログラム

整数の配列 nums が与えられたとき、[i, j]の範囲内にある要素の和を求めよ(sumRange(i, j))。
ただし、update(i, val)関数でnums[i]をvalに変更することができるとする。

https://leetcode.com/problems/range-sum-query-mutable/

配列numはupdate関数以外で変更されることはなく、和を求める関数とupdate関数は均等に呼ばれると仮定してよいそうです。

例えば、nums = [1, 3, 5] のとき、
sumRange(0, 2) -> 9
update(1, 2) -> nums = [1, 2, 5]
sumRange(0, 2) -> 8
というような動きになります。

考え方と実装の方針

今回実装すべきsumRangeという、指定された範囲内の合計を求める関数は、セグメント木の考え方をそのまま適用することができます。
上で紹介したスライドでは、セグメント木の各ノードで最小値を保持していましが、その代わりに合計値を保持するようにするだけでOKです。

セグメント木は完全二分木なので、ヒープを使うと扱いやすそうなので、ヒープを使って実装してみます。

実装

これが私の実装したコードになります。
コメントを山のようにつけたのでコードを追っていくとわかっていただけるのではないでしょうか。

スライドに紹介されていたように、始点はinclusive、終点はexclusiveで実装しているので、実装は比較的簡潔になっているかと思います。

class NumArray {
public:
    NumArray(vector<int> &nums) {
        // 与えられた配列のサイズが n のとき、 2の累乗の中でn以上の最小のものを求める。
        // 例えば、[1, 2, 3] という配列であっても、[1, 2, 3, 0]という配列と考えたほうが
        // 実装上扱いやすいため。
        m_size = 1;
        while (m_size < nums.size()) { m_size <<= 1; }
        m_tree.resize(m_size * 2 - 1);

        // セグメント木の初期化。まずは葉ノードを与えられた配列で初期化する。
        // ヒープを使って完全二分木を保持しているので、 m_size - 1番目以降の要素が葉ノードとなる。
        for (int i = 0; i < nums.size(); i++) {
            m_tree[m_size - 1 + i] = nums[i];
        }

        // その後は葉から根に向かって順にノードを初期化していく。
        // 各ノードは2つの子の合計値を持つので、根から初期化しないように注意。
        // n番目のノードの子ノードは 2n + 1, 2n + 2でアクセスできる。
        for (int i = m_size - 2; i >= 0; i--) {
            m_tree[i] = m_tree[i * 2 + 1] + m_tree[i * 2 + 2];
        }
    }

    void update(int i, int val) {
        // まずは対象の葉ノードを更新する
        int idx = m_size - 1 + i;
        m_tree[idx] = val;

        // 対象の葉から根に到達するまで親ノードを順に更新していく
        // n番目のノードの親へは (n - 1)/2 でアクセスできる。
        while (idx > 0) {
            idx = (idx - 1) / 2;
            m_tree[idx] = m_tree[idx * 2 + 1] + m_tree[idx * 2 + 2];
        }
    }

    int sumRange(int i, int j) {
        // 根ノードから探索を開始する。引数の詳細は、下記のsumRange関数を参照。
        // 第2引数はexclusiveなので j + 1 になっている。
        // 根ノードから探索を開始するので、ノードのインデックス(第3引数)は0、
        // 第4と第5引数はルートノードのカバーする範囲 [0, m_size) となる
        return sumRange(i, j + 1, 0, 0, m_size);
    }

private:
    /*
    i, j: 合計値を求める範囲
    idx: 現在探索しているセグメント木のノード
    left, right: 現在のノードが持っている合計値の範囲

    注意:2つの範囲はそれぞれ[i, j), [left, right)という形で
    始点はinclusiveだが、終点はexclusiveで保持している。
    */
    int sumRange(int i, int j, int idx, int left, int right) {
        // 合計値が欲しい範囲と、現在のノードの対象範囲が位置した場合は、
        // そのノードが保持している合計値が欲しかった値
        if ((i == left) && (j == right)) {
            return m_tree[idx];
        }

        // ノードの対象範囲の中心となる位置 mid を求める
        // すると、左右の子ノードが保持している合計値の範囲はそれぞれ
        // [left, mid), [mid, right) という形になる。
        auto mid = (left + right) / 2;

        // パターン1: 欲しい合計値が、左側の子だけによってカバーされている場合
        if (j <= mid) {
            // 探索を左の子ノードに進めればよい
            return sumRange(i, j, idx * 2 + 1, left, mid);
        }

        // パターン2: 欲しい合計値が、右側の子だけによってカバーされている場合
        if (i >= mid) {
            // 探索を右の子ノードに進めればよい
            return sumRange(i, j, idx * 2 + 2, mid, right);
        }

        // パターン3: 欲しい合計値の範囲が、左右両方の子にまたがっている場合
        // この場合は、対象としている範囲[i, j)を、[i, mid)と[mid, j)の2つの範囲に分割して、
        // それぞれの合計を求めてやればよい。
        return sumRange(i, mid, idx * 2 + 1, left, mid) + sumRange(mid, j, idx * 2 + 2, mid, right);
    }

    // セグメント木を保持するヒープ
    vector<int> m_tree;

    // セグメント木の葉ノードの数。2の累乗。
    int m_size;
};