tomabouの日記

Haskellなどで勉強したことなどを書いていきます

忘れてしまうsegment木

競技プログラミングは忘れたころにやりたくなります
セグメント木は便利なので忘れたころに使いたくなります
しかしライブラリを作ろうにも使いたくなるころには忘れているので紛失してしまいます。
なので備忘録変わりにセグメント木をあげておきます

普通のセグメント木

プログラミングコンテストチャレンジブックに書いてあるセグメント木は以下のようなものです。
数列と、結合法則が成り立ち零元が存在する(モノイドという)演算に対して次のことが出来ます。(例えば最小値をとる演算とします)

  • ある区間の最小値をとる
  • ある値を更新する
#include<algorithm>
#include<functional>
#include<vector>

template<typename T>
class segtree {
public:
	segtree(int n) {//数列の長さを指定してセグメント木を作る
		for (size = 1; size < n; size *= 2);
		vec.resize(2 * size -1, zero);
	}
	segtree(std::vector<T> v) {//vectorで数列を与えてセグメント木を作る
		int n = v.size();
		for (size = 1; size < n; size *= 2);
		vec.resize(2 * size -1, zero);
		for (int i = 0; i < n; i++) {
			vec[i + size-1] = v[i];
		}
		for (int i = size - 2; i >= 0; i--) {
			vec[i] = func(vec[i * 2 + 1], vec[i * 2 + 2]);
		}
	}
	void set(int n, T x) {//値の更新
		n += size - 1;
		while (n >= 0) {
			vec[n] = func(vec[n], x);
			n = (n + 1) / 2 - 1;
		}
	}
	int get(int a, int b) {//値の取得([a,b)で指定)
		return get_func(0, 0, size, a, b);
	}
private:
	T zero =2147483647;//モノイドの零元 ここを下の関数と一緒に適切なものに変えればよい
	std::function<T(T, T)> func = [](T a, T b) {return std::min(a, b); };//演算
	std::vector<T> vec;
	int size;
	
	int get_func(int n, int p, int q, int x, int y) {
		if (q - p == 1)return vec[n];
		if (x == p && q == y) return vec[n];
		int mean = (p + q) / 2;
		T l = zero;
		T r = zero;
		if (x < mean)l = get_func(n * 2 + 1, p, mean, x, std::min(y,mean));
		if (mean < y)r = get_func(n * 2 + 2, mean, q, std::max(x,mean), y);
		return func(l, r);
	}
};

普通のセグメント木の双対

数列と可換なモノイドに対して次のことが出来ます(例えば和とるとします)

  • ある区間全体に同じ数字を足す
  • ある値を取得する

これは普通のセグメント木のデータの流れを逆向きにすると実装できます(このコードはその双対が分かりやすいと思います)

template<typename T>
class segtree2 {
public:
	segtree2(int n) {
		for (size = 1; size < n; size *= 2);
		vec.resize(2 * size - 1, zero);
	}
	segtree2(std::vector<T> v) {
		int n = v.size();
		for (size = 1; size < n; size *= 2);
		vec.resize(2 * size - 1, zero);
		for (int i = 0; i < n; i++) {
			vec[i + size - 1] = v[i];
		}
		for (int i = size - 2; i >= 0; i--) {
			vec[i] = func(vec[i * 2 + 1], vec[i * 2 + 2]);
		}
	}
	T get(int n) {
		T ans = zero;
		n += size - 1;
		while (n >= 0) {
			ans = func(vec[n], ans);
			n = (n + 1) / 2 - 1;
		}
		return ans;
	}
	void set(int a, int b, T x) {
		return set_func(0, 0, size, a, b,x);
	}
private:
	T zero = 0;
	std::function<T(T, T)> func = [](T a, T b) {return a+ b; };
	std::vector<T> vec;
	int size;

	void set_func(int n, int p, int q, int x, int y, T val) {
		if (q - p == 1) {
			vec[n] = func(vec[n],val);
			return;
		}
		if (x == p && q == y) {
			vec[n] = func(vec[n], val);
			return ;
		}
		int mean = (p + q) / 2;
		if (x < mean)set_func(n * 2 + 1, p, mean, x, std::min(y, mean),val);
		if (mean < y)set_func(n * 2 + 2, mean, q, std::max(x, mean), y,val);
		return;
	}
};