2017년 7월 25일 화요일

7812 중앙 트리

7812 중앙 트리 https://www.acmicpc.net/problem/7812

요즘 트리가 재미있다. 이문제도 재미있다. 
중앙 정점은 모든 정점으로 이르는 비용의 합이 가장 작은 정점이다.
모든 정점과 중앙 정점까지의 거리의 합을 구하는 문제이다.

codeforces풀기까지 시간이 얼마 남지 않아 제대로 풀질 못했었다.
떠오른 생각은 subtree를 이용하는 거였다.
up-subtreedown-subtree를 이용하여 DFS를 2번돌려 풀 수 있다.

DFS를 우선 한번 돌려 현재 위치와 down-subtree의 모든 정점에서의 거리합과 
정점의 갯수도 구해놓는다.

DFS를 다시 돌려 up-subtree의 거리합을 정점의 갯수와 트리의 특성상 
root노드는 down-subtree만을 갖는 점을 이용해서 구할 수 있다.
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
typedef long long ll;
#define min(a,b) ((a)<(b)?(a):(b))
int N;
bool visit[10001];
int down_cnt[10001], up_cnt[10001];
ll down_cost[10001];
ll ans[10001];
vector<vector<pair<ll, int>>> adj;
vector<pair<ll, int>> reverse_tree;
void dfs(int here) {
    visit[here] = true;
    down_cnt[here] = 1;
    int cnt = 0;
    for (auto n : adj[here]) {
        int next = n.second;
        ll cost = n.first;
        if (!visit[next]) {
            reverse_tree[next] = { cost,here };
            dfs(next);
            down_cost[here] += down_cost[next] + cost*(down_cnt[next]);
            cnt += down_cnt[next];
        }
    }
    down_cnt[here] += cnt;
}
void dfs2(int here) {
    visit[here] = true;
    if (down_cnt[here] == N) ans[here] = down_cost[here];
    else {
        int other_subtree_cnt = N - down_cnt[here];
        int par = reverse_tree[here].second;
        ll parcost = reverse_tree[here].first;
        ans[here] = down_cost[here];
        ans[here] += ans[par] - down_cost[here] - down_cnt[here] * parcost;
        ans[here] += other_subtree_cnt * parcost;
    }
    for (auto n : adj[here]) {
        int next = n.second;
        if (!visit[next]) {
            dfs2(next);
        }
    }
}
int main() {
    while (scanf("%d"&N), N != 0) {
        memset(down_cost, 0sizeof down_cost);
        memset(visit, 0sizeof visit);
        memset(up_cnt, 0sizeof up_cnt);
        memset(down_cnt, 0sizeof down_cnt);
        memset(ans, 0sizeof ans);
        adj = vector<vector<pair<ll, int>>>(N + 1vector<pair<ll, int>>());
        reverse_tree = vector<pair<ll, int>>(N + 1);
        for (int n = 0;n < N - 1;n++) {
            int u, v, d;
            scanf("%d%d%d"&u, &v, &d);
            adj[u].push_back({ (ll)d,v });
            adj[v].push_back({ (ll)d,u });
        }
        dfs(0);
        memset(visit, 0sizeof visit);
        dfs2(0);
        ll ret = (ll)2e16;
        for (int n = 0;n < N;n++)
            ret = min(ret, ans[n]);
        printf("%lld\n", ret);
    }
    return 0;
}
cs

댓글 없음:

댓글 쓰기