STUDY/Algorithm

[백준] 21924 도시건설 python

sinawi95 2022. 1. 18. 10:47
728x90

https://www.acmicpc.net/problem/21924

 

21924번: 도시 건설

첫 번째 줄에 건물의 개수 $N$ $(3 \le N \le 10^5 )$와 도로의 개수 $M$ $(2 \le M \le min( {N(N-1) \over 2}, 5×10^5)) $가 주어진다. 두 번째 줄 부터 $M + 1$줄까지 건물의 번호 $a$, $b$ $(1 \le a, b \le N, a ≠ b)$와 두

www.acmicpc.net

최소 신장 트리 문제이다.

Prim 알고리즘으로 풀었다. 인접한 정점(건물)을 추가하는 연결리스트를 만들고 신장트리를 만든다. 그리고 최소 값인 도로끼리 연결해야하므로 heap 자료구조를 사용한다.

신장트리는 정점 1부터 시작해서 인접하되 방문하지 않은 정점들을 추가한다. heap을 사용하므로 신장트리에 연결된 정점중 가장 cost가 낮은 정점에 계속 방문하게 되어 반복이 끝나면 최소 신장트리를 만들수 있다.

마지막엔 모든 정점이 연결되었는지 확인하고(모든 정점을 방문했는지 확인) 하나라도 연결되지 않았으면 -1을 모두 연결되었으면 최소 비용을 반환한다.

그리고 원하는 답은 절약한 비용이므로 모든 간선을 추가했을 때의 cost에서 최소 비용을 뺀 값을 출력하면 된다.

 

# import sys; input = sys.stdin.readline
from heapq import heappop, heappush

def prim(N, adj_list):
    visit = [False for _ in range(N + 1)]

    h = [(0, 1)] # 1부터 N까지 모든 건물이 연결되어야하므로 1부터 시작
    ans = 0
    while h:
        cost, cur = heappop(h)
        if visit[cur]: continue
        visit[cur] = True
        ans += cost
        # print(adj_list[cur])
        for adj_cost, adj in adj_list[cur]:
            if not visit[adj]:
                heappush(h, (adj_cost, adj))
    for i in range(1, N + 1):
        if not visit[i]:
            return -1
    return ans

def kruskal(N, M, adj_list):
    pass

def main():
    # 0. 입력
    N, M = map(int, input().split())
    linked_list = [[] for _ in range(N + 1)]
    max_cost = 0
    for _ in range(M):
        a, b, c = map(int, input().split())
        linked_list[a].append((c, b))
        linked_list[b].append((c, a))
        max_cost += c

    # 1. 최소 신장트리
    ans = prim(N, linked_list)
    # 2. 출력
    if ans == -1:
        print(-1)
    else:
        print(max_cost - ans)

if __name__ == "__main__":
    main()

 

Kruskal은 cost를 기준으로 오름차순으로 정렬한뒤, 낮은 값부터 차례대로 추가하면된다.

선택한 건물이 사이클이 되는지 확인해야하는 과정이 필요했는데, 사이클 체크를 어떻게 하는지 잊어버려서 다시 찾아보았다.(참고했던 블로그는 아래 추가하겠다.)

그리고 heappop(O(log n)), heappush(O(log n))를 사용하는것보다 append(O(1)), sort(O(n logn)), indexing(O(1))하는게 더 빠르다

# import sys; input = sys.stdin.readline

def find(n):
    global parent
    if parent[n] != n:
        parent[n] = find(parent[n])
    return parent[n]

def union(a, b):
    global rank, parent
    if rank[a] > rank[b]:
        parent[b] = a
    else:
        parent[a] = b
        if rank[a] == rank[b]:
            rank[b] += 1

def cycle_check(a, b):
    global rank, parent
    a, b = find(a), find(b)
    if a == b:
        return True
    union(a, b)
    return False

def kruskal(N, M, adj_list):
    cost = 0
    cnt = 0
    for i in range(M):
        c, a, b = adj_list[i]
        if cycle_check(a, b): continue # 사이클이 생긴 경우
        cost += c
        cnt += 1

    if cnt == N - 1: # 건물이 N개이므로 N-1개의 도로가 있어야함
        return cost
    return -1

def main():
    # 0. 입력
    global parent, rank
    N, M = map(int, input().split())
    parent = list(range(N + 1))
    rank = [0 for _ in range(N + 1)]
    roads = []
    max_cost = 0
    for _ in range(M):
        a, b, c = map(int, input().split())
        roads.append((c, a, b))
        max_cost += c
    roads.sort()
    # 1. 최소 신장트리
    ans = kruskal(N, M, roads)
    # 2. 출력
    if ans == -1:
        print(-1)
    else:
        print(max_cost - ans)

if __name__ == "__main__":
    main()

 

그리고 prim은 pypy에서 빨랐고, kruskal은 python3에서 빨랐다. 왜 그런진 모르겠다.


참고글

https://gmlwjd9405.github.io/2018/08/29/algorithm-kruskal-mst.html

 

[알고리즘] Kruskal 알고리즘 이란 - Heee's Development Blog

Step by step goes a long way.

gmlwjd9405.github.io

https://gmlwjd9405.github.io/2018/08/31/algorithm-union-find.html

 

[알고리즘] Union-Find 알고리즘 - Heee's Development Blog

Step by step goes a long way.

gmlwjd9405.github.io

https://m.blog.naver.com/ndb796/221230994142

 

18. 크루스칼 알고리즘(Kruskal Algorithm)

이번 시간에 다루어 볼 내용은 바로 크루스칼 알고리즘입니다. 크루스칼 알고리즘은 가장 적은 비용으로 모...

blog.naver.com