2D 세그먼트 트리
오늘 배운 것
1. 2D 세그먼트 트리 (2D Segment Tree)
2D 세그먼트 트리는 N×M 크기의 고정된 2차원 배열에서 특정 직사각형 범위의 합이나 최댓값 등을 빠르게 구하기 위해 사용합니다. ‘세그먼트 트리의 트리’ 구조로, Y축을 관리하는 바깥쪽 세그먼트 트리의 각 노드가 X축을 관리하는 안쪽 세그먼트 트리를 갖는 형태입니다.
점 업데이트와 범위 쿼리는 모두 $O(\log N \times \log M)$의 시간 복잡도를 가집니다.
Java 예제 코드 (점 업데이트, 범위 합 쿼리)
/**
* 2D 세그먼트 트리 구현
* N x M 크기의 2차원 배열을 가정합니다.
*/
public class SegmentTree2D {
private final int N, M;
private final long[][] tree; // 2D 트리를 1D 배열로 펼쳐서 관리
private final int[][] arr;
public SegmentTree2D(int[][] arr) {
this.N = arr.length;
this.M = arr[0].length;
this.arr = arr;
// 2D 트리는 1D 트리의 4*N개 노드 각각이 4*M 크기의 배열을 가지는 것과 유사
this.tree = new long[4 * N][4 * M];
buildY(1, 0, N - 1);
}
// Y축 트리 빌드
private void buildY(int nodeY, int startY, int endY) {
if (startY == endY) {
buildX(nodeY, 1, 0, M - 1, startY);
} else {
int midY = (startY + endY) / 2;
buildY(2 * nodeY, startY, midY);
buildY(2 * nodeY + 1, midY + 1, endY);
// 자식 Y노드들의 X트리 값을 합쳐 현재 Y노드의 X트리를 만듦
for (int i = 1; i < 4 * M; i++) {
tree[nodeY][i] = tree[2 * nodeY][i] + tree[2 * nodeY + 1][i];
}
}
}
// X축 트리 빌드
private void buildX(int nodeY, int nodeX, int startX, int endX, int y) {
if (startX == endX) {
tree[nodeY][nodeX] = arr[y][startX];
} else {
int midX = (startX + endX) / 2;
buildX(nodeY, 2 * nodeX, startX, midX, y);
buildX(nodeY, 2 * nodeX + 1, midX + 1, endX, y);
tree[nodeY][nodeX] = tree[nodeY][2 * nodeX] + tree[nodeY][2 * nodeX + 1];
}
}
// 점 (y, x)의 값을 val로 업데이트
public void update(int y, int x, int val) {
updateY(1, 0, N - 1, y, x, val);
}
private void updateY(int nodeY, int startY, int endY, int y, int x, int val) {
if (startY == endY) {
updateX(nodeY, 1, 0, M - 1, x, val);
} else {
int midY = (startY + endY) / 2;
if (y <= midY) {
updateY(2 * nodeY, startY, midY, y, x, val);
} else {
updateY(2 * nodeY + 1, midY + 1, endY, y, x, val);
}
// 해당 Y노드의 X트리 전체를 갱신
for (int i = 1; i < 4 * M; i++) {
tree[nodeY][i] = tree[2 * nodeY][i] + tree[2 * nodeY + 1][i];
}
}
}
private void updateX(int nodeY, int nodeX, int startX, int endX, int x, int val) {
if (startX == endX) {
tree[nodeY][nodeX] = val;
} else {
int midX = (startX + endX) / 2;
if (x <= midX) {
updateX(nodeY, 2 * nodeX, startX, midX, x, val);
} else {
updateX(nodeY, 2 * nodeX + 1, midX + 1, endX, x, val);
}
tree[nodeY][nodeX] = tree[nodeY][2 * nodeX] + tree[nodeY][2 * nodeX + 1];
}
}
// 범위 (y1, x1) ~ (y2, x2)의 합 쿼리
public long query(int y1, int x1, int y2, int x2) {
return queryY(1, 0, N - 1, y1, y2, x1, x2);
}
private long queryY(int nodeY, int startY, int endY, int y1, int y2, int x1, int x2) {
if (y1 > endY || y2 < startY) return 0;
if (y1 <= startY && endY <= y2) {
return queryX(nodeY, 1, 0, M - 1, x1, x2);
}
int midY = (startY + endY) / 2;
long leftSum = queryY(2 * nodeY, startY, midY, y1, y2, x1, x2);
long rightSum = queryY(2 * nodeY + 1, midY + 1, endY, y1, y2, x1, x2);
return leftSum + rightSum;
}
private long queryX(int nodeY, int nodeX, int startX, int endX, int x1, int x2) {
if (x1 > endX || x2 < startX) return 0;
if (x1 <= startX && endX <= x2) {
return tree[nodeY][nodeX];
}
int midX = (startX + endX) / 2;
long leftSum = queryX(nodeY, 2 * nodeX, startX, midX, x1, x2);
long rightSum = queryX(nodeY, 2 * nodeX + 1, midX + 1, endX, x1, x2);
return leftSum + rightSum;
}
}
2. 동적 세그먼트 트리
좌표의 범위가 10^9 처럼 매우 크지만, 실제로 사용하는 점의 개수는 적을 때(희소할 때, sparse) 사용합니다. 모든 범위를 배열로 미리 할당하는 대신, 값이 추가되는 경로의 노드만 필요할 때 동적으로 생성합니다. Node 객체와 자식 노드 참조(포인터)를 이용해 구현합니다.
/**
* 동적 세그먼트 트리 구현
* 매우 큰 좌표 범위(0 ~ MAX_COORD)에서 점 업데이트와 범위 합 쿼리를 지원
*/
public class DynamicSegmentTree {
// 노드 객체
static class Node {
long value;
Node left = null;
Node right = null;
}
private final Node root;
private final int MAX_COORD;
public DynamicSegmentTree(int maxCoord) {
this.root = new Node();
this.MAX_COORD = maxCoord;
}
// index에 value를 더하는 업데이트
public void update(int index, int value) {
update(root, 0, MAX_COORD, index, value);
}
private void update(Node node, int start, int end, int index, int value) {
if (start == end) {
node.value += value;
return;
}
int mid = start + (end - start) / 2;
if (index <= mid) {
// 왼쪽 자식 노드가 없으면 새로 생성
if (node.left == null) node.left = new Node();
update(node.left, start, mid, index, value);
} else {
// 오른쪽 자식 노드가 없으면 새로 생성
if (node.right == null) node.right = new Node();
update(node.right, mid + 1, end, index, value);
}
long leftVal = (node.left != null) ? node.left.value : 0;
long rightVal = (node.right != null) ? node.right.value : 0;
node.value = leftVal + rightVal;
}
// left ~ right 범위의 합 쿼리
public long query(int left, int right) {
return query(root, 0, MAX_COORD, left, right);
}
private long query(Node node, int start, int end, int left, int right) {
// 해당 경로의 노드가 생성된 적 없으면 값은 0
if (node == null || right < start || left > end) {
return 0;
}
if (left <= start && end <= right) {
return node.value;
}
int mid = start + (end - start) / 2;
long leftSum = query(node.left, start, mid, left, right);
long rightSum = query(node.right, mid + 1, end, left, right);
return leftSum + rightSum;
}
}
내일 할 일
- 스프링 공부