2D 세그먼트 트리

오늘 배운 것

1. 2D 세그먼트 트리 (2D Segment Tree)

2D 세그먼트 트리는 N×M 크기의 고정된 2차원 배열에서 특정 직사각형 범위의 합이나 최댓값 등을 빠르게 구하기 위해 사용합니다. ‘세그먼트 트리의 트리’ 구조로, Y축을 관리하는 바깥쪽 세그먼트 트리의 각 노드가 X축을 관리하는 안쪽 세그먼트 트리를 갖는 형태입니다.

점 업데이트와 범위 쿼리는 모두 $O(\log N \times \log M)$의 시간 복잡도를 가집니다.

Java 예제 코드 (점 업데이트, 범위 합 쿼리)
/**
 * 2D 세그먼트 트리 구현
 * N x M 크기의 2차원 배열을 가정합니다.
 */
public class SegmentTree2D {

    private final int N, M;
    private final long[][] tree; // 2D 트리를 1D 배열로 펼쳐서 관리
    private final int[][] arr;

    public SegmentTree2D(int[][] arr) {
        this.N = arr.length;
        this.M = arr[0].length;
        this.arr = arr;
        // 2D 트리는 1D 트리의 4*N개 노드 각각이 4*M 크기의 배열을 가지는 것과 유사
        this.tree = new long[4 * N][4 * M];
        buildY(1, 0, N - 1);
    }

    // Y축 트리 빌드
    private void buildY(int nodeY, int startY, int endY) {
        if (startY == endY) {
            buildX(nodeY, 1, 0, M - 1, startY);
        } else {
            int midY = (startY + endY) / 2;
            buildY(2 * nodeY, startY, midY);
            buildY(2 * nodeY + 1, midY + 1, endY);
            // 자식 Y노드들의 X트리 값을 합쳐 현재 Y노드의 X트리를 만듦
            for (int i = 1; i < 4 * M; i++) {
                tree[nodeY][i] = tree[2 * nodeY][i] + tree[2 * nodeY + 1][i];
            }
        }
    }

    // X축 트리 빌드
    private void buildX(int nodeY, int nodeX, int startX, int endX, int y) {
        if (startX == endX) {
            tree[nodeY][nodeX] = arr[y][startX];
        } else {
            int midX = (startX + endX) / 2;
            buildX(nodeY, 2 * nodeX, startX, midX, y);
            buildX(nodeY, 2 * nodeX + 1, midX + 1, endX, y);
            tree[nodeY][nodeX] = tree[nodeY][2 * nodeX] + tree[nodeY][2 * nodeX + 1];
        }
    }

    // 점 (y, x)의 값을 val로 업데이트
    public void update(int y, int x, int val) {
        updateY(1, 0, N - 1, y, x, val);
    }

    private void updateY(int nodeY, int startY, int endY, int y, int x, int val) {
        if (startY == endY) {
            updateX(nodeY, 1, 0, M - 1, x, val);
        } else {
            int midY = (startY + endY) / 2;
            if (y <= midY) {
                updateY(2 * nodeY, startY, midY, y, x, val);
            } else {
                updateY(2 * nodeY + 1, midY + 1, endY, y, x, val);
            }
            // 해당 Y노드의 X트리 전체를 갱신
            for (int i = 1; i < 4 * M; i++) {
                tree[nodeY][i] = tree[2 * nodeY][i] + tree[2 * nodeY + 1][i];
            }
        }
    }

    private void updateX(int nodeY, int nodeX, int startX, int endX, int x, int val) {
        if (startX == endX) {
            tree[nodeY][nodeX] = val;
        } else {
            int midX = (startX + endX) / 2;
            if (x <= midX) {
                updateX(nodeY, 2 * nodeX, startX, midX, x, val);
            } else {
                updateX(nodeY, 2 * nodeX + 1, midX + 1, endX, x, val);
            }
            tree[nodeY][nodeX] = tree[nodeY][2 * nodeX] + tree[nodeY][2 * nodeX + 1];
        }
    }

