2018년 1월 4일 목요일

14603 소금과 후추(Large)

14603 소금과 후추(Large) https://www.acmicpc.net/problem/14603

행렬이 주어질 때 그 행렬을 $w$X$w$마다 압축하는 문제이다.
압축한다는것은 $w$X$w$행렬에서 중앙값을 추출한다는 것이다.

naive하게 생각하면 $O(nmw^{2})$이라 당연하게 시간초과가 발생할 것이다.

segment treeplane sweeping을 이용하면 $O(nmw$ $lg(k))$만에 해결할 수 있다.
($k$는 여기서 segment tree의 크기)

$n,m,w$가 각각 5,6,3이라면 위와 같이 훑어서 전체값을 구할 수 있다.
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
using namespace std;
int N, M, L, K;
int arr[301][301];
vector<vector<int>> ans;
struct Segment {
    vector<int> tree;
    int size;
    Segment() {}
    Segment(int N) :size(N + 1) { tree.resize(4 * (N + 1), 0); }
    int update(int idx, int val, int node, int nl, int nr) {
        if (idx < nl || idx > nr) return tree[node];
        if (nl == nr) return tree[node] += val;
        int mid = (nl + nr) >> 1;
        return tree[node] = update(idx, val, node * 2, nl, mid) + update(idx, val, node * 2 + 1, mid + 1, nr);
    }
    void update(int idx, int val) {
        update(idx, val, 10size - 1);
    }
    int query(int val, int node, int nl, int nr) {
        if (nl == nr) return nl;
        int mid = (nl + nr) >> 1;
        if (tree[node * 2>= val) return query(val, node * 2, nl, mid);
        else return query(val - tree[node * 2], node * 2 + 1, mid + 1, nr);
    }
    int query(int val) {
        return query(val, 10size - 1);
    }
};
int main() {
    scanf("%d%d%d%d"&N, &M, &L, &K);
    Segment seg(L);
    for (int n = 0;n < N;n++for (int m = 0;m < M;m++scanf("%d"&arr[n][m]);
    for (int n = 0;n < K;n++for (int m = 0;m < K;m++)seg.update(arr[n][m], 1);
    for (int n = 0;n < N - K + 1;n++) {
        if (!(n & 1)) {        // ->
            vector<int> tmp;
            if (n != 0)
                for (int m = 0;m < K;m++) seg.update(arr[n - 1][m], -1), seg.update(arr[n + K - 1][m], 1);
            for (int m = 0;m < M - K + 1;m++) {
                if (m == 0) {}
                else {
                    for (int y = n;y < n + K;y++) seg.update(arr[y][m - 1], -1);
                    for (int y = n;y < n + K;y++) seg.update(arr[y][m + K - 1], 1);
                }
                tmp.push_back(seg.query((K*+ 1/ 2));
            }
            ans.push_back(tmp);
        }
        else {            // <-
            vector<int> tmp;
            for (int m = M - K;m < M;m++) seg.update(arr[n - 1][m], -1), seg.update(arr[n + K - 1][m], 1);
            for (int m = M - K;m >= 0;m--) {
                if (m == M - K) {}
                else {
                    for (int y = n;y < n + K;y++) seg.update(arr[y][m + K], -1);
                    for (int y = n;y < n + K;y++) seg.update(arr[y][m], 1);
                }
                tmp.push_back(seg.query((K*+ 1/ 2));
            }
            reverse(tmp.begin(), tmp.end());
            ans.push_back(tmp);
        }
    }
    for (auto &n : ans) {
        for (auto &m : n) printf("%d ", m);
        printf("\n");
    }
    return 0;
}
cs

푸니까 1초제한에서 888ms가 나와서 조금 찝찝하긴 하지만 다른분들 풀이랑 별 차이가 없는듯 하다.

댓글 3개:

  1. 어휴.. 저 문제.. multiset 썼더니 그대로 시간초과 나네요..
    대충 3-4배 가량 느린 거 같네요..

    결국 sum 구하는 Query가 있는 코드
    그대로 복붙해서 맞긴 했는데.. 제한을 5초 정도로 줘도 될 거 같긴 한데.. 너무 빡빡하네요.

    답글삭제
    답글
    1. 아무래도 백준에 많이 들어오는 질문 중 하나가 set이 왜 이렇게 느리냐는 건데..
      이걸 직접 체험할 줄은 몰랐네요..

      문제 의도가 무엇이였는지는 정확히 모르겠네요.
      시간을 측정해 보니.. multiset을 이용한 풀이는 제한 시간이 5초는 되어야 ac날 거 같더라고요.

      분명 O(4*(n-w+1)*(m-w+1)*w*log(w))를 의도한 것인 듯 싶은데..
      multiset을 써도.. 복잡도는 같거든요.. 단지 앞에 붙는 상수만 커질 뿐인데..
      어렵네요.. 상당히.. 이 문제.. 여러모로..

      삭제
    2. 생각보다 시간이 많이나와서 놀랐었어요 ㅋㅋ
      2초로 늘려줬으면 좋겠네여

      삭제