코딩고치

[파이썬][알고리즘] 최소 신장 트리 본문

파이썬/알고리즘

[파이썬][알고리즘] 최소 신장 트리

코딩고치 2020. 5. 15. 23:07

최소 신장 트리

신장 트리

  • 그래프의 모든 노드가 서로 연결됨과 동시에 트리의 속성을 가지는 그래프.
  • 조건
    • 모든 노드가 연결
    • 트리의 속성 (사이클을 가지지 않음)

최소 신장 트리

크루스칼 알고리즘

Union-Find 알고리즘

  • 사이클이 생기는지 안 생기는지 확인하는 알고리즘
  • 연결된 노드를 찾거나 노드들을 연결할 때 사용
  • Disjoint Set
    • 전체를 하나의 집합으로 보고 그 안의 부분 집합들을 만들어 공통된 원소가 없도록 연결
      1. 초기화
        • n개의 원소를 개별 집합으로 만듦
      2. Union
        • 두 개별 집합을 하나로 합침 (두 트리를 하나로 만듦)
      3. Find
        • 두 개의 노드가 있을 때 같은 부분 집합 안에 존재하는지 아닌지 판별
        • 각 집합의 루트 노드가 같은지 다른지를 확인

Union-Find 알고리즘의 주의사항

  • 잘못 구현하면 링크드 리스트와 같은 형태가 됨
    • union by rank와 path compression을 이용하여 최적화

Union-by-rank

  • 각 트리에 대해 높이 (rank)를 기억
  • 두 트리의 높이가 다르면 높이가 작은 트리를 큰 트리에 붙임 ( 높이가 큰 트리의 루트 노드가 합친 집합의 루트 노드가 되도록 함)
  • 높이가 같은 경우 한쪽의 트리를 1 증가시키고 다른 쪽의 트리를 해당 트리에 연결
    • 높이 h인 트리를 만들려면 높이 h - 1인 트리 2개를 연결해야 함
    • 높이 h - 1인 트리를 만들 때 n개의 원소가 있어야 한다면 높이 h인 트리를 만들려면 최소 2n개가 있어야 함
    • 연산 복잡도는 O(N)이 아닌 O(logN)으로 낮출 수 있음

path compression

  • Find를 실행한 노드에서 거쳐간 노드를 루트 노드에 전부 연결
  • Union-by-rank와 path compression을 이용하면 시간 복잡도는 O(Mlog*N)
    • O(1)

크루스칼 알고리즘 코드

출처: https://ko.wikipedia.org/wiki/%ED%81%AC%EB%9F%AC%EC%8A%A4%EC%BB%AC_%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98

 

graph = {
    # 노드를 리스트로
    'vertices': ['A', 'B', 'C', 'D', 'E', 'F', 'G'],
    # 간선을 리스트로 (weight, 각 끝점)
    'edges': [
        (7, 'A', 'B'),
        (5, 'A', 'D'),
        (7, 'B', 'A'),
        (9, 'B', 'D'),
        (8, 'B', 'C'),
        (7, 'B', 'E'),
        (8, 'C', 'B'),
        (5, 'C', 'E'),
        (5, 'D', 'A'),
        (9, 'D', 'B'),
        (15, 'D', 'E'),
        (6, 'D', 'F'),
        (7, 'E', 'B'),
        (5, 'E', 'C'),
        (15, 'E', 'D'),
        (8, 'E', 'F'),
        (9, 'E', 'G'),
        (6, 'F', 'D'),
        (8, 'F', 'E'),
        (11, 'F', 'G'),
        (9, 'G', 'E'),
        (11, 'G', 'F')        
    ]
}
# 부모 노드 값 저장
parent = dict()
# 각각의 노드의 높이 번호
rank = dict()

def initialization(node):
    parent[node] = node
    rank[node] = 0

def find(node):
    # path compression
    if parent[node] != node:
        parent[node] = find(parent[node])
    return parent[node]

def union(node_a, node_b):
    # union-by-rank
    root_a = find(node_a)
    root_b = find(node_b)

    if rank[root_a] > rank[root_b]:
        parent[root_b] = root_a
    else:
        parent[root_a] = root_b

        if rank[root_a] == rank[root_b]:
            rank[root_b] += 1

def kruskal(graph):
    mst = []

    # 초기화
    for node in graph['vertices']:
        initialization(node)

    # 간선을 오름차순으로 정렬
    edges = graph['edges']
    edges.sort()

    # 사이클 확인 후 연결
    for edge in edges:
        weight, node_a, node_b = edge
        if find(node_a) != find(node_b):
            union(node_a, node_b)
            mst.append(edge)

    return mst
kruskal(graph)
[(5, 'A', 'D'),
 (5, 'C', 'E'),
 (6, 'D', 'F'),
 (7, 'A', 'B'),
 (7, 'B', 'E'),
 (9, 'E', 'G')]

시간 복잡도

  • O(ElogE)
    • 초기화하는 과정 O(V)
    • sort 하는 과정 O(ElogE)
    • 간선 확인
      • O(E) + O(1) -> O(E)

프림 알고리즘

  • 시작 정점을 선택 후 최소 간선으로 연결된 노드를 연결하고 이 노드에서 다시 최소 간선으로 연결된 노드를 연결하는 방식으로 확장
  • 크루스칼 알고리즘과의 차이
    • 크루스칼 알고리즘: 가중치가 가장 작은 간선에서 시작
    • 특정 노드에서 시작, 해당 노드에서 가중치가 가장 작은 간선을 통해 노드를 연결해 나감
  • 참고: https://ko.wikipedia.org/wiki/%ED%94%84%EB%A6%BC_%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98

