백준 1197번 : 최소 스패닝 트리 in Python
이 문제는 그래프가 주어졌을 때 최소 스패닝 트리를 구하는 문제이다. 최소 스패닝 트리(minimum spanning tree, MST)란 주어진 그래프의 모든 정점을 연결하는 부분 그래프 중 그 가중치의 합이 최소인 트리를 말한다.
이 문제는 전형적인 그래프 문제 중 하나로, 이를 해결하는 알고리즘에는 Kruskal's algorithm과 Prim's algorithm이 있다. 두 알고리즘의 핵심은 최소 스패닝 트리를 만들기 위해서 가중치가 가장 작은 간선을 우선적으로 살펴본다는 점이다. 다만 Kruskal's algorithm은 모든 간선을 가중치가 작은 것부터 살펴보면서 트리를 만들어나가고, Prim's algorithm은 하나의 정점에서 시작하는 부분 그래프에 가중치가 가장 작은 간선을 계속 추가해 트리를 만들어나간다. 즉, Kruskal's algorithm은 초기에 각 정점들을 하나의 트리로 간주하고 이를 계속 합쳐나가면서 하나의 트리로 만든다면, Prim's algorithm은 하나의 정점에서 시작하는 트리에 간선을 계속 추가해 트리를 점점 키워나간다.
정점의 개수가 V, 간선의 개수가 E라고 할 때, 두 알고리즘은 binary heap을 사용하는 경우 모두 O(E log V) 만큼의 시간이 소요된다. 그러므로 두 알고리즘은 효율성 면에서 큰 차이가 없다.
여기서는 Kruskal's algorithm을 이용해 최소 스패닝 트리를 구한다. 이 알고리즘에서는 임의의 두 정점이 서로 연결되어 있는지를 수시로 확인해야 하기 때문에 분리 집합에서 사용했던 find 함수와 union 함수를 사용하고, 이를 보조하는 parent 배열 역시 필요하다. 구현 방식은 이전의 집합의 표현 문제와 같다.
그 다음 Kruskal's algorithm에서는 가중치가 가장 작은 간선을 여러 번 뽑아야 하기 때문에 이를 보조하는 min-heap을 만들고 여기에 각 간선에 대응하는 가중치와 간선의 양 끝 정점을 포함하는 배열을 넣는다. 그리고 heap에서 가중치가 가장 작은 간선을 하나씩 꺼내 그 간선의 양 끝 정점이 같은 트리 안에 있지 않은 경우에만 두 트리를 합친다. 그러면 모든 간선을 다 살폈을 때 하나의 트리가 나오게 되고, 이 트리는 최소 스패닝 트리가 된다. 그러므로 이 트리 내의 간선들의 가중치의 합을 출력하면 된다.
이를 코드로 구현하면 다음과 같다.
import sys
import heapq
input = sys.stdin.readline
sys.setrecursionlimit(10**5)
def find(u):
if parent[u] == u: return u
parent[u] = find(parent[u])
return parent[u]
def union(u, v):
root_u = find(u)
root_v = find(v)
parent[root_u] = root_v
V, E = map(int, input().split())
parent = [i for i in range(V+1)]
# 간선들을 가중치 순서대로 뽑기 위한 heap
heap = []
for _ in range(E):
A, B, C = map(int, input().split())
# 간선들을 가중치 순서대로 뽑기 위해 heap에 간선을 넣을 때 [가중치, 간선의 양 끝 점]으로 넣는다.
heapq.heappush(heap, [C, A, B])
# 최소 스패닝 트리의 가중치의 합을 나타내는 변수
W = 0
while len(heap) > 0:
w, u, v = heapq.heappop(heap)
# 같은 트리 내의 정점들을 연결하는 것을 방지한다.
if find(u) != find(v):
union(u, v)
W += w
print(W)