본문 바로가기
Computer Science/C++

[C++] 객체지향, 제네릭 세그먼트 트리 라이브러리 구현하기

by invrtd.h 2023. 10. 3.

이 글에서는 재사용 가능한 제네릭 세그먼트 트리 라이브러리를 구축하는 방법에 대해 알아본다. 세그먼트 트리를 사용할 어떤 상황이 나와도 최대한 복붙만으로 문제를 해결할 수 있게 하는 것을 목표로 한다. 세그먼트 트리의 작동 원리 자체에 대해 설명하는 글은 아니기 때문에 독자가 세그먼트 트리에 대해 어느 정도 이해하고 있다고 가정한다. 

 

Motivation

세그먼트 트리는 point update & range query를 O(log N)에 처리할 때 유용하게 사용 가능한 자료구조다. 여기서 range query는 다음과 같은 것들이 주어질 수 있다:

  • 구간 a[l..r]의 합/곱/xor을 구하기
  • 구간 a[l..r]의 최솟값/최댓값을 구하기
  • 구간 a[l..r]의 최대 부분연속합을 구하기 (속칭 금광세그)
  • 구간 a[l..r]에서 k번째로 큰 원소를 구하기 (머지 소트 트리)

range query가 바뀔 때마다 세그먼트 트리를 수정해 주는 것은 불필요하게 노가다 같다는 느낌이 든다. 따라서 C++ 템플릿으로 세그먼트 트리를 구현해서 코드의 재사용성을 높였으면 좋겠다.

 

Concepts

다음 사실에 주목해보자. 주어지는 range query들은 모두 이항 연산의 반복이다. 덧셈도, 곱셈도, xor도, max/min도 모두 이항 연산이다. 따라서 이항 연산을 잘 표현할 수 있는 구조체를 하나 만들려고 한다. 그래서 이런 구조체가 특정 요구조건을 만족시켜야 한다는 것을 라이브러리 사용자에게 알리고 싶다. C++20에서는 Concept를 사용해서 이 요구조건을 코드로 문서화하는 것이 가능하다.

template<typename Op, typename C>
concept Optype = std::default_initializable<Op> &&
        requires (Op op, C lhs, C rhs) {
    {op(lhs, rhs)} -> std::same_as<C>;
    {Op::identity()} -> std::same_as<C>;
};

이 요구조건은 값을 나타내는 타입 C에 대해서(C는 Cell의 앞글자다) 타입 Op이 다음 요구조건을 만족시켜야 함을 나타낸다.

  • std::default_initializable<Op>. 즉 빈 생성자 Op{}이 유효한 구문이어야 한다.
  • Op 타입의 객체 op에 대해서 op에는 함수호출 연산자 ()이 오버로드되어 있어서, op(lhs, rhs)가 유효한 구문이어야 하고, 그 결과 C 타입의 값을 리턴해야 한다.
  • Op 클래스 내에 static 함수 identity()가 정의되어 있어서, C 타입의 값을 리턴해야 한다.

identity() 함수는 세그먼트 트리 내부의 빈 노드를 init하는 데 쓰는 등 곳곳에 요긴하게 쓰인다. 그런데 여기서 내가 identity() 함수를 C 클래스가 아니라 Op 클래스에 정의할 것을 굳이 강제한 이유가 있는데, 다음과 같은 이유 때문이다.

  • C는 세그먼트 트리 내의 노드 값을 나타내는데, C가 int일 수도 있다. 즉 int 내부에 identity() 함수를 정의할 수 없다.
  • 나는 항등원을 C의 성질이라기보다는 Op의 성질이라고 생각했다. 예를 들어 + 연산과 max 연산은 모두 (int, int) -> int 연산이지만, + 연산의 항등원은 0이고 max 연산의 항등원은 -inf다. 

 

Implementation

이제 세그먼트 트리를 구축해 보자. 우리의 세그먼트 트리는 2개의 타입 인자를 받는다. 첫 번째 타입 인자는 세그먼트 트리의 노드에 들어갈 셀 C이고, 두 번째 타입 인자는 방금 전에 정의한 OpType concept를 만족하는 이항 연산 Op이다. 따라서 템플릿 선언은 다음과 같을 것이다.

