본문 바로가기
Computer Science/PS

[백준 25952] Rectangles (ICPC 2022 Seoul Internet; Diamond V)

by invrtd.h 2023. 3. 2.

https://www.acmicpc.net/problem/25952

 

25952번: Rectangles

Your program is to read from standard input. The input starts with a line containing an integer $n$ ($1 ≤ n ≤ 70\,000$), where $n$ is the number of points given in the plane. In the following $n$ lines, each line contains two integers that represent, r

www.acmicpc.net

 

 n = 70 000이라는 제약조건이 상당히 생소한 문제다. 보통 n = 100 000 정도를 주면 O(n log n) ~ O(n)이고 n = 5 000 정도를 주면 O(n^2)인데 이건 뭐 하라는 건지 싶다. O(n^1.5)가 정해라는데, 난 아무리 머리를 굴려보아도 O(n^1.5) 풀이가 어떻게 돌아가는지 감을 못 잡겠다. 이 글에서는 O(n^1.666) 풀이를 소개한다. 

 

 집합 P(y)를, y좌표가 y인 모든 점들의 x좌표를 저장해 놓은 집합이라 하자. 우리는 P(y1)과 P(y2)에 모두 x1과 x2가 포함되어 있도록 하는 x1, x2, y1, y2 쌍의 개수를 세어야 한다. 이제 변수 N과 M을 다음과 같이 정의한다.

 N := n

 M := P(y)의 크기의 최댓값

 K := P(y)들의 개수

 M은 1부터 N까지 다양한 값을 가질 수 있다. 이제 다음의 두 풀이를 생각하자.

 

 풀이 1)

 모든 P(y1), P(y2) 쌍 (y1 < y2)에 대해서, P(y1) 교집합 P(y2)의 원소의 개수를 나이브하게 계산한다(...). 이 과정은 hashmap 등을 써서 linear time에 해결 가능하다. 원소의 개수가 x개라면, 그때마다 결괏값에 xC2를 더한다.

 이 풀이의 시간복잡도는 O(K^2 * M)이다. 최악의 경우 시간복잡도는 O(N^3)이다.

 

 풀이 2)

 가장 먼저 두 점의 pair를 저장하는 multiset U를 정의한다.

 P(1), P(2), ...에 대해서, 각각의 P(y)에서 두 점을 선택할 수 있는 모든 경우의 수를 담은 집합 X를 생각하자. 예를 들어

 P(3) = {3, 4, 5, 6}이면, X = {{3, 4}, {3, 5}, {3, 6}, {4, 5}, {4, 6}, {5, 6}}

 X를 구했으면, U에 X의 모든 원소를 넣는다. 절차를 반복하면 U에는 겹치는 원소가 있을 수 있다. x개가 겹치는 원소를 발견하면, 그때마다 결괏값에 xC2를 더한다.

 이 풀이의 시간복잡도는 O(N * M^2)이다. 최악의 경우 시간복잡도는 O(N^3)이다.

 

 핵심 아이디어는 풀이 1이 최악의 경우일 때 풀이 2를 쓰고, 풀이 2가 최악의 경우일 때 풀이 1을 써서 어떻게든 최악을 피해가는 것이다. 이를 위해 다음의 관찰이 중요하다.

  1. 풀이 1은 size(P(y))의 값이 클 때 효율적이다. 왜냐하면 이 값이 크면 K가 작아질 가능성이 높기 때문이다.
  2. 풀이 2는 size(P(y))의 값이 작을 때 효율적이다. 시간복잡도만 봐도 자명하다.

 따라서 다음과 같은 설계를 생각한다.

 각각의 P(y)들을 large와 small로 분류한다. (기준은, 정해진 상수 R에 대해, P(y) >= R?)

 large-large, large-small 관계는 풀이 1을 써서 계산한다. small-small 관계는 풀이 2를 써서 계산한다.

 

 그러면 다음과 같은 일이 생긴다.

  1. large-large 관계 + large-small 관계: large에 속하는 P의 개수는 O(N/R)개다. 각각의 P(y)마다 모든 점에 대해서 hashmap을 돌려서 그 점의 x좌표가 P(y)에 존재하는지 아닌지를 판정하게 된다. 따라서 전체 시간복잡도는 O(N^2 / R).
  2. small-small 관계: small에 속하는 P의 개수는 어쨌든 O(N)개고... M = R을 넣으면, 시간복잡도는 O(N * R^2)이다.

 이제 R = N^0.333을 넣으면, 어떤 경우에서든 시간복잡도 O(N^1.666)을 보장할 수 있다. 나는 첫 시도로 R = 70 000^0.333 = 41을 넣었는데 거의 2초를 다 채워서 돌았고, R = 160을 쓰니까 200ms 정도까지 빨라졌다. 이런 식으로 case-work 최적화하는 문제들이 다 그렇지만 최적의 R값을 찾으려면 결국 프로그램을 여러 번 돌려봐야 한다.

 

 

#include <bits/stdc++.h>

void solve() {
    int n;
    std::cin >> n;
    
    constexpr int SQ = 160;
    
    std::unordered_map<int, std::unordered_set<int>> points;
    // points[y] = {x1, x2, ...}
    for (int i = 0; i < n; ++i) {
        int x, y;
        std::cin >> x >> y;
        
        if (auto it = points.find(y); it != points.end()) {
            it->second.insert(x);
        } else {
            points.insert({y, std::unordered_set<int>{x}});
        }
    }
    
    std::vector<std::unordered_set<int> *> large, small;
    for (auto &[y, xset] : points) {
        if (xset.size() > SQ) {
            large.push_back(&xset);
        } else {
            small.push_back(&xset);
        }
    }
    
    long long large_large_sum = 0;
    long long large_small_sum = 0;
    long long small_small_sum = 0;
    
    for (const auto *p : large) {
        const auto &large_xset = *p;
        for (const auto &[y, xset] : points) {
            if (&large_xset == &xset) {
                continue;
            }
            long long count = 0;
            for (int x : xset) {
                if (large_xset.find(x) != large_xset.end()) {
                    ++count;
                }
            }
            long long to_add = count * (count - 1) / 2;
            if (xset.size() > SQ) {
                large_large_sum += to_add;
            } else {
                large_small_sum += to_add;
            }
        }
    }
    
    std::unordered_map<long long, long long> count;
    for (const auto *p : small) {
        const auto &small_xset = *p;
        std::vector<long long> xvector(small_xset.begin(), small_xset.end());
        int ss = (int) xvector.size();
        for (int i = 0; i < ss; ++i) {
            for (int j = i + 1; j < ss; ++j) {
                long long l = std::min(xvector[i], xvector[j]);
                long long r = std::max(xvector[i], xvector[j]);
                auto val = (l << 32) + r;
                if (auto it = count.find(val); it != count.end()) {
                    it->second++;
                } else {
                    count.insert({val, 1});
                }
            }
        }
    }
    
    for (auto [key, value] : count) {
        small_small_sum += value * (value - 1) / 2;
    }
    
    std::cout << large_large_sum / 2 + large_small_sum
            + small_small_sum << '\n';
}

#ifndef LOCAL
int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);
    solve();
}
#endif

댓글