상세 컨텐츠

본문 제목

20297번: Confuzzle - Centroid Decomposition

알고리즘/baekjoon

by oVeron 2024. 8. 13. 12:51

본문

728x90
반응형

Centroid Decomposition(센트로이드 분할) 기본 문제.

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tiii;

const int INF = 1e9;
const int MAXN = 100001;
int N;
int c[MAXN];
vector<int> graph[MAXN];

int subSize[MAXN];
bool cache[MAXN]; //cache[u] : 정점 u가 centroid가 되었는지를 확인한다.

void buildSubSize(int u, int p) //정점 u를 root로 하는 서브트리의 크기를 DFS를 이용해 구한다.
{
    subSize[u] = 1;
    for(int v : graph[u])
    {
        if(v == p) continue;
        if(cache[v]) continue; //이미 centroid가 된 정점이므로 무시
        
        buildSubSize(v, u);
        subSize[u] += subSize[v];
    }
}

int getCtr(int u, int p, int sz) //centroid가 되는 정점을 구한다.
{
    for(int v : graph[u])
    {
        if(v == p) continue;
        if(cache[v]) continue; 
        //서브트리의 크기가 전체 트리의 크기의 절반이 넘으면, 해당 트리로 들어간다.
        if(subSize[v] > sz / 2) return getCtr(v, u, sz);
    }
    return u; //모든 서브트리의 크기가 전체 트리의 크기의 절반 이하면, 해당 노드가 centroid이다.
}

//centroid와 c[u] 사이의 최소 거리를 DFS로 구한다.
void DFS(int u, int p, int d, map<int, int>& distI)
{
    if(distI.find(c[u]) == distI.end()) distI[c[u]] = d;
    else distI[c[u]] = min(distI[c[u]], d);
    
    for(int v : graph[u])
    {
        if(v == p) continue;
        if(cache[v]) continue;
        DFS(v, u, d+1, distI);
    }
}

int solve(int u)
{
    buildSubSize(u, 0);
    int ctr = getCtr(u, 0, subSize[u]); //ctr : centroid
    cache[ctr] = true;
    
    //distO : 탐색하고 있는 서브트리의 반대편에 존재하는 c[i]와 centroid 사이의 거리
    map<int, int> distO;
    distO[c[ctr]] = 0; //c[ctr]은 centroid와 같으므로 거리는 0이다.
    
    int ans = INF;
    for(int v : graph[ctr])
    {
        if(cache[v]) continue;
        
        //distI : 탐색하고 있는 서브트리에 존재하는 c[i]와 centroid 사이의 거리
        map<int, int> distI;
        DFS(v, ctr, 1, distI);
        
        //distI[c[u]] + distO[c[u]]의 최솟값이 답이 된다.
        for(auto [node, d] : distI)
        {
            if(distO.find(node) == distO.end()) continue;
            ans = min(ans, d + distO[node]);
        }
        
        //현재 서브트리를 탐색한 후 다음 서브트리를 탐색하면, 현재 서브트리는 
        //다음 서브트리에 대하여 반대 서브트리가 되므로, 지금까지 구한 거리를 distO에 갱신한다.
        for(auto [node, d] : distI)
        {
            if(distO.find(node) == distO.end()) distO[node] = d;
            else distO[node] = min(distO[node], d);
        }
    }
    
    //u의 모든 서브트리에 대하여 위의 작업을 진행한다.
    //서브트리의 크기는 전체 트리의 크기의 절반 이하이므로, 시간 복잡도는 O(NlogN)이다.
    for(int v : graph[ctr])
    {
        if(cache[v]) continue;
        ans = min(ans, solve(v));
    }
    return ans;
}

int main()
{
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	
	cin >> N;
	for(int i=1; i<=N; i++) cin >> c[i];
	for(int i=1; i<N; i++)
	{
	    int u, v; cin >> u >> v;
	    graph[u].push_back(v);
	    graph[v].push_back(u);
	}
	
	cout << solve(1);
}
728x90
반응형

관련글 더보기

댓글 영역