template<typename C, OpType<C> Op>
struct Seg {
    std::vector<C> data;
}

이제 세그먼트 트리를 구현한다. 제네릭 세그먼트 트리를 구현하고 있는 만큼, 다음 주의사항을 숙지하는 것이 좋다:

  • 이항 연산 Op을 사용하는 방법은 Op()(lhs, rhs)다. 사이에 괄호가 하나 껴 있는 것은 Op의 생성자를 불러주는 것이다.
  • 빈 노드를 init한다거나, range query를 수행할 때 결과를 저장하는 변수를 init한다거나 할 때 C ret = 0; 같은 문장을 쓰는 실수를 하지 않도록 하자. max나 min 연산자는 0이 항등원이 아니며, C가 직접 정의한 구조체라면 해당 코드는 컴파일조차 되지 않을 가능성이 높다. 올바른 방식은 C ret = Op::identity();다. 
  • C 타입을 함수 인자로 받을 땐 가급적이면 참조자로 받을 것. C = int라면 참조자를 안 쓰는 게 더 빠르겠지만, 우리의 세그먼트 트리는 C가 64byte짜리 무거운 구조체일 가능성도 포함한다.

이외 모든 구현은 일반적인 세그먼트 트리의 구현을 따라가도 좋다. 나는 다음과 같이 구현했다. 

 

template<typename C, Optype<C> Op>
struct Seg {
    std::vector<C> data;
    
    explicit Seg(const std::vector<C>& vec) {
        int n = 1<<std::bit_width(vec.size());
        data = std::vector<C>(2 * n, Op::identity());
        
        for (std::size_t i = 0; i < vec.size(); ++i) {
            data[n + i] = vec[i];
        }
        
        for (int i = n - 1; i > 0; --i) {
            data[i] = Op()(data[2 * i], data[2 * i + 1]);
        }
    }
    
    [[nodiscard]] int size() const {
        return (int) data.size() / 2;
    }
    
    void update(int idx, const C& val) {
        idx += size();
        data[idx] = val;
        idx /= 2;
        
        for (; idx; idx /= 2) {
            data[idx] = Op()(data[2 * idx], data[2 * idx + 1]);
        }
    }
    
    [[nodiscard]] C get(int idx) const {
        return data[size() + idx];
    }
    
    [[nodiscard]] C reduce(int l, int r) const {
        l += size();
        r += size();
        
        C ret = Op::identity();
        for (; l <= r; l /= 2, r /= 2) {
            if (l % 2 == 1) ret = Op()(data[l++], ret);
            if (r % 2 == 0) ret = Op()(ret, data[r--]);
        }
        return ret;
    }
};

 

Application

가장 쉬운 문제인 구간 합 구하기: C = int로 놓고 Op를 다음과 같이 구현한다.

 

struct Op {
    int operator()(int l, int r) {
        return l + r;
    }
    
    static int identity() {
        return 0;
    }
};

 

수열과 쿼리 16: 가장 작은 값이 아니라 가장 작은 값의 인덱스를 출력하라는 조건이 까다롭다. C = struct {int value, int idx}로 놓으면 C와 C 사이에서 작동하는 이항연산 min을 정의해줄 수 있다.

struct Cell {
    int value;
    int idx;
};

이제 이항연산을 다음과 같이 정의한다.

struct Op {
    Cell operator()(const Cell& l, const Cell& r) {
        if ((l.value < r.value) || (l.value == r.value && l.idx < r.idx)) {
            return Cell{l.value, l.idx};
        } else {
            return Cell{r.value, r.idx};
        }
    }
    
    static Cell identity() {
        return Cell{std::numeric_limits<int>::max(), std::numeric_limits<int>::max()};
    }
};

이제 이항연산을 할 때마다 value뿐만 아니라 그 value의 index도 같이 살아남는다. 출력할 때는 셀 하나를 받아온 뒤 그 셀의 인덱스만 남겨서 출력하면 된다.

댓글