본문 바로가기
Computer Science/PS

라빈-카프 알고리즘 구현 꿀팁

by invrtd.h 2023. 2. 7.

라빈-카프의 전체적인 동작 원리를 설명하는 글은 아니다.

 

  1. 문자열 s = a0a1a2...를 해싱으로 f(s)로 보내버릴 때, 다항식을 f = a0 + a1 * c + a2 * c^2 + ... 이렇게 인덱스를 맞춰서 잡기보다는 거꾸로 f = an + a_(n-1) * c + ... 이렇게 잡는 게 더 구현이 쉽다. hashval = (hashval * c + s[i]) % M을 재귀적으로 반복해서 구현할 수 있기 때문.
  2. c의 모든 거듭제곱을 전처리하지 않아도 된다. C = c ^ k만 전처리해서 어디다가 저장해 두고 있어도 된다.
  3. 첫 번째 해시값을 구했으면 한 칸씩 옮기면서 다음 해시값을 hashval = (hashval * c + s[i] - C * s[i - k]) % M으로 O(1)에 구할 수 있다. 그런데 여기서 조심해야 할 것은 음수에다가 대고 나머지 연산을 함부로 돌렸다가 값이 음수가 될 수도 있다는 것이다. 따라서 실제 계산은 hashval = (hashval * c + s[i] + (M - C) * s[i - k]) % M으로 해 줘야 한다. 상당히 귀찮다.
  4. 해시 충돌이 한 번이라도 일어나기 위해서는 O(sqrt(H)) 크기의 해시값이 필요하다고 한다. H는 해시 값으로 가능한 모든 수이므로 이 코드에서는 M과 같다. 그런데 int64 오버플로를 막기 위해 M을 보통 1,000,000,007 정도로 설정하는데 그러면 s.size() > 40,000 정도에서 해시 충돌이 날 가능성이 높다. 이를 해결하기 위해서는 서로 다른 2개의 M을 잡고, 각각 해싱을 돌린 뒤, 최종 해시 값을 h1 << 32 + h2로 잡는다.
  5. 이 과정에서 클로저 개념을 알고 있으면 함수 인자의 개수를 적절히 유지하면서 코드 중복을 많이 줄일 수 있다. 그런데 M이 상수이므로, C++에서는 템플릿을 대신 사용 가능하다.
  6. 코루틴을 알면 더 쉽게 구현할 수 있을 것 같기도 하다. 근데 코루틴 자체가 C++20에서 도입된 데다가 아직 완전하지 못하다 보니...

 

전체 소스 코드는 다음과 같다.

 

#include <bits/stdc++.h>

constexpr long long M1 = 1'000'000'007;
constexpr long long M2 = 998'244'353;
constexpr long long c = 128;

template<long long M>
auto hash_division(const std::string &s, int k) {
    int C = 1;
    for (int i = 0; i < k; ++i) {
        C = c * C % M;
    }
    
    long long hashval = 0;
    for (int i = 0; i < k; ++i) {
        hashval = (hashval * c + s[i]) % M;
    }
    
    std::vector<long long> ret(s.size() - k + 1);
    ret[0] = hashval;
    
    for (int i = k; i < s.size(); ++i) {
        hashval = (hashval * c + s[i] + (M - C) * s[i - k]) % M;
        ret[i - k + 1] = hashval;
    }
    
    return ret;
}

template<long long N1, long long N2>
auto hash_division_with_2_ints(const std::string &s, int k) {
    auto ret1 = hash_division<N1>(s, k);
    auto ret2 = hash_division<N2>(s, k);
    
    auto ret = std::vector<long long>(s.size() - k + 1);
    for (int i = 0; i < ret.size(); ++i) {
        ret[i] = (ret1[i] << 32) | ret2[i];
    }
    
    return ret;
}

int find(std::string a, std::string b) {
    const int B = b.size();
    auto adiv = hash_division_with_2_ints<M1, M2>(a, B);
    long long b_hash = hash_division_with_2_ints<M1, M2>(b, B)[0];
    
    int ret = 0;
    for (long long n : adiv) {
        if (n == b_hash) {
            ++ret;
        }
    }
    
    return ret;
}

 

댓글