알고리즘 문제

백준 4803번 : 트리 in Python

YJH3968 2021. 5. 22. 23:19
728x90
4803번: 트리
 
www.acmicpc.net

이 문제는 그래프가 주어질 때 그래프 내에 트리가 몇 개가 있는지 세는 문제이다.

이 문제를 해결하기 위해서는 트리를 찾는 방법을 생각할 필요가 있다. 트리는 사이클이 없는 연결된 요소이므로, 그래프 탐색을 이용해 연결된 요소를 찾되 만약 연결된 요소 내에 사이클이 있으면 그 요소는 트리가 아니므로 트리 개수에서 제외시켜야 한다. 그러므로 이 문제의 핵심은 연결된 요소 내에서 사이클이 있는지를 판별하는 것이다.

그렇다면 어떻게 연결된 요소 내에 사이클이 있는지 판별할 수 있을까? 우선 연결된 요소를 파악하기 위해서는 그래프의 각 점에 대해 BFS나 DFS를 적용해야 한다. 특별히 트리에 DFS를 적용한다고 가정하자. 그러면 임의의 정점에서 시작해도 연결된 요소 내에서 만약 특정 정점 u에 대해 DFS의 수행을 끝낸 경우 더 이상 간선을 통해 u에 접근하는 일이 없다. 반면에 사이클에 DFS를 적용하면, 특정 정점 u와 인접한 정점 v, w에 대해 DFS를 u에 적용해서 v를 먼저 방문하는 경우 w까지 방문한 뒤 다시 돌아오는 작업을 수행한 뒤, u는 다음으로 v를 방문하려 할 것이다. 그런데 v는 이미 DFS를 끝마친 상태이다. 즉, 사이클 내에서는 위의 트리의 경우와는 달리 v가 DFS를 끝낸 상태임에도 불구하고 u에서 v로 방문을 시도한다는 점에서 차이가 발생한다. 이를 이용해 연결된 요소가 사이클인지를 판단한다.

우선 각 정점들을 방문했는지에 대한 배열 visited을 만들고 배열의 모든 값을 방문하지 않았다는 의미의 -1로 초기화한다. 기존 DFS를 위한 함수에 사이클을 판별하는 기능을 추가하기 위해 tree_check라는 boolean 값을 True로 초기화한다. 이후 정점을 방문한다는 표시로 visited 배열의 정점에 대응하는 값을 0으로 바꾸고, 정점과 인접한 점들 중 방문하지 않은 점에 대해 DFS 함수를 재귀적으로 호출하는 것은 같다. 그러나 DFS를 다 수행한 뒤에 visited 배열의 정점에 대응하는 값을 1로 바꾼다. 만약 지금 방문한 정점이 들어있는 연결된 요소가 트리라면 이 정점에 절대 방문하는 일이 없다. 그러나 방문하는 일이 생긴다면 이는 사이클이 존재한다는 뜻이므로 tree_check의 값을 False로 바꾼다. 그리고 방문하지 않은 정점에 대해 재귀적으로 호출한 결과와 and 연산자를 수행해 인접한 정점들에 대해 DFS를 수행했을 때 사이클이 나오는지도 조사한다. 그리고 DFS를 다 수행한 뒤에 tree_check 값을 반환한다.

이제 모든 정점에 대해 DFS 함수를 실행해 반환된 값이 True인 경우에만 tree의 개수를 1 증가시키면 tree의 개수를 셀 수 있다.

이를 코드로 구현하면 다음과 같다.

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**5)

# DFS를 구현한 함수
def DFS(n):
    # 이미 방문한 정점인 경우 해당 정점이 포함된 연결된 요소가 트리인지 이미 판별한 상태이다.
    if visited[n] >= 0: return False
    # 해당 정점이 포함된 요소가 트리인지 판별하는 변수
    tree_check = True
    visited[n] = 0
    for v in graph[n]:
        # 인접한 정점이 이미 DFS를 다 끝낸 상태라면 사이클이 있다는 뜻이다.
        if visited[v] == 1:
            tree_check = False
        elif visited[v] == -1:
            tree_check = DFS(v) and tree_check
    # DFS 수행을 다 끝냈다는 의미로 값을 1로 바꾼다.
    visited[n] = 1
    return tree_check
case = 0
while True:
    case += 1 
    n, m = map(int, input().split())
    if n == 0 and m == 0: break
    graph = [[] for _ in range(n+1)]
    for _ in range(m):
        a, b = map(int, input().split())
        graph[a].append(b)
        graph[b].append(a)
    tree_count = 0
    visited = [-1]*(n+1)
    for i in range(1, n+1):
        # 각 정점에 대해 DFS를 수행한 결과 트리를 찾았다면 트리의 개수를 1 증가시킨다.
        if DFS(i):
            tree_count += 1
    if tree_count == 0:
        print("Case "+str(case)+": No trees.")
    elif tree_count == 1:
        print("Case "+str(case)+": There is one tree.")
    else:
        print("Case "+str(case)+": A forest of "+str(tree_count)+" trees.")
728x90