この記事について
Union-findに続き、コーディング面接に備える上で大事そうなデータ構造であるセグメント木を勉強したので、セグメント木を使用して解ける問題を解いてみます。
セグメント木とは?
セグメント木自体の説明は、また他の方のスライドに丸投げします。。下記スライドの33ページ目から詳しく説明してあります。
www.slideshare.net要するに、セグメント木とは、配列のある範囲内の最大値だったり最小値だったりという興味のある値を、完全二分木を使って管理するデータ構造です。
初期化はの時間がかかりますが、一度初期化してしまうと、どんな範囲内の最小値もで求めることができます。
今回解く問題
このセグメント木を使って、今回はLeetCodeの下記の問題を解きます。LeetCode自体の紹介記事はこちら。
LeetCode: コーディング面接に向けた練習に使えるサイトの紹介 - 今日も窓辺でプログラム
整数の配列 nums が与えられたとき、[i, j]の範囲内にある要素の和を求めよ(sumRange(i, j))。
ただし、update(i, val)関数でnums[i]をvalに変更することができるとする。
https://leetcode.com/problems/range-sum-query-mutable/
配列numはupdate関数以外で変更されることはなく、和を求める関数とupdate関数は均等に呼ばれると仮定してよいそうです。
例えば、nums = [1, 3, 5] のとき、
sumRange(0, 2) -> 9
update(1, 2) -> nums = [1, 2, 5]
sumRange(0, 2) -> 8
というような動きになります。
考え方と実装の方針
今回実装すべきsumRangeという、指定された範囲内の合計を求める関数は、セグメント木の考え方をそのまま適用することができます。
上で紹介したスライドでは、セグメント木の各ノードで最小値を保持していましが、その代わりに合計値を保持するようにするだけでOKです。
セグメント木は完全二分木なので、ヒープを使うと扱いやすそうなので、ヒープを使って実装してみます。
実装
これが私の実装したコードになります。
コメントを山のようにつけたのでコードを追っていくとわかっていただけるのではないでしょうか。
スライドに紹介されていたように、始点はinclusive、終点はexclusiveで実装しているので、実装は比較的簡潔になっているかと思います。
class NumArray {
public:
NumArray(vector<int> &nums) {
m_size = 1;
while (m_size < nums.size()) { m_size <<= 1; }
m_tree.resize(m_size * 2 - 1);
for (int i = 0; i < nums.size(); i++) {
m_tree[m_size - 1 + i] = nums[i];
}
for (int i = m_size - 2; i >= 0; i--) {
m_tree[i] = m_tree[i * 2 + 1] + m_tree[i * 2 + 2];
}
}
void update(int i, int val) {
int idx = m_size - 1 + i;
m_tree[idx] = val;
while (idx > 0) {
idx = (idx - 1) / 2;
m_tree[idx] = m_tree[idx * 2 + 1] + m_tree[idx * 2 + 2];
}
}
int sumRange(int i, int j) {
return sumRange(i, j + 1, 0, 0, m_size);
}
private:
int sumRange(int i, int j, int idx, int left, int right) {
if ((i == left) && (j == right)) {
return m_tree[idx];
}
auto mid = (left + right) / 2;
if (j <= mid) {
return sumRange(i, j, idx * 2 + 1, left, mid);
}
if (i >= mid) {
return sumRange(i, j, idx * 2 + 2, mid, right);
}
return sumRange(i, mid, idx * 2 + 1, left, mid) + sumRange(mid, j, idx * 2 + 2, mid, right);
}
vector<int> m_tree;
int m_size;
};