행렬이 주어질 때 그 행렬을 $w$X$w$마다 압축하는 문제이다.
압축한다는것은 $w$X$w$행렬에서 중앙값을 추출한다는 것이다.
naive하게 생각하면 $O(nmw^{2})$이라 당연하게 시간초과가 발생할 것이다.
segment tree와 plane sweeping을 이용하면 $O(nmw$ $lg(k))$만에 해결할 수 있다.
($k$는 여기서 segment tree의 크기)
$n,m,w$가 각각 5,6,3이라면 위와 같이 훑어서 전체값을 구할 수 있다.
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
using namespace std;
int N, M, L, K;
int arr[301][301];
vector<vector<int>> ans;
struct Segment {
vector<int> tree;
int size;
Segment() {}
Segment(int N) :size(N + 1) { tree.resize(4 * (N + 1), 0); }
int update(int idx, int val, int node, int nl, int nr) {
if (idx < nl || idx > nr) return tree[node];
if (nl == nr) return tree[node] += val;
int mid = (nl + nr) >> 1;
return tree[node] = update(idx, val, node * 2, nl, mid) + update(idx, val, node * 2 + 1, mid + 1, nr);
}
void update(int idx, int val) {
update(idx, val, 1, 0, size - 1);
}
int query(int val, int node, int nl, int nr) {
if (nl == nr) return nl;
int mid = (nl + nr) >> 1;
if (tree[node * 2] >= val) return query(val, node * 2, nl, mid);
else return query(val - tree[node * 2], node * 2 + 1, mid + 1, nr);
}
int query(int val) {
return query(val, 1, 0, size - 1);
}
};
int main() {
scanf("%d%d%d%d", &N, &M, &L, &K);
Segment seg(L);
for (int n = 0;n < N;n++) for (int m = 0;m < M;m++) scanf("%d", &arr[n][m]);
for (int n = 0;n < K;n++) for (int m = 0;m < K;m++)seg.update(arr[n][m], 1);
for (int n = 0;n < N - K + 1;n++) {
if (!(n & 1)) { // ->
vector<int> tmp;
if (n != 0)
for (int m = 0;m < K;m++) seg.update(arr[n - 1][m], -1), seg.update(arr[n + K - 1][m], 1);
for (int m = 0;m < M - K + 1;m++) {
if (m == 0) {}
else {
for (int y = n;y < n + K;y++) seg.update(arr[y][m - 1], -1);
for (int y = n;y < n + K;y++) seg.update(arr[y][m + K - 1], 1);
}
tmp.push_back(seg.query((K*K + 1) / 2));
}
ans.push_back(tmp);
}
else { // <-
vector<int> tmp;
for (int m = M - K;m < M;m++) seg.update(arr[n - 1][m], -1), seg.update(arr[n + K - 1][m], 1);
for (int m = M - K;m >= 0;m--) {
if (m == M - K) {}
else {
for (int y = n;y < n + K;y++) seg.update(arr[y][m + K], -1);
for (int y = n;y < n + K;y++) seg.update(arr[y][m], 1);
}
tmp.push_back(seg.query((K*K + 1) / 2));
}
reverse(tmp.begin(), tmp.end());
ans.push_back(tmp);
}
}
for (auto &n : ans) {
for (auto &m : n) printf("%d ", m);
printf("\n");
}
return 0;
}
| cs |
푸니까 1초제한에서 888ms가 나와서 조금 찝찝하긴 하지만 다른분들 풀이랑 별 차이가 없는듯 하다.
어휴.. 저 문제.. multiset 썼더니 그대로 시간초과 나네요..
답글삭제대충 3-4배 가량 느린 거 같네요..
결국 sum 구하는 Query가 있는 코드
그대로 복붙해서 맞긴 했는데.. 제한을 5초 정도로 줘도 될 거 같긴 한데.. 너무 빡빡하네요.
아무래도 백준에 많이 들어오는 질문 중 하나가 set이 왜 이렇게 느리냐는 건데..
삭제이걸 직접 체험할 줄은 몰랐네요..
문제 의도가 무엇이였는지는 정확히 모르겠네요.
시간을 측정해 보니.. multiset을 이용한 풀이는 제한 시간이 5초는 되어야 ac날 거 같더라고요.
분명 O(4*(n-w+1)*(m-w+1)*w*log(w))를 의도한 것인 듯 싶은데..
multiset을 써도.. 복잡도는 같거든요.. 단지 앞에 붙는 상수만 커질 뿐인데..
어렵네요.. 상당히.. 이 문제.. 여러모로..
생각보다 시간이 많이나와서 놀랐었어요 ㅋㅋ
삭제2초로 늘려줬으면 좋겠네여