S E P H ' S

[자료구조] 13. 세그먼트 트리(Segment Tree) 본문

Programing & Coding/Data Structure

[자료구조] 13. 세그먼트 트리(Segment Tree)

yoseph0310 2023. 6. 29. 17:48

 

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

BOJ 2042 구간 합 구하기 풀이를 하면서 정리를 시작하게 됐다. 이 문제는 세그먼트 트리에 대해 이해하고 있어야 접근이 가능하다. 우선 세그먼트 트리에 대해서 정리하고 문제 풀이를 시작해보겠다.


세그먼트 트리 (Segment Tree)

특정 구간 내 데이터에 대한 연산(쿼리)를 빠르게 구할 수 있는 트리이다. 예를 들어 특정 구간 합, 최소값, 최대값, 평균값 등을 구하는데 용이하다. 시간 복잡도는 아래와 같다.

 

데이터 변경 : O(logN)
연산 : O(logN)
데이터 변경할때마다 M번의 연산 : O((logN + logN) * M) = O(MlogN))

 

구조

기본적으로 세그먼트 트리는 이진트리 구조를 가진다. 그래서 이진트리의 특징대로 다음과 같은 특징 또한 가진다.

 

  • 왼쪽 자식노드 index = 부모 노드 index * 2
  • 오른쪽 자식노드 index = 부모 노드 index * 2 + 1

이 인덱스를 통해서 여러 연산을 진행하거나 접근할 수 있다. 그러면 세그먼트 트리의 크기는 어떻게 될까? 그것은 트리의 높이로 구할 수 있다. 트리의 높이는 Root Node 에서 가장 긴 경로를 뜻한다. 위 트리에서 가장 긴 경로는 3이다.

 

배열의 길이를 8이라고 해보자. 이진 트리에 데이터를 저장해야 한다고 할때 각 배열의 원소들을 단말노드(위 그림에서 8 ~ 15번)라고 생각해보자. 그렇다면 각각의 노드들과 차수가 1인 노드(4 ~ 7번, 2 ~3번, 1번)들의 개수를 보면 8 - 4 - 2 - 1 이다.

 

노드들의 개수와 경로로 알 수 있는 것은 2^h = 배열의 길이 이다. 이 경우 전체 노드 개수는 2^(h+1)-1 이다. 그래서 세그먼트 트리를 저장할 배열의 크기는 2^(h+1) 이면 충분하다. 그리고 인덱스를 위해서 루트 노드의 인덱스를 보통 1로 사용하기 때문에 0을 하나 둔다. 그러므로 2^(h+1) 이면 충분한 것이다.

 

세그먼트 트리 구현

생성 및 구성

class SegmentTree {
    long tree[];		// 각 원소가 담길 트리
    int treeSize;		// 트리 크기
    
    public SegmentTree(int arrSize) {
        // 트리 높이 구하기
        int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2));
        
        // 높이를 이용해 배열 사이즈 구하기
        this.treeSize = (int) Math.pow(2, h+1);
        
        // 배열 생성
        tree = new long[treeSize];
    }
    
     /**
     * 1. 생성 및 구성
     *
     * @param arr : 원소 배열
     * @param node : 현재 노드
     * @param start : 현재구간 배열 시작
     * @param end : 현재구간 배열 끝
     *
     * @return : 원소 배열 값 or 자식노드의 합
     */
    public long init(long[] arr, int node, int start, int end) {
        // 배열의 시작과 끝이 같으면 단말노드임 즉, 원소를 그대로 담는다.
        if (start == end) return tree[node] = arr[start];
        
        // 단말 노드가 아니면 자식노드의 합을 담는다.
        int mid = (start + end) / 2;
        return tree[node] = init(arr, node * 2, start, mid) + init(arr, node * 2 + 1, mid + 1, end);
    }
}

 

 

 

 

데이터 변경

