S E P H ' S
[자료구조] 13. 세그먼트 트리(Segment Tree) 본문
BOJ 2042 구간 합 구하기 풀이를 하면서 정리를 시작하게 됐다. 이 문제는 세그먼트 트리에 대해 이해하고 있어야 접근이 가능하다. 우선 세그먼트 트리에 대해서 정리하고 문제 풀이를 시작해보겠다.
세그먼트 트리 (Segment Tree)
특정 구간 내 데이터에 대한 연산(쿼리)를 빠르게 구할 수 있는 트리이다. 예를 들어 특정 구간 합, 최소값, 최대값, 평균값 등을 구하는데 용이하다. 시간 복잡도는 아래와 같다.
데이터 변경 : O(logN)
연산 : O(logN)
데이터 변경할때마다 M번의 연산 : O((logN + logN) * M) = O(MlogN))
구조
기본적으로 세그먼트 트리는 이진트리 구조를 가진다. 그래서 이진트리의 특징대로 다음과 같은 특징 또한 가진다.
- 왼쪽 자식노드 index = 부모 노드 index * 2
- 오른쪽 자식노드 index = 부모 노드 index * 2 + 1
이 인덱스를 통해서 여러 연산을 진행하거나 접근할 수 있다. 그러면 세그먼트 트리의 크기는 어떻게 될까? 그것은 트리의 높이로 구할 수 있다. 트리의 높이는 Root Node 에서 가장 긴 경로를 뜻한다. 위 트리에서 가장 긴 경로는 3이다.
배열의 길이를 8이라고 해보자. 이진 트리에 데이터를 저장해야 한다고 할때 각 배열의 원소들을 단말노드(위 그림에서 8 ~ 15번)라고 생각해보자. 그렇다면 각각의 노드들과 차수가 1인 노드(4 ~ 7번, 2 ~3번, 1번)들의 개수를 보면 8 - 4 - 2 - 1 이다.
노드들의 개수와 경로로 알 수 있는 것은 2^h = 배열의 길이 이다. 이 경우 전체 노드 개수는 2^(h+1)-1 이다. 그래서 세그먼트 트리를 저장할 배열의 크기는 2^(h+1) 이면 충분하다. 그리고 인덱스를 위해서 루트 노드의 인덱스를 보통 1로 사용하기 때문에 0을 하나 둔다. 그러므로 2^(h+1) 이면 충분한 것이다.
세그먼트 트리 구현
생성 및 구성
class SegmentTree {
long tree[]; // 각 원소가 담길 트리
int treeSize; // 트리 크기
public SegmentTree(int arrSize) {
// 트리 높이 구하기
int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2));
// 높이를 이용해 배열 사이즈 구하기
this.treeSize = (int) Math.pow(2, h+1);
// 배열 생성
tree = new long[treeSize];
}
/**
* 1. 생성 및 구성
*
* @param arr : 원소 배열
* @param node : 현재 노드
* @param start : 현재구간 배열 시작
* @param end : 현재구간 배열 끝
*
* @return : 원소 배열 값 or 자식노드의 합
*/
public long init(long[] arr, int node, int start, int end) {
// 배열의 시작과 끝이 같으면 단말노드임 즉, 원소를 그대로 담는다.
if (start == end) return tree[node] = arr[start];
// 단말 노드가 아니면 자식노드의 합을 담는다.
int mid = (start + end) / 2;
return tree[node] = init(arr, node * 2, start, mid) + init(arr, node * 2 + 1, mid + 1, end);
}
}
데이터 변경
3번째 원소의 값을 7에서 9로 변경해보자. 그렇다면 3번째 원소와 관련된 노드들이 모두 변경되어야 한다. 우선 원래 값과 변경할 값의 차이를 구해주고 변경된 차이만큼 더해주고 자식노드들을 확인한다. 만약 자식노드의 합 범위가 변경된 원소와 관련 없으면 따로 확인 하지 않는다. 이렇게 관련된 자식노드들을 모두 변경해 나간다. 그리고 원래 배열의 값도 변경해야 한다.
/**
* 2. 데이터 변경
*
* @param node : 현재 노드 idx
* @param start : 배열의 시작
* @param end : 배열의 끝
* @param idx : 변경된 데이터의 idx
* @param diff : 원래 데이터 값과 변경 데이터값의 차이
*/
public void update(int node, int start, int end, int idx, long diff) {
// 만약 변경할 idx가 범위 밖이면 확인할 필요가 없다.
if (idx < start || end < idx) return;
// 차이를 저장한다.
tree[node] += diff;
// 단말 노드가 아니라면 아래 자식 노드들도 확인을 거친다.
if (start != end) {
int mid = (start + end) / 2;
update(node * 2, start, mid, idx, diff);
update(node * 2 + 1, mid + 1, end, idx, diff);
}
}
구간 합 구하기
만약 3번에서 5번까지의 합을 구한다고 생각해보자. 그림에서처럼 3~4번 까지의 합과 5번 원소의 합만 구하면 정답을 얻을 수 있다. 이 역시 루프 노드에서 부터 확인한다. 자식노드로 내려가면서 확인하다가 현재 배열에서 찾고자 하는 범위를 벗어나면 0을 반환하도록 한다.
그리고 현재 찾고자 하는 범위에 포함되면 현재 값을 그대로 반환한다. 아니라면 그 밑의 자식을 확인하도록 한다. 최종적으로 모두 찾게되면 재귀를 타고 올라오면서 더한 값이 반환되도록 하면 된다.
/**
* 3. 구간 합 구하기
*
* @param node : 현재 노드
* @param start : 배열의 시작
* @param end : 배열의 끝
* @param left : 원하는 누적합의 시작
* @param right : 원하는 누적합의 끝
* @return : 누적합
*/
public long sum(int node, int start, int end, int left, int right) {
// 범위를 벗어나게 되면 더할 필요가 없다.
if (left > end || start > right) return 0;
// 범위 내에 포함된다면 더 내려가지 않고 바로 리턴하면 된다.
if (left <= start && end <= right) return tree[node];
// 그 외의 경우에는 지속적으로 탐색을 진행한다.
int mid = (start + end) / 2;
return sum(node * 2, start, mid, left, right) + sum(node * 2 + 1, mid + 1, end, left, right);
}
최종 코드
class SegmentTree {
long tree[]; // 각 원소가 담길 트리
int treeSize; // 트리 크기
public SegmentTree(int arrSize) {
// 트리 높이 구하기
int h = (int) Math.ceil(Math.log(arrSize) / Math.log(2));
// 높이를 이용해 배열 사이즈 구하기
this.treeSize = (int) Math.pow(2, h+1);
// 배열 생성
tree = new long[treeSize];
}
/**
* 1. 생성 및 구성
*
* @param arr : 원소 배열
* @param node : 현재 노드
* @param start : 현재구간 배열 시작
* @param end : 현재구간 배열 끝
*
* @return : 원소 배열 값 or 자식노드의 합
*/
public long init(long[] arr, int node, int start, int end) {
// 배열의 시작과 끝이 같으면 단말노드임 즉, 원소를 그대로 담는다.
if (start == end) return tree[node] = arr[start];
// 단말 노드가 아니면 자식노드의 합을 담는다.
int mid = (start + end) / 2;
return tree[node] = init(arr, node * 2, start, mid) + init(arr, node * 2 + 1, mid + 1, end);
}
/**
* 2. 데이터 변경
*
* @param node : 현재 노드 idx
* @param start : 배열의 시작
* @param end : 배열의 끝
* @param idx : 변경된 데이터의 idx
* @param diff : 원래 데이터 값과 변경 데이터값의 차이
*/
public void update(int node, int start, int end, int idx, long diff) {
// 만약 변경할 idx가 범위 밖이면 확인할 필요가 없다.
if (idx < start || end < idx) return;
// 차이를 저장한다.
tree[node] += diff;
// 단말 노드가 아니라면 아래 자식 노드들도 확인을 거친다.
if (start != end) {
int mid = (start + end) / 2;
update(node * 2, start, mid, idx, diff);
update(node * 2 + 1, mid + 1, end, idx, diff);
}
}
/**
* 3. 구간 합 구하기
*
* @param node : 현재 노드
* @param start : 배열의 시작
* @param end : 배열의 끝
* @param left : 원하는 누적합의 시작
* @param right : 원하는 누적합의 끝
* @return : 누적합
*/
public long sum(int node, int start, int end, int left, int right) {
// 범위를 벗어나게 되면 더할 필요가 없다.
if (left > end || start > right) return 0;
// 범위 내에 포함된다면 더 내려가지 않고 바로 리턴하면 된다.
if (left <= start && end <= right) return tree[node];
// 그 외의 경우에는 지속적으로 탐색을 진행한다.
int mid = (start + end) / 2;
return sum(node * 2, start, mid, left, right) + sum(node * 2 + 1, mid + 1, end, left, right);
}
}
문제 풀이
이제 BOJ 2042 문제를 풀어보자. 문제의 입력과 요구하는 것은 다음과 같다.
N : 숫자의 개수
M : 변경(update)이 일어나는 횟수
K : 구간합(sum)이 일어나는 횟수
a : 변경 (1), 구간합(2)
- a 가 1 일 경우 : b 위치에 c로 update
- a 가 2 일 경우 : b ~ c 위치의 구간합 구하고 출력하기
즉 입력을 받을 때 a의 입력에 따라서 세그먼트 트리의 연산을 진행하고 출력하면 되는 세그먼트 트리 입문용 문제이다. 코드는 다음과 같다.
public class Main {
static long[] arr, tree;
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int M = Integer.parseInt(st.nextToken());
int K = Integer.parseInt(st.nextToken());
arr = new long[N + 1];
for (int i = 1; i <= N; i++) {
arr[i] = Long.parseLong(br.readLine());
}
// 트리 크기 구하기
int k = (int) Math.ceil(Math.log(N) / Math.log(2)) + 1;
int size = (int) Math.pow(2, k);
tree = new long[size];
init(1, N, 1);
StringBuilder sb = new StringBuilder();
for(int i = 0; i < M + K; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
long c = Long.parseLong(st.nextToken());
if (a == 1) {
long diff = c - arr[b];
arr[b] = c;
update(1, N, 1, b, diff);
} else if (a == 2) {
sb.append(sum(1, N, 1, b, (int)c) + "\n");
}
}
System.out.println(sb.toString());
}
static long init(int start, int end, int node) {
if (start == end) return tree[node] = arr[start];
int mid = (start + end) / 2;
return tree[node] = init(start, mid, node * 2) + init(mid + 1, end, node * 2 + 1);
}
static long sum(int start, int end, int node, int left, int right) {
if (left > end || start > right) return 0;
if (left <= start && end <= right) return tree[node];
int mid = (start + end) / 2;
return sum(start, mid, node * 2, left, right) + sum(mid + 1, end, node * 2 + 1, left, right);
}
static void update(int start, int end, int node, int idx, long diff) {
if (idx < start || end < idx) return;
tree[node] += diff;
if (start != end) {
int mid = (start + end) / 2;
update(start, mid, node * 2, idx, diff);
update(mid + 1, end, node * 2 + 1, idx, diff);
}
}
}
'Programing & Coding > Data Structure' 카테고리의 다른 글
[자료구조] 15. HashSet (해시 셋) (1) | 2023.07.16 |
---|---|
[자료구조] 14. Java 셋 인터페이스 (Set Interface) (0) | 2023.07.13 |
[자료구조] 12. 우선순위 큐(Priority Queue) (0) | 2023.04.13 |
[자료구조] 11. 연결리스트 덱(LinkedList Deque) (1) | 2023.04.13 |
[자료구조] 10. 힙 (Heap) (0) | 2023.04.11 |