자료구조와 알고리즘/Graph

Segment Tree and Range Minimum Query

그레이트쪼 2016. 9. 22. 00:01

Segment Tree 소개

  • 1차원 배열에서 특정 구간의 합을 구하면 어떻게 하겠는가? 아마 가장 간단한 방법으로 구간을 차례로 탐색하면 될 것이다. 이 경우 시간 복잡도는 O(n)이다. 하지만 질문을 여러 번 한다면 그때마다 O(n)의 시간을 소모하게 된다.

  • 다른 방법으로는 첫 번째 원소에서부터 해당 원소까지의 합을 미리 계산해 놓는다.. S1, S2, S3, ..., Sn. 질문이 주어졌을 때 구간이 i ~ j까지라면 Sj - Si 를 하면 O(1)만에 답을 구할 수 있다. 이 방법은 일종의 DP (memorization) 방법이다. 하지만 원소의 값이 update 된다면 합을 다시 구해야 하기 때문에 O(n)이 소요된다. 따라서 update가 자주 일어난다면 첫 번째 방법 대비 별 이득이 없다.
  • 그렇다면 update와 sum을 모두 O(log n)에 할 수 는 없을까? 바로 Segment Tree를 활용하면 된다.
  • Segment tree는 배열의 모든 요소가 leaf node에 배치된다. 그리고 internal node는 leaf node들을 merge한 sum 값을 갖는다.

  • Segment tree는 Full Binary Tree이다. (즉 자식이 없으면 없었지 있으면 2개다) 그리고 애초 배열의 개수인 n개의 leaf node를 가지고 n-1개의 internal node를 가진다. (따라서 tree array의 size는 2n-1이면 된다)
  • Tree array의 size를 좀 더 수학적으로 기술하자면 tree의 높이(h)는 log2N의 ceiling이고 tree의 size는 2 x 2h -1 이다.



Segment tree의 구현

  • Segment tree를 만들 때에는 조각에 원소가 하나밖에 없을 때까지 계속해서 배열을 분할해 가는 방법을 사용한다.
  • Segment tree가 Full Binary Tree라는 점을 이용하여 tree의 구현은 array를 이용한다. 왼쪽 자식의 index은 2*i+1, 오른쪽 자식의 index는 2*i+2, 부모의 index는 (i-1)/2이다.
  • 주어진 범위에서의 값을 찾는 것은 아래와 같다.

int getSum(node, l, r)

{

    if (node 범위가 l ~ r 사이면)

        return node ;

    else if (node 범위가 완전히 l ~ r 범위 바깥이면)

        return 0;

    else

        return getSum(node's left child, l, r) + getSum(node's right child, l, r);

}


Minimum Range Query (RMQ)

  • Segment tree는 sum을 구하는 것보다 min값을 구하는 데 더 유용하다.
  • 보통 배열의 부분 구간에 대한 min값을 구하는 방법은 차례로 탐색하는 방법이다. 이 경우 시간 복잡도는 O(n). 하지만 역시 질문을 여러 번 한다면 그때마다 O(n)의 시간을 소모하게 된다.
  • 다른 방법으로는 미리 계산해 놓는 memoization이 있는데 그나마 우아했던 Sum 때의 방법과 달리 모든 원소 조합에 대해 min값을 미리 구해 놓는 수 밖에 없다. 즉 2D 배열을 만들고 (i, j) 엔트리에 i ~ j 구간의 최소값을 저장한다. 이 과정이 O(n2)이다. 한 번만 구해 놓는다면 질문에 대해 O(1)만에 답을 구할 수 있긴하지만 그래도 O(n2)는 좀 아니다. 게다가 원소의 값이 update 되기라도 한다면 또다시 O(n2)이 소요된다.
  • 따라서 segment tree를 이용해서 update와 min을 모두 O(log n)에 구하자.

  • 앞서 sum에 대한 segment tree 코드에서 internal node에 sum 대신 min을 저장하면된다. 또 query 과정에서는 분할된 영역들의 min 중에서 min을 선택하면 된다.

