본문 바로가기
Computer Science/PS

[백준 16998] It's a Mod, Mod, Mod, Mod World [Diamond IV]

by invrtd.h 2023. 2. 19.

https://www.acmicpc.net/problem/16998

 

16998번: It’s a Mod, Mod, Mod, Mod World

You are given multiple problems with three integers p, q, and n. Find \(\displaystyle\sum_{i=1}^{n}{((p \cdot i) \text{ mod } q)}\). That is, the first n multiples of p, modulo q, summed. Note that the overall sum has no modulus.

www.acmicpc.net

 

 \( \sum_1^n {\left( pi \bmod q \right)} \)의 값을 구하는 문제. 50만 개 정도의 쿼리가 주어지는데 제한시간이 5초이므로 한 번 계산할 때마다 \( O \left( \log 10^6 \right) \) 정도의 시간을 써서 문제를 풀어야 한다. 제약조건만 보면 lookup table을 만드는 것도 고려해볼 수 있지만 이 문제에서는 lookup table이 별 소용이 없다.

 

 직접 몇 항을 써보면서 규칙을 찾아보려고 하면 알 수 있겠지만 거의 무질서 수준으로 규칙이 없다. 따라서 식을 어떻게든 잘 변형해줘야 한다. 먼저 떠올릴 수 있는 것이 다음의 자명한 공식이다.

 

$$ a \bmod b = a - \lfloor {a \over b} \rfloor \times b $$

 

 이렇게 바꿔 주면 바닥 함수는 나머지 함수의 변화에 비해 그나마 규칙적인 것처럼 보인다는 것을 알 수 있을 것이다. 그래서 맨 처음에 주어진 공식은 다음과 같이 바꿔 쓸 수 있다.

 

$$ \sum_1^n pi - \lfloor {pi \over q} \rfloor \times q = {p \times n \times (n - 1) \over 2} - q \times \sum_1^n \lfloor {ip \over q} \rfloor $$

 

 

 sum 바깥에 있는 친구들은 다 상수 시간에 계산 가능하니까 sum 안쪽에 있는 것만 계산해 주면 된다. 여기서부터는 상수 시간 안에 해결되는 일반화된 공식 같은 것이 없어서 재귀함수를 돌려 줘야 한다. 몇 가지 간단한 숫자들에 대해서 생각을 해 보면, 예컨대 n = 5, p = 4, q = 7일 때 다음과 같은 그림을 그릴 수 있다.

 

 여기서 동그라미 쳐진 격자점의 개수가 바로 \( \sum_1^n \lfloor {ip \over q} \rfloor \)이다. 그런데 그 왼쪽에 x 표시된 격자점도 있다. 이 격자점의 개수도 사실 동그라미 쳐진 격자점의 개수와 비슷한 방식으로 구할 수 있다. x축과 y축을 뒤집으면 되기 때문이다. x 격자점의 수는 n, p, q로 나타내면

 

 $$ \sum_1^{\lfloor {np \over q} \rfloor} \lfloor {iq \over p} \rfloor $$

 

 이고, o 격자점과 x 격자점의 수의 합은

 

 $$ n \times \lfloor {np \over q} \rfloor + \lfloor {n \over q} \rfloor $$

 

 이다. (문제: 2번째 항이 왜 들어갈까요?) 바닥 함수의 특성에 의해 분자가 분모보다 크면 (정수 + 진분수) 꼴로 분리해서 정수 부분을 바닥 함수 밖으로 빼 줄 수 있는데 이 절차가 유클리드 호제법의 절차와 상당히 비슷하다. 따라서 시간복잡도도 유클리드 호제법의 시간복잡도를 따라간다. 절차를 q = 1 또는 n = 0 또는 p = 0이 될 때까지 반복하면 끝.

 

#include <bits/stdc++.h>

using i64 = long long;

i64 div_sum(i64 p, i64 q, i64 n) {
    if (n == 0 or p == 0) {
        return 0;
    }
    if (q == 1) {
        return p * n * (n + 1) / 2;
    }
    if (p > q) {
        return div_sum(p % q, q, n) + n * (n + 1) / 2 * (p / q);
    }
    return n * (n * p / q) + (n / q) - div_sum(q, p, n * p / q);
}

i64 mod_sum(i64 p, i64 q, i64 n) {
    i64 gcd = std::gcd(p, q);
    return p * n * (n + 1) / 2 - q * div_sum(p / gcd, q / gcd, n);
}

void solve() {
    int t; std::cin >> t;
    
    while (t --> 0) {
        i64 p, q, n;
        std::cin >> p >> q >> n;
        std::cout << mod_sum(p, q, n) << '\n';
    }
}

#ifndef LOCAL
int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);
    solve();
}
#endif

댓글