읽기 전

  • 불필요한 코드나 잘못 작성된 내용에 대한 지적은 언제나 환영합니다.
  • 개인적으로 사용해보면서 배운 점을 정리한 글입니다.

세그먼트 트리의 비재귀적(반복문) 구현

백준 문제를 풀다가 세그먼트 트리로 푸는 문제인데 시간 제약이 빠듯한 문제를 만났다. 도저히 통과가 안되길래 질문글을 봤더니 pypy3로 제출해야 함은 물론이거니와 그마저도 재귀는 시간이 오래 걸려 비재귀 방식으로 구현해야 한단다. 찾아본 결과 python으로 세그먼트 트리를 비재귀적 방식을 사용해 정리한 포스팅을 찾을 수 없어 개인적으로 정리하게 되었다. 구현 방식은 비교적 간단한 편이다. 재귀로 구현한 포스팅은 세그먼트 트리(Segment Tree)에 정리한 바 있다.

재귀 vs 비재귀(반복) - Recursion VS Loop

세그먼트 트리의 재귀적 구현 방식의 핵심은 분할 정복이었다. 따라서 리프 노드의 위치가 모여있지 않고 떨어진 경우가 있는 반면 비재귀적인 방식은 먼저 정해진 위치에 기존 배열 원소를 리프 노드로 두고 힙을 역으로 구성하듯이 접근한다. 배열 nums = [1, 3, 5, 2, 6, 4]를 예를 들어 구현해보자.

세그먼트 트리 생성

Data Structure_Non-Recursive_Segment_Tree_001

먼저 주어진 배열 길이 N에 대해 각 배열 원소의 좌표를 더한 위치에 값을 저장한다. 그리고 트리에서의 양쪽 자식노드는 [i * 2], [i * 2 + 1]이므로 해당 위치의 값을 더해 N - 1부터 1좌표까지 역으로 완성해나간다. 그러면 1좌표에 배열의 모든 원소 합이 들어가게 된다. Python코드로 구현해보자.

Python 코드

def init(tree, N):
    for i in range(N - 1, 0, -1):
        tree[i] = tree[i << 1] + tree[i << 1 | 1]

nums = [1, 3, 5, 2, 6, 4]
print(f'nums: {nums}')
segment_tree = [0] * (2 * len(nums))
N = len(nums)
for i in range(len(nums)):
    segment_tree[N + i] = nums[i]
init(segment_tree, N)
print(f'segment tree : {segment_tree}')

Data Structure_Non-Recursive_Segment_Tree_002

<<는 우측으로 1비트 옮김을 의미하고 현재 값에 2를 곱하는 연산과 결과가 동일하다. |는 or 연산으로 <<연산 후 0이 된 1의 자리에 or연산을 적용하면 1을 더한 값과 동일해진다. tree[i] = tree[i * 2] + tree[i * 2 + 1]로 바꿔도 동작 결과는 동일하다.

세그먼트 트리 쿼리

반복문을 사용한 세그먼트 트리의 쿼리 범위는 좌/우측 좌표가 주어졌을 때 [left, right)이다. right좌표를 포함하지 않기 때문에 right좌표까지 포함한 값을 구하고 싶다면 right + 1을 주어야 하고 우측 좌표가 배열 길이 N보다 크다면 N으로 바꿔 조회 시 오류가 없게끔 한다. 1좌표부터 3좌표까지의 합을 구한다고 가정하면 함수에 left = 1, right = 4를 입력해야 한다. Python 코드로 구현해보자.

Python 코드

def query(tree, N, left, right):
    result = 0
    left += N
    right += N
    while left < right:
        if left % 2 == 1:
            result += tree[left]
            left += 1
        if right % 2 == 1:
            result += tree[right - 1]
            right -= 1
        left //= 2
        right //= 2
    return result

print(f'idx 1 to 3 : {nums[1:4]}')
print(f'sum 1 to 3 : {query(segment_tree, N, 1, 4)}')