3번째 원소의 값을 7에서 9로 변경해보자. 그렇다면 3번째 원소와 관련된 노드들이 모두 변경되어야 한다. 우선 원래 값과 변경할 값의 차이를 구해주고 변경된 차이만큼 더해주고 자식노드들을 확인한다. 만약 자식노드의 합 범위가 변경된 원소와 관련 없으면 따로 확인 하지 않는다. 이렇게 관련된 자식노드들을 모두 변경해 나간다. 그리고 원래 배열의 값도 변경해야 한다.

 

/**
 * 2. 데이터 변경
 *
 * @param node : 현재 노드 idx
 * @param start : 배열의 시작
 * @param end : 배열의 끝
 * @param idx : 변경된 데이터의 idx
 * @param diff : 원래 데이터 값과 변경 데이터값의 차이
 */
public void update(int node, int start, int end, int idx, long diff) {
    // 만약 변경할 idx가 범위 밖이면 확인할 필요가 없다.
    if (idx < start || end < idx) return;
    
    // 차이를 저장한다.
    tree[node] += diff;
    
    // 단말 노드가 아니라면 아래 자식 노드들도 확인을 거친다.
    if (start != end) {
        int mid = (start + end) / 2;
        
        update(node * 2, start, mid, idx, diff);
        update(node * 2 + 1, mid + 1, end, idx, diff);
    }
}

 

구간 합 구하기

만약 3번에서 5번까지의 합을 구한다고 생각해보자. 그림에서처럼 3~4번 까지의 합과 5번 원소의 합만 구하면 정답을 얻을 수 있다. 이 역시 루프 노드에서 부터 확인한다. 자식노드로 내려가면서 확인하다가 현재 배열에서 찾고자 하는 범위를 벗어나면 0을 반환하도록 한다.

 

그리고 현재 찾고자 하는 범위에 포함되면 현재 값을 그대로 반환한다. 아니라면 그 밑의 자식을 확인하도록 한다. 최종적으로 모두 찾게되면 재귀를 타고 올라오면서 더한 값이 반환되도록 하면 된다.

 

/**
 * 3. 구간 합 구하기
 *
 * @param node : 현재 노드
 * @param start : 배열의 시작
 * @param end : 배열의 끝
 * @param left : 원하는 누적합의 시작
 * @param right : 원하는 누적합의 끝
 * @return : 누적합
 */
 public long sum(int node, int start, int end, int left, int right) {
     // 범위를 벗어나게 되면 더할 필요가 없다.
     if (left > end || start > right) return 0;
     
     // 범위 내에 포함된다면 더 내려가지 않고 바로 리턴하면 된다.
     if (left <= start && end <= right) return tree[node];
     
     // 그 외의 경우에는 지속적으로 탐색을 진행한다.
     int mid = (start + end) / 2;
     return sum(node * 2, start, mid, left, right) + sum(node * 2 + 1, mid + 1, end, left, right);
 }

 

최종 코드

class SegmentTree {
    long tree[];		// 각 원소가 담길 트리
    int treeSize;		// 트리 크기
    
