레이블이 14603 소금과 후추(Large)인 게시물을 표시합니다. 모든 게시물 표시
레이블이 14603 소금과 후추(Large)인 게시물을 표시합니다. 모든 게시물 표시

2018년 1월 4일 목요일

14603 소금과 후추(Large)

14603 소금과 후추(Large) https://www.acmicpc.net/problem/14603

행렬이 주어질 때 그 행렬을 $w$X$w$마다 압축하는 문제이다.
압축한다는것은 $w$X$w$행렬에서 중앙값을 추출한다는 것이다.

naive하게 생각하면 $O(nmw^{2})$이라 당연하게 시간초과가 발생할 것이다.

segment treeplane sweeping을 이용하면 $O(nmw$ $lg(k))$만에 해결할 수 있다.
($k$는 여기서 segment tree의 크기)

$n,m,w$가 각각 5,6,3이라면 위와 같이 훑어서 전체값을 구할 수 있다.
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
using namespace std;
int N, M, L, K;
int arr[301][301];
vector<vector<int>> ans;
struct Segment {
    vector<int> tree;
    int size;
    Segment() {}
    Segment(int N) :size(N + 1) { tree.resize(4 * (N + 1), 0); }
    int update(int idx, int val, int node, int nl, int nr) {
        if (idx < nl || idx > nr) return tree[node];
        if (nl == nr) return tree[node] += val;
        int mid = (nl + nr) >> 1;
        return tree[node] = update(idx, val, node * 2, nl, mid) + update(idx, val, node * 2 + 1, mid + 1, nr);
    }
    void update(int idx, int val) {
        update(idx, val, 10size - 1);
    }
    int query(int val, int node, int nl, int nr) {
        if (nl == nr) return nl;
        int mid = (nl + nr) >> 1;
        if (tree[node * 2>= val) return query(val, node * 2, nl, mid);
        else return query(val - tree[node * 2], node * 2 + 1, mid + 1, nr);
    }
    int query(int val) {
        return query(val, 10size - 1);
    }
};
int main() {
    scanf("%d%d%d%d"&N, &M, &L, &K);
    Segment seg(L);
    for (int n = 0;n < N;n++for (int m = 0;m < M;m++scanf("%d"&arr[n][m]);
    for (int n = 0;n < K;n++for (int m = 0;m < K;m++)seg.update(arr[n][m], 1);
    for (int n = 0;n < N - K + 1;n++) {
        if (!(n & 1)) {        // ->
            vector<int> tmp;
            if (n != 0)
                for (int m = 0;m < K;m++) seg.update(arr[n - 1][m], -1), seg.update(arr[n + K - 1][m], 1);
            for (int m = 0;m < M - K + 1;m++) {
                if (m == 0) {}
                else {
                    for (int y = n;y < n + K;y++) seg.update(arr[y][m - 1], -1);
                    for (int y = n;y < n + K;y++) seg.update(arr[y][m + K - 1], 1);
                }
                tmp.push_back(seg.query((K*+ 1/ 2));
            }
            ans.push_back(tmp);
        }
        else {            // <-
            vector<int> tmp;
            for (int m = M - K;m < M;m++) seg.update(arr[n - 1][m], -1), seg.update(arr[n + K - 1][m], 1);
            for (int m = M - K;m >= 0;m--) {
                if (m == M - K) {}
                else {
                    for (int y = n;y < n + K;y++) seg.update(arr[y][m + K], -1);
                    for (int y = n;y < n + K;y++) seg.update(arr[y][m], 1);
                }
                tmp.push_back(seg.query((K*+ 1/ 2));
            }
            reverse(tmp.begin(), tmp.end());
            ans.push_back(tmp);
        }
    }
    for (auto &n : ans) {
        for (auto &m : n) printf("%d ", m);
        printf("\n");
    }
    return 0;
}
cs

푸니까 1초제한에서 888ms가 나와서 조금 찝찝하긴 하지만 다른분들 풀이랑 별 차이가 없는듯 하다.