본문 바로가기
Computer Science

[C++] Segment Tree 라이브러리 만들기

by invrtd.h 2022. 8. 16.

Segment Tree는 PS에서 정말 자주 사용되는 자료구조 중 하나다. 그러나 Segment Tree에 들어갈 타입도 int, long long, derived-type 등 제각각, 연산도 add, multiply, min, max 등 제각각이므로 그냥 Generic programming으로 클래스 하나를 만들었다. 이 Segment Tree는 타입 T와 함수객체 BinaryOP를 받는다. T에 타입 넣고 BinaryOP에 인자 2개 받는 operator() 정의하면 끝.

/*
 * Segment-tree!
 */
template<typename T, typename BinaryOP>
class Segment_tree {
    std::vector<T> tree;
    int size;
public:
    explicit Segment_tree(const std::vector<T> &init_list);
    T operate(int l, int r) const;
    void update(int idx, T t);
private:
    static constexpr int log2(long long n);
    T init(int l, int r, int node_idx, const std::vector<T> &init_list);
    T operate(int l, int r, int lnow, int rnow, int node_idx) const;
    void update(int idx, T t, int lnow, int rnow, int node_idx);
};

template<typename T, typename BinaryOP>
Segment_tree<T, BinaryOP>::Segment_tree(const std::vector<T> &init_list) : size(init_list.size()) {
    tree.resize(1 << (log2(init_list.size()) + 2));
    init(0, init_list.size() - 1, 1, init_list);
}

template<typename T, typename BinaryOP>
T Segment_tree<T, BinaryOP>::operate(int l, int r) const {
    return operate(l, r, 0, size - 1, 1);
}

template<typename T, typename BinaryOP>
constexpr int Segment_tree<T, BinaryOP>::log2(long long int n) {
    return n <= 1 ? 0 : log2(n / 2) + 1;
}

template<typename T, typename BinaryOP>
T Segment_tree<T, BinaryOP>::init(int l, int r, int node_idx, const std::vector<T> &init_list) {
    if (l == r) return tree[node_idx] = init_list[l];
    int mid = (l + r) / 2;
    return tree[node_idx] = BinaryOP()(
            init(l, mid, node_idx * 2, init_list),
            init(mid + 1, r, node_idx * 2 + 1, init_list)
    );
}

template<typename T, typename BinaryOP>
T Segment_tree<T, BinaryOP>::operate(int l, int r, int lnow, int rnow, int node_idx) const {
    if (l == lnow and r == rnow) return tree[node_idx];
    int mid = (lnow + rnow) / 2;
    if (r <= mid) return operate(l, r, lnow, mid, node_idx * 2);
    if (mid < l) return operate(l, r, mid + 1, rnow, node_idx * 2 + 1);
    return BinaryOP()(
            operate(l, mid, lnow, mid, node_idx * 2),
            operate(mid + 1, r, mid + 1, rnow, node_idx * 2 + 1)
    );
}

template<typename T, typename BinaryOP>
void Segment_tree<T, BinaryOP>::update(int idx, T t) {
    update(idx, t, 0, size - 1, 1);
}

template<typename T, typename BinaryOP>
void Segment_tree<T, BinaryOP>::update(int idx, T t, int lnow, int rnow, int node_idx) {
    if (lnow == rnow) {
        tree[node_idx] = t;
        return;
    }
    int mid = (lnow + rnow) / 2;
    idx <= mid ? update(idx, t, lnow, mid, node_idx * 2) : update(idx, t, mid + 1, rnow, node_idx * 2 + 1);
    tree[node_idx] = BinaryOP()(tree[node_idx * 2], tree[node_idx * 2 + 1]);
}
/*
 * Segment-tree end
 */

댓글