문제는 간단하다. 매초마다 1은 132로 2는 211로 3은 232로 바뀐다. N초후에 왼쪽L번째부터 R번째까지 1,2,3숫자의 개수를 구하는 문제이다.
이 문제는 보고 어떻게 풀지 생각은 났는데 구현이 잘 안됬던 문제이다. 막연하게 연산시간을 생각하지 않고 짜면 금방 풀 수 있지만 당연히 TLE가 나올것이다. 그래서 시간을 줄이는 방법 두가지를 사용하였다.
<1>. L ~ R 이외의 범위 탐색X
1부터 시작한다면 위의 그림처럼 $N$초뒤에는 ${3}^{N}$의 숫자가 존재할 것이다. 이것은 어떤 시작점(1 또는 2또는 3)으로부터 $N$초후에 생기게 될 범위를 알 수 있다는 말이다.
예를들어 1에서 시작하여 2초뒤에 범위는 0~8(배열 0시작 기준)이다. 또 1초때 2번 index에서 시작하여 1초후의 범위는 6~8이다.
범위를 알게되면 굳이 탐색할 필요도 없는 공간에 대해 시간투자를 안해도 된다.
사실 $N$초후의 시작부터 마지막까지 1, 2, 3의 개수는 for문 N번만 돌리면 알 수 있다.
$a,b,c$를 각각 1,2,3의 개수라고 한다면 다음과 같이 나타낼 수 있다.
$a = a+2b$
$b = a+b+2c$
$c = a+c$
따라서 시작점으로부터 N초후의 범위가 완전히 L ~ R사이에 존재한다면 이것역시 탐색을 할 필요가 없게된다.
예를들어 $시작=1, L=3, R=6, N =2$이라고 하면
0초때는 0~8까지의 범위를 가질 것이다. 1초뒤에는 1, 3, 2의 숫자가 각각 0~2, 3~5, 6~8의 범위를 가지는데 온전한 범위 이내에 있는 3~5, 6~8은 탐색할 필요가 없다는것이다.
사실상 위의 예는 1초때 0~2의 범위도 범위 밖이므로 탐색할 필요가 없기때문에 1초까지만 탐색을 하면된다. (<1>의 경우)
#include <cstdio>
int X, L, R, N;
long long an1 = 0, an2 = 0, an3 = 0;
long long dp[4][21][4];
long long pow(int x, int y){
if (y == 0)
return 1;
else if (y == 1)
return x;
if (y & 1)
return pow(x, y - 1)*pow(x, 1);
else
return pow(x, y / 2)*pow(x, y / 2);
}
int sqrt(long long x){
int cnt = 0;
for (x; x != 1; x /= 3, cnt++);
return cnt;
}
void cul(){
long long a = 1, b = 0, c = 0;
long long na, nb, nc;
for (int n = 1; n <= 20; n++){
na = a + 2 * b;
nb = a + b + 2 * c;
nc = a + c;
a = na; b = nb; c = nc;
dp[1][n][1] = a; dp[1][n][2] = b; dp[1][n][3] = c;
}
a = 0; b = 1; c = 0;
for (int n = 1; n <= 20; n++){
na = a + 2 * b;
nb = a + b + 2 * c;
nc = a + c;
a = na; b = nb; c = nc;
dp[2][n][1] = a; dp[2][n][2] = b; dp[2][n][3] = c;
}
a = 0; b = 0; c = 1;
for (int n = 1; n <= 20; n++){
na = a + 2 * b;
nb = a + b + 2 * c;
nc = a + c;
a = na; b = nb; c = nc;
dp[3][n][1] = a; dp[3][n][2] = b; dp[3][n][3] = c;
}
}
void loop(int num, long long range, long long first, long long second){
if ((range != 1) && (R < first || (L > second)))
return;
if ((range != 1) && (L <= first && second <= R)){
int r = sqrt(range);
an1 += dp[num][r][1]; an2 += dp[num][r][2]; an3 += dp[num][r][3];
return;
}
if (range == 1){
if (L <= first && R >= first){
if (num == 1)
an1++;
else if (num == 2)
an2++;
else
an3++;
return;
}
else
return;
}
long long new_range = range / 3;
long long f1 = first, s1 = f1 + new_range - 1;
long long f2 = s1 + 1, s2 = f2 + new_range - 1;
long long f3 = s2 + 1, s3 = f3 + new_range - 1;
if (num == 1){
loop(1, new_range, f1, s1);
loop(3, new_range, f2, s2);
loop(2, new_range, f3, s3);
}
else if (num == 2){
loop(2, new_range, f1, s1);
loop(1, new_range, f2, s2);
loop(1, new_range, f3, s3);
}
else if (num == 3){
loop(2, new_range, f1, s1);
loop(3, new_range, f2, s2);
loop(2, new_range, f3, s3);
}
return;
}
int main(){
scanf("%d%d%d%d", &X, &L, &R, &N);
cul();
loop(X, pow(3, N), 0, pow(3, N) - 1);
printf("%d %d %d\n", an1, an2, an3);
return 0;
}
| cs |