본문 바로가기
Computer Science/PS

[백준 18292] NM과 K (2) (Platinum IV)

by invrtd.h 2023. 1. 29.

https://www.acmicpc.net/problem/18292

 

18292번: NM과 K (2)

크기가 N×M인 격자판의 각 칸에 정수가 하나씩 들어있다. 이 격자판에서 칸 K개를 선택할 것이고, 선택한 칸에 들어있는 수를 모두 더한 값의 최댓값을 구하려고 한다. 단, 선택한 두 칸이 인접

www.acmicpc.net

 

 1014번 컨닝과 비슷한 Bitmask DP 문제다. 티어도 컨닝과 똑같은 티어가 찍혔다.

 

 3차원 dp로 해결 가능한데 for문은 4중첩으로 돌려야 한다.

 dp[i][j_bits][k]를 다음과 같이 정의하자 : i번째 행까지 봤고, i번째 행에서 선택을 j_bits와 같이 했으며, 0..i행 통틀어서 k개의 선택을 했을 때, 점수의 최댓값

 그 다음 dp[i]에서 dp[i+1]로 가는 상태전이를 bottom-up으로 계산해주면 된다. 그러나 구현을 잘못하면 시간 초과가 날 수도 있기에 시간복잡도를 미리 계산해주는 센스가 필요하다.

 일단 상태전이의 개수는 I * (2^J) * K개다. 그리고 각각의 상태전이마다 다음 state로 넘어갈 수 있는 경우의 수가 2^J개다. 따라서 전체 시간복잡도는 O(I * (4^J) * K)이고, I = J = 10, K = 50 대입하면 값은 5억이 나와서 시간 초과다. 이 문제를 어떻게 해결할 수 있을까? 다음 state로 넘어갈 수 있는 경우의 수를 2^J보다 더 작게 할 수 있음을 관찰한다. 왜냐하면 인접한 두 칸을 고를 수 없다는 조건이 있으므로, 상태 전이의 수는 J를 인덱스로 하는 피보나치 수기 때문이다. 따라서 이 값을 1.6^J까지 줄일 수 있다. 이 과정을 백트래킹으로 해결 가능하다. 

 

 시간복잡도는 O(I * (3.2)^J * K)이다.

#ifndef LOCAL
#pragma GCC optimize("O3")
#endif

#include <bits/stdc++.h>

constexpr int NIL = -100'000'000;

auto next(unsigned int bits, int J) {
    std::vector<unsigned int> ret(1, 0);
    for (int bi = 0; bi < J; ++bi) {
        const int SZ = ret.size();
        for (int i = 0; i < SZ; ++i) {
            unsigned int x = ret[i] | (1 << bi);
            if (x & bits) {continue;}
            if (ret[i] & (1 << (bi - 1))) {continue;}
            ret.push_back(x);
        }
    }
    
    return ret;
}

void solve() {
    int I, J, K;
    std::cin >> I >> J >> K;
    
    std::vector grid(I, std::vector<int>(J));
    for (auto &v : grid) {
        for (int &i : v) {
            std::cin >> i;
        }
    }
    
    auto get_sum = [&grid, J](int i, unsigned int bits) {
        int ret = 0;
        for (int b = 0; b < J; ++b) {
            if (bits & (1 << b)) {
                ret += grid[i][b];
            }
        }
        return ret;
    };
    
    std::vector nexts_vec(1 << J, std::vector<unsigned int>());
    for (unsigned int b = 0; b < (1 << J); ++b) {
        nexts_vec[b] = next(b, J);
    }
    
    std::vector dp(1 << J, std::vector<int>(K + 1, NIL));
    auto temp = dp;
    
    const auto &next_0 = nexts_vec[0];
    for (unsigned int bits : next_0) {
        dp[bits][std::popcount(bits)] = get_sum(0, bits);
    }
    
    for (int i = 1; i < I; ++i) {
        temp = dp;
        
        for (unsigned int bits = 0; bits < (1 << J); ++bits) {
            const auto &nexts = nexts_vec[bits];
            for (int k = 0; k <= K; ++k) {
                if (dp[bits][k] == NIL) {continue;}
                
                for (unsigned int next_bits : nexts) {
                    int next_bitcount = k + std::popcount(next_bits);
                    if (next_bitcount > K) {continue;}
                    
                    temp[next_bits][next_bitcount] =
                        std::max(temp[next_bits][next_bitcount],
                                 dp[bits][k] + get_sum(i, next_bits));
                }
            }
        }
        
        std::swap(temp, dp);
    }
    
    int ret = NIL;
    for (const auto &v : dp) {
        ret = std::max(ret, v[K]);
    }
    std::cout << ret << '\n';
}

#ifndef LOCAL
int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);
    solve();
}
#endif

댓글