프림 알고리즘 로직

  1. 임의의 노드를 선택 후 연결된 노드 집합에 입력
  2. 선택된 정점에 연결된 간선을 리스트로 만듦
  3. 간선 리스트에서 가장 작은 가중치를 가진 간선부터 선택
    • 이 간선에 연결된 노드가 연결된 노드 집합에 있으면 스킵 (Cycle 방지)
    • 연결된 노드 집합에 없으면 해당 간선을 선택 후 최소 신장 트리에 정보 입력
  4. 추출한 간선은 리스트에서 제거
  5. 간선 리스트에 간선이 없을 때까지 반복

프림 알고리즘 코드 구현

참고: collections 라이브러리의 defaultdict

  • key에 대한 값을 지정하지 않았을 때 빈 리스트로 초기화함
from collections import defaultdict

list_dict = defaultdict(list)
print(list_dict['key1'])

list_dict2 = dict()
print(list_dict2['key1'])
[]



---------------------------------------------------------------------------

KeyError                                  Traceback (most recent call last)

<ipython-input-7-e03bfe85bbf3> in <module>
      5 
      6 list_dict2 = dict()
----> 7 print(list_dict2['key1'])


KeyError: 'key1'
edges = [
    (7, 'A', 'B'), (5, 'A', 'D'),
    (8, 'B', 'C'), (9, 'B', 'D'), (7, 'B', 'E'),
    (5, 'C', 'E'),
    (15, 'D', 'E'), (6, 'D', 'F'),
    (8, 'E', 'F'), (9, 'E', 'G'),
    (11, 'F', 'G')
]
from collections import defaultdict
from heapq import *

def prim(first_node, edges):
    mst = []
    # 해당 노드에 해당 간선을 추가
    adjacent_edges = defaultdict(list)
    for weight, node1, node2 in edges:
        adjacent_edges[node1].append((weight, node1, node2))
        adjacent_edges[node2].append((weight, node2, node1))

    # 처음 선택한 노드를 연결된 노드 집합에 삽입
    connected = set(first_node)
    # 선탠된 노드에 연결된 간선을 간선 리스트에 삽입
    candidated_edge = adjacent_edges[first_node]
    # 오름 차순으로 정렬
    heapify(candidated_edge)

    while candidated_edge:
        weight, node1, node2 = heappop(candidated_edge)
        # 사이클 있는지 확인 후 연결
        if node2 not in connected:
            connected.add(node2)
            mst.append((weight, node1, node2))

            for edge in adjacent_edges[node2]:
                if edge[2] not in connected:
                    heappush(candidated_edge, edge)

    return mst
prim('A', edges)
[(5, 'A', 'D'),
 (6, 'D', 'F'),
 (7, 'A', 'B'),
 (7, 'B', 'E'),
 (5, 'E', 'C'),
 (9, 'E', 'G')]

시간 복잡도

  • 최악의 경우 while 구문에서 모든 간선에 대하여 반복
    • 최소 힙 구조 사용 -> O(ElogE)

개선된 프림 알고리즘

  • 간선이 아닌 노드를 중심으로 우선순위 큐를 만들어 풀어 나감
    • 초기화
      • 선택한 [노드:key] 구조를 만든 후 key값을 0으로 입력한 후 나머지 노드의 key값은 무한대로 설정
      • 모든 [노드:key] 값을 큐에 넣음
    • 가장 key값이 적은 [노드:key]를 pop으로 추출
    • 해당 노드의 인접한 노드들에서 key값과 가중치의 값을 비교하여 가중치 값이 작으면 해당 key값을 가중치 값으로 업데이트
    • 업데이트 후 우선순위 큐에서 key값이 가장 작은 노드를 루트 노드로 올라오도록 해야 함
      • heapdict 라이브러리 이용
from heapdict import heapdict

def prim(graph, first):
    mst = []
    keys = heapdict()
    previous = dict()   
    total_weight = 0

    #초기화
    for node in graph.keys():
        keys[node] = float('inf')
        previous[node] = None
    keys[first], previous[first] = 0, first

    while keys:
        current_node, current_key = keys.popitem()
        mst.append([previous[current_node], current_node, current_key])
        total_weight += current_key
        for adjacent, weight in graph[current_node].items():
            if adjacent in keys and weight < keys[adjacent]:
                keys[adjacent] = weight
                previous[adjacent] = current_node
    return mst, total_weight
graph = {
    'A': {'B': 7, 'D': 5},
    'B': {'A': 7, 'D': 9, 'C': 8, 'E': 7},
    'C': {'B': 8, 'E': 5},
    'D': {'A': 5, 'B': 9, 'E': 15, 'F': 6},
    'E': {'B': 7, 'C': 5, 'F': 8, 'G': 9},
    'F': {'D': 6, 'E': 8, 'G': 11},
    'G': {'E': 9, 'F': 11}
}
mst, total_weight = prim(graph, 'A')
print(mst)
print(total_weight)
[['A', 'A', 0], ['A', 'D', 5], ['D', 'F', 6], ['A', 'B', 7], ['B', 'E', 7], ['E', 'C', 5], ['E', 'G', 9]]
39

개선된 프림 알고리즘의 시간 복잡도

  • 최초 key 생성: O(V)

  • while 구문: O(VlogV)

    • while문: O(V)
    • heap에서 keys.poptiem() 실행: O(logV)
  • for문: O(ElogV)

    • for문은 총 간선 수만큼 실행: O(E)
    • key값을 업데이트할 때마다 heap구조 변경: O(logV)
  • 총 시간 복잡도: O(V + VlogV + ElogV)

  • E > V 이므로 시간 복잡도는 O(ElogV)

Comments