문제
1번부터 n번까지 n 개의 정점으로 이루어진 트리가 주어집니다.
m 개의 정점쌍 각각에 대해, 두 정점 사이의 거리를 구하는 프로그램을 작성하세요.
입력 형식
- 첫 번째 줄에 정점의 개수 n과 거리를 구하고자 하는 정점쌍의 개수 m이 주어집니다.그다음 줄부터 m 개의 줄에 걸쳐, 한 줄에 정점쌍 하나씩, 거리를 구하고자 하는 두 정점의 번호가 공백으로 구분되어 주어집니다.
- 2 ≤ n ≤ 1,000
- 1 ≤ m ≤ 1,000
- 트리의 간선의 길이는 1 이상 1,000 이하입니다.
- 그다음 줄부터 n-1 개의 줄에 걸쳐, 트리에서 간선으로 연결된 두 정점의 번호와 그 간선의 길이가 공백으로 구분되어 주어집니다.
출력 형식
첫 번째 줄부터 m 개의 줄에 걸쳐 차례대로, 두 정점의 거리를 한 줄에 하나씩 출력합니다.
입출력 예제
예제1
입력:
5 3
3 1 2
1 2 4
3 5 1
3 4 5
1 5
2 4
3 1
출력:
3
11
2
설명
각 정점을 시작으로 하는 DFS를 N번 진행하여 모든 쌍 간의 거리를 계산한다
트리가 구해진 직후 각 정점을 시작으로 하는 DFS를 N번 진행하여 모든 쌍 간의 거리를 계산한다.
특정 시작점에 대해 DFS를 진행하게 되면 O(N)의 시간복잡도로 모든 정점까지의 거리를 구할 수 있게 되므로
총 O(N^2)의 시간에 모든 정점간의 거리를 구할 수 있게 된다.
이후 m개의 질문에 대해 미리 구해놓은 거리를 O(1)로 출력해주면 질문에 대한 처리를 O(M)이다.
따라서 총 시간복잡도는 O(N^2 + M)
'i노드를 시작으로 모든 노드들의 거리를 계산하여 저장' 하는게 포인트
import sys
sys.stdin=open('input.txt', 'r')
def DFS(start,a):
for b,w in graph[a]:
if visited[b]: # 이미 방문한 노드는 스킵
continue
visited[b] = True # a노드에 연결된 b 노드들 모두 방문
dist[start][b] = dist[start][a] + w # "시작노드부터 b 노드까지 거리" == "시작노드부터 a노드까지 간거리 + w"
DFS(start, b)
if __name__=="__main__":
n,m=map(int, input().split())
graph = [[] for _ in range(n+1)]
visited = [False]*(n+1)
dist = [[0]*(n+1) for _ in range(n+1)]
for _ in range(n-1):
a,b,w=map(int, input().split())
graph[a].append((b,w))
graph[b].append((a,w))
for i in range(1,n+1):
for j in range(1, n+1):
visited[j] = False
visited[i] = True
# i는 True로 방문 시작이고 i랑 연결된 j노드들은 모두 False
DFS(i,i) # dist를 통해 i노드를 시작으로 모든 나머지 노드들간의 거리를 계산
for _ in range(m):
a,b = map(int, input().split()) # a에서 b까지 가는데 간선의 거리
print(dist[a][b])