    public SegmentTree(int arrSize) {
        // 트리 높이 구하기
        int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2));
        
        // 높이를 이용해 배열 사이즈 구하기
        this.treeSize = (int) Math.pow(2, h+1);
        
        // 배열 생성
        tree = new long[treeSize];
    }
    
     /**
     * 1. 생성 및 구성
     *
     * @param arr : 원소 배열
     * @param node : 현재 노드
     * @param start : 현재구간 배열 시작
     * @param end : 현재구간 배열 끝
     *
     * @return : 원소 배열 값 or 자식노드의 합
     */
    public long init(long[] arr, int node, int start, int end) {
        // 배열의 시작과 끝이 같으면 단말노드임 즉, 원소를 그대로 담는다.
        if (start == end) return tree[node] = arr[start];
        
        // 단말 노드가 아니면 자식노드의 합을 담는다.
        int mid = (start + end) / 2;
        return tree[node] = init(arr, node * 2, start, mid) + init(arr, node * 2 + 1, mid + 1, end);
    }
    
    /**
     * 2. 데이터 변경
     *
     * @param node : 현재 노드 idx
     * @param start : 배열의 시작
     * @param end : 배열의 끝
     * @param idx : 변경된 데이터의 idx
     * @param diff : 원래 데이터 값과 변경 데이터값의 차이
     */
    public void update(int node, int start, int end, int idx, long diff) {
        // 만약 변경할 idx가 범위 밖이면 확인할 필요가 없다.
        if (idx < start || end < idx) return;

        // 차이를 저장한다.
        tree[node] += diff;

        // 단말 노드가 아니라면 아래 자식 노드들도 확인을 거친다.
        if (start != end) {
            int mid = (start + end) / 2;

            update(node * 2, start, mid, idx, diff);
            update(node * 2 + 1, mid + 1, end, idx, diff);
        }
    }
    
    /**
     * 3. 구간 합 구하기
     *
     * @param node : 현재 노드
     * @param start : 배열의 시작
     * @param end : 배열의 끝
     * @param left : 원하는 누적합의 시작
     * @param right : 원하는 누적합의 끝
     * @return : 누적합
     */
     public long sum(int node, int start, int end, int left, int right) {
         // 범위를 벗어나게 되면 더할 필요가 없다.
         if (left > end || start > right) return 0;

         // 범위 내에 포함된다면 더 내려가지 않고 바로 리턴하면 된다.
         if (left <= start && end <= right) return tree[node];

         // 그 외의 경우에는 지속적으로 탐색을 진행한다.
         int mid = (start + end) / 2;
         return sum(node * 2, start, mid, left, right) + sum(node * 2 + 1, mid + 1, end, left, right);
     }
}

문제 풀이

이제 BOJ 2042 문제를 풀어보자. 문제의 입력과 요구하는 것은 다음과 같다.

N : 숫자의 개수
M : 변경(update)이 일어나는 횟수
K : 구간합(sum)이 일어나는 횟수

a : 변경 (1), 구간합(2)
- a 가 1 일 경우 : b 위치에 c로 update
- a 가 2 일 경우 : b ~ c 위치의 구간합 구하고 출력하기

즉 입력을 받을 때 a의 입력에 따라서 세그먼트 트리의 연산을 진행하고 출력하면 되는 세그먼트 트리 입문용 문제이다. 코드는 다음과 같다.

 

public class Main {

    static long[] arr, tree;
    
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        
        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        int K = Integer.parseInt(st.nextToken());
        
        arr = new long[N + 1];
        for (int i = 1; i <= N; i++) {
            arr[i] = Long.parseLong(br.readLine());
        }
        
        // 트리 크기 구하기
        int k = (int) Math.ceil(Math.log(N) / Math.log(2)) + 1;
        int size = (int) Math.pow(2, k);
        
        tree = new long[size];
        
        init(1, N, 1);
        
        StringBuilder sb = new StringBuilder();
        for(int i = 0; i < M + K; i++) {
            st = new StringTokenizer(br.readLine());
            
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            long c = Long.parseLong(st.nextToken());
            
            if (a == 1) {
                long diff = c - arr[b];
                arr[b] = c;
                update(1, N, 1, b, diff);
            } else if (a == 2) {
                sb.append(sum(1, N, 1, b, (int)c) + "\n");
            }
        }
        
        System.out.println(sb.toString());
    }
    
    static long init(int start, int end, int node) {
        if (start == end) return tree[node] = arr[start];
        
        int mid = (start + end) / 2;
        return tree[node] = init(start, mid, node * 2) + init(mid + 1, end, node * 2 + 1);
    }
    
    static long sum(int start, int end, int node, int left, int right) {
        if (left > end || start > right) return 0;
        
        if (left <= start && end <= right) return tree[node];
        
        int mid = (start + end) / 2;
        
        return sum(start, mid, node * 2, left, right) + sum(mid + 1, end, node * 2 + 1, left, right);
    }
    
    static void update(int start, int end, int node, int idx, long diff) {
        if (idx < start || end < idx) return;
        
        tree[node] += diff;
        
        if (start != end) {
            int mid = (start + end) / 2;
            
            update(start, mid, node * 2, idx, diff);
            update(mid + 1, end, node * 2 + 1, idx, diff);
        }
    }
}