Data Structure_Non-Recursive_Segment_Tree_003

Data Structure_Non-Recursive_Segment_Tree_004

위 그림이 쿼리 반복문의 첫 시작이다. left 좌표에 N을 더한 결과가 7이고 left % 2 == 1이므로 해당 좌표의 값 3을 더하고 left에 1을 더하여 2를 나누면 4가 된다. right 좌표에 N을 더한 결과가 10이므로 그대로 2를 나눠 5가 된다.

Data Structure_Non-Recursive_Segment_Tree_005

left가 4가 되었고 left % 2 == 0이 되어 그대로 2로 나눈 값 2가 된다. right는 5가 되었고 right % 2 == 1이므로 tree[right - 1]의 값 7을 더하고 right 좌표에 1을 빼고 2로 나누면 2가 된다. 다음 반복은 left < right가 성립하므로 진행되지 않는다. 쿼리 결과가 10인데 이는 1좌표에서 3좌표까지의 합 3 + 5 + 2와 동일하다.

세그먼트 트리 갱신

기존 배열의 특정 좌표의 값을 수정하고자 한다면 세그먼트 트리의 값도 수정해야 한다. 트리를 생성하는 과정과 유사하다. 기존 배열의 원소와 대응되는 트리 노드의 좌표를 알고 있으므로(N + i) 먼저 리프노드의 값을 반영하고 그 값을 기준으로 루트 노드에 다다를 때까지 반복한다.

자식 노드부터 출발했으므로 부모 노드의 값을 변경해야 한다. 부모 노드는 현재 좌표 idx 기준 우측으로 1비트 옮겨 2로 나눈 값과 동일하다. 그렇다면 현재 좌표와 부모 노드의 좌표는 구했으나 부모 노드의 다른 자식 노드의 좌표값을 구해야 한다. 산술적으로 현재 좌표가 2의 배수면 1을 더한 값이고 2의 배수가 아니라면 1을 뺀 값이다. 이를 간단하게 표현하면 1과 xor 연산한 값이라 말할 수 있겠다.

python 코드

def update(tree, N, i, val):
    tree[N +i] = val
    i += N
    while i > 1:
        tree[i >> 1] = tree[i] + tree[i ^ 1]
        i >>= 1

update(segment_tree, N, 1, 4)
nums[1] = 4
print(f'idx 1 to 3 : {nums[1:4]}')
print(f'sum 1 to 3 : {query(segment_tree, N, 1, 4)}')

Data Structure_Non-Recursive_Segment_Tree_006

전체 코드

from math import ceil, log


def init(tree, N):
    for i in range(N - 1, 0, -1):
        tree[i] = tree[i << 1] + tree[i << 1 | 1]


def query(tree, N, left, right):
    result = 0
    left += N
    right += N
    while left < right:
        if left % 2 == 1:
            result += tree[left]
            left += 1
        if right % 2 == 1:
            result += tree[right - 1]
            right -= 1
        left //= 2
        right //= 2
    return result


def update(tree, N, i, val):
    tree[N + i] = val
    i += N
    while i > 1:
        tree[i >> 1] = tree[i] + tree[i ^ 1]
        i >>= 1


nums = [1, 3, 5, 2, 6, 4]
print(f'nums: {nums}')
segment_tree = [0] * (2 * len(nums))
N = len(nums)
for i in range(len(nums)):
    segment_tree[N + i] = nums[i]
init(segment_tree, N)
print(f'segment tree : {segment_tree}')

print(f'idx 1 to 3 : {nums[1:4]}')
print(f'sum 1 to 3 : {query(segment_tree, N, 1, 4)}')

update(segment_tree, N, 1, 4)
nums[1] = 4
print(f'idx 1 to 3 : {nums[1:4]}')
print(f'sum 1 to 3 : {query(segment_tree, N, 1, 4)}')

+ Recent posts