int getMin(node, l, r)

{

    if (node 범위가 l ~ r 사이면)

        return node ;

    else if (node 범위가 완전히 l ~ r 범위 바깥이면)

        return INFINITE;

    else

        return min(getMin(node's left child, l, r), getMin(node's right child, l, r));

}


RMQ Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <stdio.h>
#include <math.h>
 
int minVal(int x, int y) { return (x < y) ? x : y; }
const int INT_MAX = ~(1 << 31);
 
// A recursive function to get the minimum value in a given range of array indexes. 
// l & r: Searching space
// index: Index of current node in the segment tree. (Initially 0)
// lb & rb: Query range
int _getMin(int* st, int l, int r, int index, int lb, int rb)
{
    // If segment of this node is a part of given range,
    // then return the min of the segment
    if (lb <= l && r <= rb) {
        return st[index];
    }
 
    // If segment of this node is outside the given range
    if (r < lb || rb < l)
        return INT_MAX;
 
    // If a part of this segment overlaps with the given range
    int m = (l + r) / 2;
    return minVal(_getMin(st, l, m, index * 2 + 1, lb, rb),
                   _getMin(st, m + 1, r, index * 2 + 2, lb, rb));
}
 
// Get the minimum number of the specified range. 
int getMin(int* st, int lb, int rb)
{
    return _getMin(st, 0, n - 10, lb, rb);
}
 
// A recursive function to update the specified value. 
// l & r: Updating range
// index: Index of current node in the segment tree. (Initially 0)
// i & value: Updating position and value
int _updateValue(int* st, int l, int r, int index, int i, int value)
{
    // If the input index is in range of this node, then update 
    // the value of the node and its children
    if (l <= i && i <= r) {
        // In case of a leaf node, copy it to the tree
        if (l == r) {
            st[index] = value;
        }
        // Otherwise, divide this into subproblems
        else {
            int m = (l + r) / 2;
            st[index] = minVal(_updateValue(st, l, m, index * 2 + 1, i, value),
                                 _updateValue(st, m + 1, r, index * 2 + 2, i, value));
        }
    }
    // Skip, if the input index lies outside the range of this segment
 
    return st[index];
}
 
// Update the specified position of the array and corresponding segment tree. 
int updateValue(int* st, int arr[], int i, int value)
{
    arr[i] = value;
    return _updateValue(st, 0, n - 10, i, value);
}
 
// A recursive function to build a segment tree. 
// l & r: Range of array to build
// index: Index of current node in the segment tree. (Initially 0)
int _buildTree(int arr[], int* st, int l, int r, int index)
{
    if (l == r) {
        // In case of a leaf node, copy it to the tree
        st[index] = arr[l];
    }
    else {
        // Divide and merge in the top-down manner
        int m = (l + r) / 2;
        st[index] = minVal(_buildTree(arr, st, l, m, index * 2 + 1),
                             _buildTree(arr, st, m + 1, r, index * 2 + 2));
    }
 
    return st[index];
}
 
// Build a segment tree from the specified array. 
int* buildTree(int arr[], int n)
{
    int h = (int)(ceil(log2(n))); // Height of segment tree
    int size = 2 * (int)pow(2, h) - 1// Maximum size of segment tree
    int* st = new int[size];
 
    _buildTree(arr, st, 0, n - 10);
 
    return st;
}
 
