STUDY/Algorithm

[백준] 9370 미확인 도착지, python

sinawi95 2022. 1. 5. 10:56
728x90

https://www.acmicpc.net/problem/9370

 

9370번: 미확인 도착지

(취익)B100 요원, 요란한 옷차림을 한 서커스 예술가 한 쌍이 한 도시의 거리들을 이동하고 있다. 너의 임무는 그들이 어디로 가고 있는지 알아내는 것이다. 우리가 알아낸 것은 그들이 s지점에서

www.acmicpc.net

문제를 이해하기까지 꽤 오래걸렸다. 근데 알고리즘이 어렵진 않았다.

문제를 요약하자면 각 목적지 별로 최단 경로를 찾는데 그 최단 경로 안에 g와 h 사이 도로를 포함하는 후보만 출력하는 것이다.

출발지, 교차로 g, h 에서 각 한 번 씩 최단 거리를 탐색하면 충분히 찾을수 있다. 각 지점에서 다익스트라로 최단거리를 파악하고, '출발지에서 목적지 까지'의 최단거리와 '출발지 - g - h - 목적지' 혹은 '출발지 - h - g - 목적지'의 거리가 같은지 확인하면 찾을수 있다.

 

import sys; input = sys.stdin.readline
from heapq import heappop, heappush
INF = 987654321


def check_visit_all(dest_list, visit):
    for item in dest_list:
        if not visit[item]:
            return False
    return True


def dijkstra(start, size, destination):
    visit = [False for _ in range(size + 1)]
    dist_list = [INF for _ in range(size + 1)]
    mh = [(0, start)]
    while mh:
        dist, cur = heappop(mh)
        if visit[cur]: continue
        visit[cur] = True
        dist_list[cur] = dist
        if check_visit_all(destination, visit):
            break
        for adj_dist, adj in map_linked_list[cur]:
            if dist + adj_dist < dist_list[adj]:
                heappush(mh, (dist + adj_dist, adj))

    return dist_list

T = int(input())
for _ in range(T):
    # 0 입력
    n, m, t = map(int, input().split())
    s, g, h = map(int, input().split())
    map_linked_list = [list() for _ in range(n + 1)]
    dist_tmp = 0
    for i in range(m):
        a, b, d = map(int, input().split())
        if (a == g and b == h) or (a == h and b == g):
            dist_tmp = d
        map_linked_list[a].append((d, b))
        map_linked_list[b].append((d, a))
    destinations = set()
    for i in range(t):
        destinations.add(int(input()))

    # 1 탐색
    s_list = dijkstra(s, n, destinations)
    h_list = dijkstra(h, n, destinations)
    g_list = dijkstra(g, n, destinations)
    answer = set()
    for dest in destinations:
        tmp1 = s_list[g] + dist_tmp + h_list[dest]
        tmp2 = s_list[h] + dist_tmp + g_list[dest]
        if s_list[dest] == tmp1 or s_list[dest] == tmp2:
            answer.add(dest)

    # 2 출력
    print(*sorted(answer))