    // 범위 (y1, x1) ~ (y2, x2)의 합 쿼리
    public long query(int y1, int x1, int y2, int x2) {
        return queryY(1, 0, N - 1, y1, y2, x1, x2);
    }

    private long queryY(int nodeY, int startY, int endY, int y1, int y2, int x1, int x2) {
        if (y1 > endY || y2 < startY) return 0;
        if (y1 <= startY && endY <= y2) {
            return queryX(nodeY, 1, 0, M - 1, x1, x2);
        }
        int midY = (startY + endY) / 2;
        long leftSum = queryY(2 * nodeY, startY, midY, y1, y2, x1, x2);
        long rightSum = queryY(2 * nodeY + 1, midY + 1, endY, y1, y2, x1, x2);
        return leftSum + rightSum;
    }

    private long queryX(int nodeY, int nodeX, int startX, int endX, int x1, int x2) {
        if (x1 > endX || x2 < startX) return 0;
        if (x1 <= startX && endX <= x2) {
            return tree[nodeY][nodeX];
        }
        int midX = (startX + endX) / 2;
        long leftSum = queryX(nodeY, 2 * nodeX, startX, midX, x1, x2);
        long rightSum = queryX(nodeY, 2 * nodeX + 1, midX + 1, endX, x1, x2);
        return leftSum + rightSum;
    }
}

2. 동적 세그먼트 트리

좌표의 범위가 10^9 처럼 매우 크지만, 실제로 사용하는 점의 개수는 적을 때(희소할 때, sparse) 사용합니다. 모든 범위를 배열로 미리 할당하는 대신, 값이 추가되는 경로의 노드만 필요할 때 동적으로 생성합니다. Node 객체와 자식 노드 참조(포인터)를 이용해 구현합니다.

/**
 * 동적 세그먼트 트리 구현
 * 매우 큰 좌표 범위(0 ~ MAX_COORD)에서 점 업데이트와 범위 합 쿼리를 지원
 */
public class DynamicSegmentTree {
    
    // 노드 객체
    static class Node {
        long value;
        Node left = null;
        Node right = null;
    }

    private final Node root;
    private final int MAX_COORD;

    public DynamicSegmentTree(int maxCoord) {
        this.root = new Node();
        this.MAX_COORD = maxCoord;
    }

    // index에 value를 더하는 업데이트
    public void update(int index, int value) {
        update(root, 0, MAX_COORD, index, value);
    }

    private void update(Node node, int start, int end, int index, int value) {
        if (start == end) {
            node.value += value;
            return;
        }

        int mid = start + (end - start) / 2;
        if (index <= mid) {
            // 왼쪽 자식 노드가 없으면 새로 생성
            if (node.left == null) node.left = new Node();
            update(node.left, start, mid, index, value);
        } else {
            // 오른쪽 자식 노드가 없으면 새로 생성
            if (node.right == null) node.right = new Node();
            update(node.right, mid + 1, end, index, value);
        }

        long leftVal = (node.left != null) ? node.left.value : 0;
        long rightVal = (node.right != null) ? node.right.value : 0;
        node.value = leftVal + rightVal;
    }

    // left ~ right 범위의 합 쿼리
    public long query(int left, int right) {
        return query(root, 0, MAX_COORD, left, right);
    }

    private long query(Node node, int start, int end, int left, int right) {
        // 해당 경로의 노드가 생성된 적 없으면 값은 0
        if (node == null || right < start || left > end) {
            return 0;
        }
        if (left <= start && end <= right) {
            return node.value;
        }

        int mid = start + (end - start) / 2;
        long leftSum = query(node.left, start, mid, left, right);
        long rightSum = query(node.right, mid + 1, end, left, right);
        return leftSum + rightSum;
    }
}

내일 할 일

  1. 스프링 공부

results matching ""

    No results matching ""