int main()
{
    int arr[] = { 37519 };
    int n = sizeof(arr) / sizeof(int);
 
    int* st = buildTree(arr, n);
 
    int min = getMin(st, 13);
    printf("Min of [1 ~ 3] is %d\n", min);
 
    // Update: set arr[3] = 10 and update corresponding segment tree nodes
    updateValue(st, arr, 310);
 
    min = getMin(st, 13);
    printf("Min of [1 ~ 3] is %d\n", min);
 
    return 0;

  • Build up된 segment tree를 보면 leaf node의 순서가 유지되는 것은 아니다.
  • Recursive 구조이기 때문에 top-down으로 분할하여 내려간다. 실제로 tree를 형성하는 것은 더 이상 분할되지 않을 때인데 분할 된 것 중 뒷부분이 먼저 된다. 이후 call stack을 따라 올라오면서 merge된다고 볼 수 있다.
  • 최초로 tree를 형성하는 것은 배열의 뒷부분인 1과 9이다. 1과 9는 st[]의 index 5, 6에 들어가고 이들의 internal node는 index 2에 들어간다.

  • 앞부분의 3, 7, 5는 한번 더 분할 가능하고 먼저 5는 단독으로 leaf node가 된다. st[]의 index 4에 들어간다.

  • 앞부분의 앞부분인 3, 7이 tree를 형성한다. Leaf node 3, 7은 st[]의 index 7, 8에 들어가고 이들의 internal node는 st[]의 index 3에 들어간다. (깊은 depth가 st[]의 뒤로 간다)

  • 앞부분이 최종 merge된다. (이들의 internal node가 계산된다)

  • 전체 tree가 merge된다. (root가 계산된다)



Minimum Range Query for 2D array

  • 이번엔 segment tree를 활용해서 2D array의 특정 영역에 대한 최소값을 구해보자. Array는 정방형으로 N x N을 가정하자.
  • 2D array의 경우 4분할을 해야한다는 점이 특징이고 나머지 원리는 동일하다. 따라서 사용되는 tree도 binary tree가 아니라 quadtree이다.

  • st[]의 size는 어떻게 구하나? 어짜피 한변이 N이고 N기준에서는 2분할이다. (N x N 기준으로 4분할) 따라서 height는 1D segment tree 때와 동일하고 node당 child 수만 4로 계산하면 된다.

int getMin(node, query scope)

{

    if (node's scope query scope 안쪽이면)

        return node ;

    else if (node's scope query scope 완전한 바깥이면)

        return INFINITE;

    else

        return min(getMin(node's left-top child, query scope),

                    getMin(node's right-top child, query scope),

                    getMin(node's left-bottom child, query scope),

                    getMin(node's right-bottom child, query scope);

}


2D RMQ Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include <stdio.h>
#include <math.h>
 
const int INT_MAX = ~(1 << 31);
int minVal(int a, int b, int c, int d)
{
 int min = (a < b) ? a : b;
 min = (c < min) ? c : min;
 min = (d < min) ? d : min;
 return min;
}
 
// A recursive function to get the minimum value in a given range of array indexes. 
// (ax, ay) & (bx, by): node's scope (searching space)
// index: Index of current node in the segment tree. (Initially 0)
// (x, y) & (v, w): Query range
int _getMin(int* st, int ax, int ay, int bx, int by, int index, int x, int y, int v, int w)
{
    // If segment of this node is outside the given range
    if (ax > bx || ay > by || bx < x || v < ax || by < y || w < ay)
        return INT_MAX;
 
    // If segment of this node is a part of given range,
    // then return the min of the segment
    if (x <= ax && bx <= v && y <= ay && by <= w) {
        return st[index];
    }
 
    // If a part of this segment overlaps with the given range
    int mx = (ax + bx) / 2;
    int my = (ay + by) / 2;
    return minVal(_getMin(st, ax, ay, mx, my, index * 4 + 1, x, y, v, w),
                   _getMin(st, ax, my + 1, mx, by, index * 4 + 2, x, y, v, w),
                   _getMin(st, mx + 1, ay, bx, my, index * 4 + 3, x, y, v, w),
                   _getMin(st, mx + 1, my + 1, bx, by, index * 4 + 4, x, y, v, w));
}
 
// Get the minimum number of the specified range. 
int getMin(int* st, int x, int y, int v, int w)
{
    return _getMin(st, 00, n - 1, n - 10, x, y, v, w);
}
 
// A recursive function to update the specified value. 
// (ax, ay) & (bx, by): node's scope (updating range)
// index: Index of current node in the segment tree. (Initially 0)
// (x, y) & (v, w): Updating position and value
int _updateValue(int* st, int ax, int ay, int bx, int by, int index, int x, int y, int value)
{
    // Out of scope
    if (ax > bx || ay > by)
        return INT_MAX;
 
    // If the input index is in range of this node, then update 
    // the value of the node and its children
    if (ax <= x && x <= bx && ay <= y && y <= by) {
        // In case of a leaf node, copy it to the tree
        if (ax == bx && ay == by) {
            st[index] = value;
        }
        // Otherwise, divide this into subproblems
        else {
            int mx = (ax + bx) / 2;
            int my = (ay + by) / 2;
            st[index] = minVal(_updateValue(st, ax, ay, mx, my, index * 4 + 1, x, y, value),
                _updateValue(st, ax, my + 1, mx, by, index * 4 + 2, x, y, value),
                _updateValue(st, mx + 1, ay, bx, my, index * 4 + 3, x, y, value),
                _updateValue(st, mx + 1, my + 1, bx, by, index * 4 + 4, x, y, value));
        }
    }
    // Skip, if the input index lies outside the range of this segment
 
    return st[index];
}
 
// Update the specified position of the array and corresponding segment tree. 
int updateValue(int* st, int arr[], int x, int y, int value)
{
    arr[x][y] = 10;
    return _updateValue(st, 00, n - 1, n - 10, x, y, value);
}
 
// A recursive function to build a segment tree. 
// (ax, ay) & (bx, by): Range of array to build
// index: Index of current node in the segment tree. (Initially 0)
int _buildTree(int arr[], int* st, int ax, int ay, int bx, int by, int index)
{
    // Out of scope
    if (ax > bx || ay > by)
        return INT_MAX;
 
    if (ax == bx && ay == by) {
        // In case of a leaf node, copy it to the tree
        st[index] = arr[ax][ay];
    }
    else {
        // Divide and merge in the top-down manner
        int mx = (ax + bx) / 2;
        int my = (ay + by) / 2;
        st[index] = minVal(_buildTree(arr, st, ax, ay, mx, my, index * 4 + 1),
                             _buildTree(arr, st, ax, my + 1, mx, by, index * 4 + 2),
                             _buildTree(arr, st, mx + 1, ay, bx, my, index * 4 + 3),
                             _buildTree(arr, st, mx + 1, my + 1, bx, by, index * 4 + 4));
    }
 
    return st[index];
}
 
// Build a segment tree from the specified array. 
int* buildTree(int arr[], int n)
{
    int h = (int)(ceil(log2(n))); // Height of segment tree
    int size = 4 * (int)pow(2, h) - 1// Maximum size of segment tree
    int* st = new int[size];
 
    _buildTree(arr, st, 00, n - 1, n - 10);
 
    return st;
}
 
int main()
{
    int arr[][] = { {12345},
                      {678910},
                      {12345},
                      {678910},
                      {12345} };
    int n = sizeof(arr[]) / sizeof(int);
 
    int* st = buildTree(arr, n);
 
    int min = getMin(st, 1133);
    printf("Min of [(1, 1) ~ (3, 3)] is %d\n", min);
 
    // Update: set arr[2][1] = 8 and update corresponding segment tree nodes
    updateValue(st, arr, 218);
 
    min = getMin(st, 1133);
    printf("Min of [(1, 1) ~ (3, 3)] is %d\n", min);
 
    return 0;
}

  • Top-down 방식으로 4분할을 하기 때문에 반대로 tree를 형성하는 것은 bottom-up방식이다. 아래와 같이 자른다. 자른 뒤에는 4분면 -> 3분면 -> 2분면 -> 1분면 순으로 분할 계산한다.
  • 첫번째 분할된 모습이다.

  • 두번째 분할된 모습이다. 더 이상 쪼갤 수 없는 4분면의 9, 10, 4, 5가 가장 먼저 sub tree를 형성한다.

  • 세번째 분할된 모습이다. 이제 모든 원소들이 더 이상 쪼갤 수 없기 때문에 차례대로 sub tree를 형성한다.