https://boj.ma/24501/t 

 

24501번: blobaww

 

boj.ma

 

E, S, M으로만 이루어진 2차원 배열에서 ESM을 만드는 모든 경우의 수를 찾는 문제입니다.

단, E의 좌표(x1, y1) S의 좌표(x2, y2) M의 좌표(x3, y3)가 있을때

x1 <= x2 <= x3, y1 <= y2 <= y3를 만족해야 합니다.

 

완전탐색부터 떠올려봅니다.

배열을 순회하며 E를 찾고, E를 찾았으면 E의 좌표보다 크거나 같은 모든 좌표를 순회하며 S를 찾고, 또 S의 좌표보다 크거나 같은 모든 좌표를 탐색하며 M을 찾으며 정답을 세어주면 됩니다.

 

하지만 N, M <= 3000이기 때문에 최악에 O(NM^3) 정도 되어 시간초과일 것입니다.

 

S를 기준으로 생각해 봅니다.

배열을 순회하며 S를 만날 때마다 배열에서 S의 위치의 왼쪽 위에 있는 모든 E의 갯수와

S의 오른쪽 아래 있는 모든 M의 갯수를 곱하면 만들 수 있는 ESM의 갯수를 알 수 있습니다.

 

정확히 말하면

x행 y열에 S가 있다고 하면,

(1, 1) ~ (x, y)에 존재하는 모든 E의 갯수 * (x, y) ~ (N ,M)에 존재하는 모든 M의 갯수

를 정답에 누적해 주면 됩니다.

 

이것 역시 완전탐색으로 찾게 되면 시간초과입니다.

 

"2차원 배열에서 특정 구간에 존재하는 숫자의 합을 빠르게 구하기"는 누적합으로 해줄 수 있습니다.

E가 존재하면 1을 기록해놓은 배열, M이 존재하면 1을 기록해놓은 배열을 만들어두고

각 배열에 대한 누적합 배열을 만들어줍니다.

 

E의 개수에 대한 누적합과 M의 개수에 대한 누적합만 전처리 해주면 총 시간복잡도 O(NM)에 정답을 구할 수 있습니다.

 

누적합 계산을 편하게 하기위해 인덱스는 1부터 시작했습니다.

 

코드(c++)

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define endl "\n"
#define NM 100101
#define MAX 1010100
#define BIAS 1048576
#define MOD 1000000007
#define X first
#define Y second
#define INF 0x3f3f3f3f
#define FOR(i) for(int _=0;_<(i);_++)
#define pii pair<int, int>
#define pll pair<ll, ll>
#define all(v) v.begin(), v.end()
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef __int128 i128;
using namespace std;
// pbds
using namespace __gnu_pbds;
// multi-set
using ordered_set_equal = tree<int, null_type, less_equal<int>, rb_tree_tag,tree_order_statistics_node_update>;
// set
using ordered_set = tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update>;


int main() {
    fastio
#ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
#endif

    int n, m;
    cin >> n >> m;
    vector<vector<char>> arr(n + 1, vector<char>(m + 1));
    vector<vector<ll>> arr_e(n + 1, vector<ll>(m + 1, 0));
    vector<vector<ll>> arr_m(n + 1, vector<ll>(m + 1, 0));

    for (int i = 1 ;i <= n ; i++){
        for (int j = 1 ; j <= m ; j++){
            cin >> arr[i][j];
            if (arr[i][j] == 'E') arr_e[i][j]++;
            if (arr[i][j] == 'M') arr_m[i][j]++;
        }
    }


    vector<vector<ll>> psum_e(n + 1, vector<ll>(m + 1, 0));
    vector<vector<ll>> psum_m(n + 1, vector<ll>(m + 1, 0));

    for (int i = 1 ;i <= n ; i++){
        for (int j = 1 ; j <= m ; j++){
            psum_e[i][j] = psum_e[i - 1][j] + psum_e[i][j - 1] - psum_e[i - 1][j - 1] + arr_e[i][j];
            psum_m[i][j] = psum_m[i - 1][j] + psum_m[i][j - 1] - psum_m[i - 1][j - 1] + arr_m[i][j];
        }
    }
    
    ll ans = 0;
    for (int i = 1; i <= n; i++){
        for (int j =1 ; j <= m ; j++){
            if (arr[i][j] == 'S'){
                ll e_cnt = psum_e[i][j];
                ll m_cnt = psum_m[n][m] - psum_m[i - 1][m] - psum_m[n][j - 1] + psum_m[i - 1][j - 1];
                ans += e_cnt * m_cnt;
                ans %= MOD;
            }
        }
    }
    cout << ans << endl;

    return 0;
}

 

 

 

+ Recent posts