읽기 전

  • 불필요한 코드나 잘못 작성된 내용에 대한 지적은 언제나 환영합니다.
  • 개인적으로 사용해보면서 배운 점을 정리한 글입니다.

프림 알고리즘(Prim's Algorithm)이란?

최소 신장 트리(Minimum Spanning Tree) 구현에 사용되는 알고리즘으로 시작 정점에서 정점을 추가해가며 단계적으로 트리를 확장하는 기법이다.

프림 알고리즘의 동작

프림 알고리즘은 매 순간 최선의 조건을 선택하는 그리디 알고리즘을 바탕에 둔다. 즉, 탐색 정점에 대해 연결된 인접 정점들 중 비용이 가장 적은 간선으로 연결된 정점을 선택한다.

  1. 시작 단계는 시작 노드만이 MST 집합에 속한다.
  2. 트리 집합에 속한 정점들과 인접한 정점들 중 가장 낮은 가중치의 간선과 연결된 정점에 대해 간선과 정점을 MST 트리 집합에 넣는다. (사이클을 막기 위해 연결된 정점이 이미 트리가 속한다면 그 다음 순서를 넣는다.)
  3. 2번 과정을 MST 집합의 원소 개수가 그래프의 정점의 개수가 될 때까지 반복한다. (간선의 가중치를 더해서 최소 신장 트리 비용 산출)

Algorithm_Prim's_Algorithm_001

위 그래프의 최소 신장 트리를 프림 알고리즘으로 구해보자. 시작 정점은 A라 한다.

Algorithm_Prim's_Algorithm_002

A와 인접한 노드 B, C 중 C가 가장 가중치가 낮은 간선으로 연결되어 있으니 C를 집합에 넣고 비용에 AC 가중치를 더한다.

Algorithm_Prim's_Algorithm_003

AC와 인접한 노드들 중 가장 낮은 가중치로 연결된 정점은 B다. 집합에 B를 넣고 CB 가중치를 더한다.

Algorithm_Prim's_Algorithm_004

A, C, B와 인접한 노드들 중 가장 낮은 가중치로 연결된 정점은 D다. 집합에 D를 넣고 CD 가중치를 더한다.

Algorithm_Prim's_Algorithm_005

A, C, B, D와 인접한 노드들 중 가장 낮은 가중치로 연결된 정점은 E다. 집합에 E를 넣고 DE 가중치를 더한다.

Algorithm_Prim's_Algorithm_006

A, C, B, D, E와 인접한 노드들 중 가장 낮은 가중치로 연결된 정점 F를 집합에 넣고 DF 가중치를 더한다. 트리의 집합에 속한 원소의 개수가 N이 되었으므로 탐색을 중단한다. 탐색 결과 최소 신장 트리 구축의 비용은 13으로 확인되었다.

프림 알고리즘의 구현

동작 과정을 살펴본 결과 인접 정점들 중 가중치가 가장 낮은 정점을 찾는 과정이 시간복잡도를 결정할 것으로 보인다. 그렇다면 집합 내 정점들을 순회하면서 우선순위 큐에 삽입한 뒤 pop하여 구현하면 도움이 되겠다.

Algorithm_Prim's_Algorithm_007

python 코드

from collections import defaultdict
import heapq

def mst():
    V, E = 6, 9
    edges = [[1, 2, 6], [1, 3, 3], [2, 3, 2], [2, 4, 5],
             [3, 4, 3], [3, 5, 4], [4, 5, 2], [4, 6, 3], [5, 6, 5]]
    graph = defaultdict(list)
    for srt, dst, weight in edges:
        graph[srt].append((dst, weight))
        graph[dst].append((srt, weight))
    mst_graph = [[0] * V for _ in range(V)]
    mst_nodes = [-1 for _ in range(V)]
    visited = [True for _ in range(V)]
    q = [(0, 1, 1)]
    while q:
        cost, node, prev = heapq.heappop(q)
        if visited[node - 1] is False:
            continue
        visited[node - 1] = False
        mst_graph[node - 1][prev - 1] = 1
        mst_graph[prev - 1][node - 1] = 1
        mst_nodes[node - 1] = cost
        for dst, weight in graph[node]:
            if visited[dst - 1] is True:
                heapq.heappush(q, (weight, dst, node))
    print(f'MST cost is {sum(mst_nodes)}')
    mst_graph[0][0] = 1
    for row in mst_graph:
        print(*row)

mst()

Algorithm_Prim's_Algorithm_008

프림 알고리즘 시간복잡도

모든 노드에 대해 탐색을 진행하므로 $O(V)$이다. 그리고 우선순위 큐를 사용하여 매 노드마다 최소 간선을 찾는 시간은 $O(log\ V)$이다. 따라서 탐색과정에는 $O(Vlog\ V)$가 소요된다. 그리고 각 노드의 인접 간선을 찾는 시간은 모든 노드의 차수와 같으므로 $O(\sum_{v=1}^V degree(v))$$=O(2E)$$=O(E)$다. 그리고 각 간선에 대해 힙에 넣는 과정이 $O(log\ V)$가 되어 우선순위 큐 구성에 $O(Elog\ V)$가 소요된다. 따라서, $O(Vlog\ V + Elog\ V)$로 $O(Elog\ V)$가 되겠다.($\because$ E가 일반적으로 V보다 크기 때문)

만약 우선순위 큐가 아니라 배열로 구현한다면 각 정점에 최소 간선을 갖는 정점 탐색을 매번 정점마다 수행하므로 $O(|V|^2)$가 되고 탐색 결과를 기반으로 각 정점의 최소 비용 연결 정점 탐색에는 $O(1)$이 소요된다. 따라서 시간복잡도는 $O(V^2)$이다.

관련 링크

최소 신장 트리(Minimum Spanning Tree)

크루스칼 알고리즘(Kruskal's Algorithm)

읽기 전

  • 불필요한 코드나 잘못 작성된 내용에 대한 지적은 언제나 환영합니다.
  • 개인적으로 사용해보면서 배운 점을 정리한 글입니다.

신장 트리(Spanning Tree)란?

스패닝 트리라고도 부른다. 모든 임의의 정점이 연결된 그래프인 연결 그래프의 부분 그래프다. 모든 정점이 간선으로 연결되어 있지만 사이클이 존재하지 않는 그래프다.

Algorithm_Minimum_Spanning_Tree_001

연결 그래프가 주어졌을 때 생성할 수 있는 신장 트리는 유일하지 않음을 알 수 있다.

최소 신장 트리(Minimum Spanning Tree)란?

최소 비용 신장 트리(Minimum cost Spanning Tree)라고도 부르는 사람들이 있지만 대부분 최소 신장 트리로 사용한다. 그래프의 간선에 가중치 등 값이 부여될 때 이를 가중치 그래프라고 하며 그 그래프가 유향 그래프일 경우 네트워크라 부른다. 가중치가 부여된 그래프의 신장 트리를 구성하는 비용은 신장 트리를 구성하는 모든 가중치의 합이다. 즉, 최소 신장 트리는 신장 트리를 구성하는 간선들의 가중치 합이 가장 작은 신장 트리 이다. 그리고 최소 신장 트리를 구현하는 알고리즘에는 주로 2가지 알고리즘을 소개한다. 따로 자세히 정리할 크루스칼 알고리즘과 프림 알고리즘이다.

MST의 특징

  • 신장 트리들 중 간선의 가중치 합이 최소이다.
  • n개의 정점을 갖는 그래프에 대해 n - 1 개의 간선만을 사용해야 한다.
  • 사이클이 존재해서는 안된다.
  • 최소 신장 트리는 최단 경로를 보장하지 않는다.

MST 사용 문제 예시

도로, 통신 등 모든 노드를 방문해야 하는 문제들 중 구축 비용이나 소요 시간 등을 최소로 해야하는 문제들에 적용될 수 있다.

MST의 구현 알고리즘

크루스칼 알고리즘 (Kruskal's Algorithm)

그리디 알고리즘 기반으로 구현한다. MST가 최소 비용의 간선으로 구성되며 사이클을 이루지 않는다는 특성에 기인해 각 단계에서 기존 선택된 간선들과 사이클을 이루지 않는 최소 비용 간선을 선택한다. 선택 시 사이클 형성 여부를 체크하는 로직을 작성해두고 사용해야 한다.

  • 그래프의 간선들을 가중치를 기준으로 오름차순 정렬한다.
  • 정렬된 간선들을 순서대로 탐색하며 사이클을 형성하지 않는 간선을 선택한다.
  • 선택된 간선을 MST 집합에 넣는다. 집합의 원소 개수가 $V - 1$개가 될 때까지 반복한다.

구체적인 구현 방법, 시간 복잡도는 크루스칼 알고리즘(Kruskal's Algorithm)에 정리하였다.

프림 알고리즘 (Prim's Algorithm)

그리디 알고리즘 기반으로 구현한다. 다만 크루스칼 알고리즘과 동작 방식은 유사하나 간선 선택을 중심으로 동작했던 크루스칼 알고리즘과는 달리 정점을 기준으로 탐색을 진행한다. 시작 정점에서 출발하여 신장 트리 집합을 단계적으로 확장해나간다.

  • 시작 시 시작 정점만이 MST 집합에 포함된다.
  • 인접한 정점들 중 최소 비용 간선으로 연결된 정점을 MST 집합에 삽입한다. (삽입 시 사이클 형성 여부 체크)
  • 모든 정점이 연결되도록 $V - 1$개의 간선을 갖게될 때까지 반복한다.

구체적인 구현 방법, 시간 복잡도는 프림 알고리즘(Prim's Algorithm에 정리하였다.

관련 링크

크루스칼 알고리즘(Kruskal's Algorithm

프림 알고리즘(Prim's Algorithm

'Algorithms > Data Structure' 카테고리의 다른 글

세그먼트 트리 비재귀 구현  (0) 2021.08.16
세그먼트 트리(Segment Tree)  (0) 2021.08.14
서로소 집합(Disjoint Set), 유니온 파인드(Union-Find)  (0) 2021.08.02
트라이(Trie)  (0) 2021.07.05
힙(Heap)  (0) 2021.07.02

읽기 전

  • 불필요한 코드나 잘못 작성된 내용에 대한 지적은 언제나 환영합니다.
  • 개인적으로 배운 점을 정리한 글입니다.

유니온 파인드란 서로소 집합(Disjoint Set)을 표현할 때 사용하는 알고리즘이다. 그렇다면 Disjoint Set의 정의부터 명확히 해야한다.

서로소 집합(Disjoint Set)이란?

공통 원소가 없이 "상호 배타적인" 부분집합들로 나눠진 원소들에 대한 정보를 저장하고 조작하는 자료구조이다.

정의에서 말하는 기능을 지원하기 위해서는 다음의 세 가지 연산을 지원해야 한다.

  • 초기화 : N개의 원소가 각각의 집합에 속하도록 초기화한다.
  • Union(합치기) 연산 : 두 원소가 주어졌을 때 두 원소가 속한 집합을 하나로 합친다.
  • Find(찾기) 연산 : 어떤 원소가 주어졌을 때 해당 원소가 속한 집합을 반환한다.

세 가지 연산 중 Union 연산과 Find 연산을 지원해야 하므로 Disjoint Set 자료구조는 Union-Find 알고리즘이라고도 불리게 된다. 따라서 두 가지 개념은 따로 논의될 수 없으므로 자료구조에는 Disjoint Set을 그걸 표현하는 알고리즘을 Union-Find 알고리즘이라 칭하면 되겠다. (유의어 정도로 간주하자.)

Union-Find 알고리즘의 구현

배열 vs 트리

Union-Find의 구현은 주로 트리 기반으로 구현한다, 그렇다면 이유를 정리해보자.

배열로 Union-Find 구현

  • Array[i] : i좌표 원소가 속하는 집합의 번호라 하자.
  • 초기화 : Array[i] = i로 각자 자기자신에 속하도록 초기화한다.
  • Union : 두 집합을 합하기 위해 배열을 순회하면서 하나의 집합을 다른 하나의 집합 번호로 교체한다.

Data Structure_Disjoint_Set_Union_Find_001

배열을 처음부터 끝까지 순회하므로 $O(N)$이 된다.

  • Find : 해당 좌표를 검색하면 되므로 O(1)이다.

Union-Find 배열 기반 구현 python 코드

class disjointSet:
    def __init__(self, n):
        self.data = [i for i in range(n)]
        self.size = n

    def find(self, idx):
        return self.data[idx]

    def union(self, x, y):
        x, y = self.find(x), self.find(y)
        if x == y:
            return
        for i in range(len(self.data)):
            if self.data[i] == y:
                self.data[i] = x
        self.size -= 1


s = disjointSet(10)
s.union(0, 1)
s.union(2, 3)
s.union(1, 2)
s.union(0, 1)
s.union(4, 5)
s.union(5, 6)
s.union(7, 8)
s.union(7, 9)

print(s.data)
print(s.size)

Data Structure_Disjoint_Set_Union_Find_002

트리로 Union-Find 구현

두 개의 Disjoint Set이 있다고 하자.

Data Structure_Disjoint_Set_Union_Find_003

노드의 구조는 단순하게 부모 노드가 누구인지만 정보를 담고 있으면 된다. 각 노드의 부모노드 정보만 가지고 있고 각 Disjoint Set 트리의 루트 노드는 루트노드임을 알려주는 지표를 갖는다. 자식 노드를 찾을 일은 없기 때문에 그에 대한 정보는 담지 않는다.

  • 초기화 : 각자 다른 집합이 된다. 모두가 자기자신을 루트노드로 갖는다.

Data Structure_Disjoint_Set_Union_Find_004

  • Union : 각 트리의 루트 노드를 찾은 뒤 다르다면 한쪽 트리의 루트 노드를 다른 한 쪽의 루트 노드로 바꿔 자식으로 넣음으로써 트리를 합한다. ($O(1)$, Find 연산의 시간복잡도에 전적으로 의존)

Data Structure_Disjoint_Set_Union_Find_005

  • Find : 각 노드에 저장된 부모 노드 정보를 따라가서 자기자신을 부모로 갖는 루트 노드를 찾는다. ($O(h)$, 트리의 높이와 시간복잡도가 비례)

Data Structure_Disjoint_Set_Union_Find_006

위 그림에 따라 트리가 최악의 경우라도 더 유리함을 알 수 있다.

Union-Find 트리 기반 구현 python 코드 - 기본형

class disjointSet:
    def __init__(self, n):
        self.data = [-1 for _ in range(n)]
        self.size = n

    def find(self, idx):
        parent = self.data[idx]
        if parent < 0:
            return idx
        return self.find(parent)

    def union(self, x, y):
        x, y = self.find(x), self.find(y)
        if x == y:
            return
        self.data[y] = x
        self.size -= 1


s = disjointSet(10)
s.union(0, 1)
s.union(2, 3)
s.union(1, 2)
s.union(0, 1)
s.union(4, 5)
s.union(5, 6)
s.union(7, 8)
s.union(7, 9)

print(s.data)
print(s.size)

Data Structure_Disjoint_Set_Union_Find_007

위의 코드가 기본적인 기능만을 지원했을 때의 코드다. 여기서 조금 더 고도화하여 두 개의 Disjoint Set이 있을 때 크기가 작은 집합이 큰 집합에 더해지거나 높이가 낮은 트리가 큰 높이를 갖는 트리에 더해지는 게 균형이 맞아보인다.

Union-Find 최적화

Union 연산 최적화하기

트리로 구성했을 때 시간복잡도가 $O(log\ h)$라 했지만 BST의 경우에도 그렇고 사향트리를 형성해버린 경우 원소의 개수가 N일 때 높이가 N인 연결리스트 꼴이 되어버린다. 만약 그렇다면 Find 연산의 시간복잡도는 $O(N)$이 되어버려 배열로 구현했을 때보다 효율이 나빠진다.

Data Structure_Disjoint_Set_Union_Find_008

이를 방지하기 위해 높이 정보를 담아 해결할 수 있다. union-by-rank라 하는데 union-by-height라고도 부른다. 번외로 union-by-size도 구현해보자.

union by size

class disjointSet:
    def __init__(self, n):
        self.data = [-1 for _ in range(n)]
        self.size = n

    def find(self, idx):
        parent = self.data[idx]
        if parent < 0:
            return idx
        return self.find(parent)

    def union(self, x, y):
        x, y = self.find(x), self.find(y)
        if x == y:
            return
        if self.data[x] < self.data[y]:
            self.data[x] += self.data[y]
            self.data[y] = x
        else:
            self.data[y] += self.data[x]
            self.data[x] = y
        self.size -= 1


s = disjointSet(10)
s.union(0, 1)
s.union(2, 3)
s.union(1, 2)
s.union(0, 1)
s.union(4, 5)
s.union(5, 6)
s.union(7, 8)
s.union(7, 9)

print(s.data)
print(s.size)

Data Structure_Disjoint_Set_Union_Find_009

union by rank(union by height)

class disjointSet:
    def __init__(self, n):
        self.data = [-1 for _ in range(n)]
        self.size = n

    def find(self, idx):
        parent = self.data[idx]
        if parent < 0:
            return idx
        return self.find(parent)

    def union(self, x, y):
        x, y = self.find(x), self.find(y)
        if x == y:
            return
        if self.data[x] < self.data[y]:
            self.data[y] = x
        elif self.data[x] > self.data[y]:
            self.data[x] = y
        else:
            self.data[x] -= 1
            self.data[y] = x
        self.size -= 1


s = disjointSet(10)
s.union(0, 1)
s.union(2, 3)
s.union(1, 2)
s.union(0, 1)
s.union(4, 5)
s.union(5, 6)
s.union(7, 8)
s.union(7, 9)

print(s.data)
print(s.size)

Data Structure_Disjoint_Set_Union_Find_010

Find 연산 최적화하기

트리의서의 Union-Find는 전적으로 트리의 높이에 의존하는데 결국 Find 연산의 시간복잡도가 개선되어야 한다. Union-Find 알고리즘에서의 Find 연산을 수행하면서 트리의 높이를 낮추는 과정을 "Path Compression(경로 압축)"이라 부른다.

  • find(node) 실행
  • node가 루트 노드가 아니라면 별도 공간에 임시 저장
  • 루트 노드를 찾을 때까지 1, 2 재귀 반복
  • 3의 결과로 루트 노드를 찾으면 임시로 저장해둔 노드들을 루트 노드의 자식으로 저장

기존 find 함수를 루트 노드를 탐색하면서 임시 공간에 저장하는 부분과 찾고 나서 경로를 압축하는 두 개의 파트로 구분해야 한다.

Path Compression 적용

class disjointSet:
    def __init__(self, n):
        self.data = [-1 for _ in range(n)]
        self.size = n

    def upward(self, buf, idx):
        parent = self.data[idx]
        if parent < 0:
            return idx
        buf.append(idx)
        return self.upward(buf, parent)

    def find(self, idx):
        buf = []
        result = self.upward(buf, idx)
        for i in buf:
            self.data[i] = result
        return result

    def union(self, x, y):
        x, y = self.find(x), self.find(y)
        if x == y:
            return
        if self.data[x] < self.data[y]:
            self.data[y] = x
        elif self.data[x] > self.data[y]:
            self.data[x] = y
        else:
            self.data[x] -= 1
            self.data[y] = x
        self.size -= 1

s = disjointSet(10)
s.union(0, 1)
s.union(2, 3)
s.union(1, 2)
s.union(0, 1)
s.union(4, 5)
s.union(5, 6)
s.union(7, 8)
s.union(7, 9)

print(s.data)
print(s.size) 

Union-Find 알고리즘의 시간복잡도

Union-Find의 시간복잡도는 전적으로 Find 연산의 시간복잡도에 종속되며 Find 연산의 시간복잡도는 트리의 높이 h에 의해 결정됨을 확인하였다. 그러나 Find 연산 수행 시 Path Compression이 수행되어 트리의 높이 변화가 발생한다.

증명과정은 따로 정리해야 할 정도로 길어 결론만 말하면 union-by rankpath compression이 모두 적용됐을 때 평균 시간복잡도는 $O(\alpha(N))$이라고 한다.

$\alpha(N)$은 애커만(Ackermann) 역함수로 매우 빠르게 증가하는 애커만 함수로부터 정의된다.

  • $1 \le N < 3$인 경우 $\alpha(N) = 1$
  • $3 \le N < 7$인 경우 $\alpha(N) = 2$
  • $7 \le N < 63$인 경우 $\alpha(N) = 3$
  • $63 \le N < 2^{2^{2^{65536}}}$인 경우 $\alpha(N) = 4$

4번째 조건의 상한은 무한에 가까운 수로 일반적인 경우 $\alpha(N) = 4$ 로 간주하여 상수와 다를 바 없다. 따라서 Union-Find 알고리즘은 상수시간에 수행이 완료되어 굉장히 빠름을 알 수 있다. 굳이 따지자면 Path Compression을 하느라 임시 공간에 저장했던 노드들을 패치하는 시간 정도를 고려할 수 있겠다.

'Algorithms > Data Structure' 카테고리의 다른 글

세그먼트 트리(Segment Tree)  (0) 2021.08.14
최소 신장 트리(Minimum Spanning Tree)  (1) 2021.08.02
트라이(Trie)  (0) 2021.07.05
힙(Heap)  (0) 2021.07.02
AVL 트리(Andelson-Velsky and Landis Tree)  (0) 2021.06.29

읽기 전

  • 불필요한 코드나 잘못 작성된 내용에 대한 지적은 언제나 환영합니다.
  • 개인적으로 사용해보면서 배운 점을 정리한 글입니다.

문제 링크

BOJ #10217 KCM Travel

문제 풀이

시간 제약이 굉장히 빡빡했던 문제였다. 이 문제 때문에 Python 함수 코드가 일반 코드보다 빠른 이유 포스팅을 작성했었다. 우선 다익스트라로 풀긴 푸는데 비용 제약이 추가되었음을 확인될 수 있다. 따라서 단순히 시간이 빠르다고 경로를 채택했다간 나중에는 비용제약으로 경로 없음을 출력하는 오류가 발생할 여지가 있으나 조금 다르게 생각해야 한다.

다익스트라 알고리즘으로 특정 비용을 써서 임의의 정점까지 최단 거리로 간다고 할 때 그 이후의 비용에 대해서도 전부 해당 경로를 이용하게 될 것이다. 이 점을 이용해보자.

  • [정점][비용]의 2차월 배열 생성
  • 모든 주소를 INF로 초기화하되 [i][i]는 0으로 설정
  • q를 선언하고 안에 (시간, 비용, 정점)을 입력, 초기값은 (0, 0, 1)이 되겠다.
  • q가 값을 갖는 동안 popleft하여 시간, 비용, 탐색할 정점을 꺼낸다.
  • 만약 [정점][비용]이 꺼낸 시간보다 작으면 의미 없는 탐색이므로 다음 순회로 넘어간다.
  • 그게 아니면 탐색할 정점에 연결된 간선들에 대해 현재 시간과 비용에 간선의 시간과 비용을 더한다. 만약 간선의 비용을 더한 결과가 M이하거나 [간선의 도착지][신규 비용]의 소요시간과 비교했을 때 더 작다면 의미있는 탐색이므로 q에 더한다. (신규 시간, 신규 비용, 신규 도착지)
  • q에 위 item을 더하기 전 조건문에서 기존 값과 비교를 한 뒤 더 작은 값임을 확인받았다. 따라서, 이후 cost에 대해 그보다 더 작은 값이 나올 때까지 신규 시간으로 바꿔주면 된다.
  • 모든 탐색이 종료된 후 [정점][주어진 비용]은 주어진 비용 내에서 도착지까지 갈 수 있는 최소 시간이다.
  • 갱신되지 않아 INF면 최단 경로가 존재하지 않음을 의미하고 그렇지 않다면 그대로 출력한다.

참고로 해당 로직을 함수 안에 넣고 실행하면 통과하나 그렇지 않으면 시간 초과를 출력한다. 원인에 대해 찾아본 결과를 나름대로 Python 함수 코드가 일반 코드보다 빠른 이유에 정리하였다.

python 코드

import sys
from collections import deque
input = sys.stdin.readline

def solve():
    INF = float('inf')
    for _ in range(int(input())):
        N, M, K = map(int, input().split())
    edges = [[] for _ in range(N + 1)]
    for _ in range(K):
        u, v, c, d = map(int, input().split())
        edges[u].append((v, c, d))
    dist = [[INF] * (M + 1) for _ in range(N + 1)]
    q = deque()
    q.append((0, 0, 1))
    while q:
        time, cost, node = q.popleft()
        if time > dist[node][cost]:
            continue
        for city, c, t in edges[node]:
            alt_t, alt_c = time + t, cost + c
            if alt_c <= M and alt_t < dist[city][alt_c]:
                for i in range(alt_c, M + 1):
                    if alt_t < dist[city][i]:
                        dist[city][i] = alt_t
                    else:
                        break
                q.append((alt_t, alt_c, city))
    s = dist[N][M]
    if s == INF:
        print("Poor KCM")
    else:
        print(s)

solve()

'Algorithms > Baekjoon' 카테고리의 다른 글

BOJ #1107 리모컨  (0) 2021.08.10
BOJ #11051 이항 계수 2  (0) 2021.08.10
BOJ #9251 LCS  (0) 2021.07.28
BOJ #11054 가장 긴 바이토닉 부분 수열  (0) 2021.07.28
BOJ #11444 피보나치 수 6  (0) 2021.07.28

읽기 전

  • 불필요한 코드나 잘못 작성된 내용에 대한 지적은 언제나 환영합니다.
  • 개인적으로 사용해보면서 배운 점을 정리한 글입니다.

문제 제기

알고리즘 문제를 풀다보면 항상 시간 제약에 민감해질 수밖에 없는데 그 중에서도 Python은 C나 Java에 비해 속도가 느려 체감이 더 크다. BOJ #10217 KCM Travel문제를 풀던 중 Python으로 해결 시 같은 로직임에도 함수 안에 넣고 넣지 않고로 통과 여부가 결정됨을 확인했다. 그 원인에 대해 찾아본 결과 Why does Python code run faster in a function?에서 질문/답변을 주고받은 글이 있어 개인적으로 정리해보려 한다.

시간 측정

$10^8$번 순회를 도는 시간을 체크해보자.

python 코드

from time import time


def check():
    start = time()
    for i in range(10 ** 8):
        pass
    return time() - start


print(check())
start = time()
for i in range(10 ** 8):
    pass
print(time() - start)

Development_Python_function code vs global code_001

실행 결과 약 2배의 차이가 난다. 물론 코드를 실행하는 머신의 상태에 따라 가변적이겠으나 일반적인 경우에도 차이가 남을 확인할 수 있을 것이다. 글 원문에서는 프로세서의 차이인지 거의 4배에 가까운 속도 차이를 보여주었다. 이 정도면 실무에서든 문제 풀이에서든 충분히 의미 있는 시간 차이로 보인다.

function 코드와 global 코드와의 차이

결론부터 말하면 Python의 코어를 담당하는 CPython의 구현 방식으로 인해 차이가 발생한다. 함수가 컴파일되면 크기가 정해진 배열에 로컬 변수들을 저장한다. 이 과정으로 인해 함수에 동적으로 변수를 추가할 수 없는 것이다. 따라서, 함수 내부의 특정 변수를 global하게 접근하기 위해선 global 변수를 명시적으로 붙여주어야 한다. 그렇지 않으면 함수 내부 변수를 저장할 때 STORE_FAST opcode를 사용하는데 global 변수는 STORE_NAME opcode를 사용하기 때문에 접근할 수 없기 때문이다. 이러한 변수 저장 방식의 차이가 속도 차이의 요인으로 작용한다.

STORE_FASTSTORE_NAME에 왜 차이가 발생하는가

위에서 function을 compile하면서 크기가 정해진 배열에 저장하니 global 변수와는 달리 호출이 빠르다는 점은 인지했다. 그러나 별도로 STORE_NAME opcode와 STORE_FAST opcode에 대한 요인도 언급된다. 일반적으로 반복문은 FOR_ITER opcode를 호출하는데 반복 순회 시 loop의 top에 FOR_ITER이 위치하게 되고 그 이후 STORE_NAME opcode가 올 것이라 "예측"하게 된다. 만약 함수 내부에 변수를 넣어 내부 변수로 처리되었다면 바로 STORE_FAST opcode로 점프하여 검증 과정을 생락하기 때문에 사실상 1개의 opcode로 작용한다.

만약 global 레벨에서 STORE_NAME opcode가 루프에 사용된 경우 예측에 실패하기 때문에 내부 변수와는 달리 skip 과정이 없다. 만약 아래와 같이 함수의 변수를 global로 선언하여 실행하면 확실히 시간 지연이 발생함을 확인할 수 있다.

from time import time

def check():
    start = time()
    global x
    for i in range(10 ** 8):
        x += 1
    return time() - start

x = 0
print(check())
start = time()
k = 0
for j in range(10 ** 8):
    k += 1
print(time() - start)

Development_Python_function code vs global code_002

함수 내부의 변수를 global 변수로 선언해서 다루면 갑자기 소요시간이 급증함을 확인할 수 있다. 사실상 global 코드와 동일한 시간을 갖는 걸 볼 수 있다.

결론

시간 지연을 고려하여 모든 기능 단위 코드는 함수에 넣어 작성하고 혹여 global한 값을 반복문에 넣지 않았는지 체크해야 한다. 꼭 필요하다면 메모리 제약조건을 확인한 후 내부 변수로 받아서 실행함이 속도에 유리할 것이다.

'Development > Python' 카테고리의 다른 글

Python GIL(Global Interpreter Lock)  (10) 2021.11.10

+ Recent posts