알고리즘 문제

백준 1717번 : 집합의 표현 in Python

YJH3968 2021. 5. 23. 20:51
728x90
1717번: 집합의 표현
 
www.acmicpc.net

이 문제는 집합을 코드 상에서 어떻게 표현하는지에 관한 문제이다.

여기서 집합을 사용하는 목적은 임의의 두 원소가 주어졌을 때 두 원소가 들어있는 집합을 합한다던가(합집합 연산을 의미한다.), 아니면 두 원소가 같은 집합 내에 있는지를 판별하기 위함이다. 

이를 구현하기 위해서는 가장 단순하게는 그냥 파이썬 내의 집합 자료 구조를 사용해서 위 연산을 구현하면 되지만, 문제는 합집합 연산의 경우 어떤 집합을 합쳐야 하는지를 알려주는게 아니라 집합 내의 한 원소만을 알려주기 때문에 모든 집합들을 뒤져서 그 원소가 들어있는 집합을 찾아야 한다. 그러므로 이러한 방식은 매우 비효율적인 방법이 된다.

이러한 집합을 표현하는 방법으로는 집합을 하나의 트리로 구현하는 것이다. 즉, 집합의 루트 원소를 정하고 루트 원소에서 자식 원소들을 추가함으로써 하나의 트리 형태로 집합을 표현한다. 그러면 만약 임의의 원소에 대해 그 원소가 어느 집합에 들어있는지를 알고 싶을 때는 원소의 부모 원소를 찾는 과정을 루트 원소에 도달할 때까지 반복하면 된다.

그러면 임의의 두 원소에 대해 합집합 연산을 수행할 경우를 생각해보자. 우선 두 원소가 같은 집합에 있는지를 검사하기 위해 각 원소가 포함된 트리의 루트 원소를 찾고 해당 루트 원소가 같은지를 검사하면 되므로 이를 수행하는데 걸리는 시간은 각 원소가 포함된 트리의 높이만큼일 것이다. 그리고 합집합 연산을 할 경우 두 트리에 대해 한 트리가 다른 트리의 자식 트리로 만들면 된다. 그러므로 이를 수행하는데 걸리는 총 시간은 각 원소가 포함된 트리의 높이에 따라 달라진다.

그렇다면 트리의 높이를 최대한 낮추는 것이 중요한데, 이를 위해서는 두 집합을 합할 때 높이가 더 작은 트리가 높이가 더 큰 트리의 자식 트리로 하는 방법이 있고, 트리의 루트를 찾는 과정에서 만난 모든 원소들을 루트의 직속 자식 원소들로 만드는 방법이 있다. 전자의 방법을 union by rank, 후자의 방법을 path compression이 있다. 이 두 방법을 이용하면 트리의 높이를 효과적으로 낮출 수 있어 합집합 연산이나 원소가 들어있는 집합을 찾는 연산에 걸리는 시간을 줄일 수 있다. 여기서는 path compression만 해도 충분한 시간 내에 문제를 해결할 수 있으므로 이 방법만 사용한다.

지금까지 한 내용을 구현하기 위해 크게 두 함수를 구현하는데, 하나는 원소가 주어질 때 그 원소가 들어있는 집합(의 루트 원소)을 구하기 위해 사용하는 find 함수, 다른 하나는 두 원소가 주어질 때 두 원소가 들어있는 집합을 합치는 union 함수이다.

우선 find 함수를 구현할 때는 base case로 만약 인자로 넣은 원소가 루트 원소면 자기 자신을 반환하면 된다. 그렇지 않은 경우에는 인자로 주어진 u의 부모 원소를 다시 find 함수에 넣어 재귀적으로 호출한다. 그러면 이 원소가 루트 원소가 될 때까지 재귀적으로 호출되고, 루트 원소가 반환된 경우 이전에 호출된 함수들이 해당 원소의 부모 원소를 루트 원소로 가지게 한 뒤 그 루트 원소를 반환한다. 그러므로 u의 부모 원소 역시 루트 원소가 되고, 최종적으로 루트 원소를 반환한다. 함수가 끝나면 u에서 루트 원소까지의 path 내의 모든 원소들의 부모 원소가 루트 원소가 된다.

이제 합집합 연산을 구현하는 것은 쉽다. find 함수를 이용해 각 원소의 루트 원소를 구한 다음 한 루트 원소의 부모 원소를 다른 루트 원소로 만들면 된다.

집합을 표현할 때는 각 원소의 부모 원소를 저장하는 배열을 만들고, 처음에는 모든 원소가 루트 원소가 되므로 이를 표현하기 위해 부모 원소가 자기 자신이 되도록 만든다. 즉, 해당 원소가 루트 원소인지를 판별하기 위해서는 원소의 부모 원소가 자기 자신인지 검사하면 된다.

이를 코드로 구현하면 다음과 같다.

import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**5)

n, m = map(int, input().split())
# 해당 원소가 어느 집합에 포함되어 있는지를 알려주는 함수
def find(u):
    if parent[u] == u: return u
    parent[u] = find(parent[u])
    return parent[u]
# 두 원소가 들어있는 집합을 합하는 함수
def union(u, v):
    root_u = find(u)
    root_v = find(v)
    parent[root_u] = root_v
# 각 원소의 부모 원소를 저장하는 함수. 초기값은 자기 자신으로 한다.
parent = [i for i in range(n+1)]
for _ in range(m):
    operator, a, b = map(int, input().split())
    # 합집합 연산
    if operator == 0 and find(a) != find(b):
        union(a, b)
    # 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산
    elif operator == 1:
        if find(a) == find(b):
            print("YES")
        else:
            print("NO")

 

 

728x90