ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 세그먼트 트리란?
    알고리즘/알고리즘 2022. 3. 13. 15:39
    728x90

    세그먼트 트리란?

    여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하기 위한 자료구조입니다.

     

    배열에서 특정 구간의 합을 가장 빠르게 구하기 위한 방법은 무엇일까요?

    예시 데이터 : 5 8 7 3 2 5 1 8 9 8 7 3

     

    여기에서 인덱스 1부터 10까지 데이터의 합을 구하려면 어떻게 할 수 있을까요?

    간단하게 인덱스 1부터 10까지 데이터를 다 더해준다면 데이터의 개수에 의존하여 O(N)의 시간 복잡도가 나옵니다.

     

    이것을 트리 구조를 이용해서 구한다면 O(logN)의 시간복잡도로 부분합을 구할 수 있습니다.

     

    다음 그림을 보면 조금 더 이해하기 쉽습니다.

    https://www.youtube.com/watch?v=075fcq7oCC8

    이처럼 더한값을 다시 재사용하면서 최종적으로 0~14의 연산 결과를 얻기 때문에 O(logN) 시간을 보장합니다.

     

    Segment Tree의 실제 구조

    https://www.youtube.com/watch?v=075fcq7oCC8

    부모로부터 왼쪽 자식은 번호 * 2  오른쪽 자식은 번호*2 +1을 가집니다.

    반대로 자식 번호/2를 하면 부모 번호로 이동할 수 있습니다.

    이는 완전 트리(complete tree)의 특성을 이용했기 때문에 가능합니다.

     

     

    그러면 실제로 Segment Tree를 어떻게 만들까요?

    각 segment에 부여된 번호는 segment 저장을 위한 배열의 index로 사용됩니다.

    예를 들어 배열의 1번 index에는 0~14까지의 구간합이 저장됩니다.

    예를 들어 배열의 11번 index에는 6~7까지의 구간합이 저장됩니다.

     

    이제 Segment tree를 생성해 보겠습니다.

    루트부터 rage를 반절씩 나눠가면서 생성합니다.

    static int[] numbers = {1,2,3,4,5,6,7,8};
    static int[] segmentTree = new int[numbers.length * 4]; 
    //tree size를 N*4를 하는 이유는 2의 제곱 형태의 길이를 가지기 때문에 4를 곱하면 모든 범위를 커버할 수 있습니다.
    
    static int makeSegmentTree(int start, int end, int index) {
    	if(start ==end) { 
    		return segmentTree[index] = numbers[start];
    	}
    	int mid = (start + end) /2;
    	return segmentTree[index] = makeSegmentTree(start, mid, index*2) + makeSegmentTree(mid+1, end, index*2 + 1);
    }

     

    그러면 생성한 Segment tree를 가지고 특정 구간의 연산 결과를 구해보겠습니다.

    https://m.blog.naver.com/ndb796/221282210534

     

    만약 4~8범위의 구간합을 구하려고 한다면 다음과 같이 세 노드의 합만 구해 주면 됩니다.

    구간의 합은 '범위 안에 있는 경우'에 한해서만 더해주면 되고 그 밖의 경우는 고려하지 않습니다.

    	// start = 배열의 시작 인덱스, end = 배열의 끝 인덱스, index = 노드의 번호
    	// left, right = 구간 합을 구하고자 하는 범위
    	static long sum(int start, int end, int index, int left, int right) {
    		// 범위 밖에 있는 경우
    		if (left > end || right < start)
    			return 0;
    		// 범위 안에 있는 경우
    		if (left <= start && end <= right)
    			return segmentTree[index];
    		// 그렇지 않다면 두 부분을 나누어 합을 구하기
    		int mid = (start + end) / 2;
    		return sum(start, mid, index * 2, left, right) + sum(mid + 1, end, index * 2 + 1, left, right);
    
    	}

     

    특정 원소의 값을 수정하려면 어떻게 해야 할까요?

    특정 원소의 값을 수정하기 위해서는 해당 원소를 포함하고 있는 모든 구간의 노드를 갱신해야 합니다.

    예를 들어 인덱스 7의 노드를 수정한다고 하면 다음과 같이 5개의 구간 합 노드를 모두 수정하면 됩니다.

     

    https://m.blog.naver.com/ndb796/221282210534

    즉, 재귀적으로 수정할 노드가 범위 안에 있는 경우에 한해서만 수정하면 됩니다.

     

    방법 1

    public static void update(int start, int end, int index, int originIndex, long diff) {
    		// 범위 밖에 있는경우에는 진행하지 않습니다.
    		if (!(start <= originIndex && originIndex <= end))
    			return;
    		
    		// 범위 안에 있는경우에는 값을 갱신합니다.
    		segmentTree[index] += diff;
    
    		//리프노드일 경우에는 더이상 진행x
    		if (start == end)
    			return;
    		//리프노드가 아닌경우에는 리프노드까지 진행
    		int mid = (start + end) / 2;
    		update(start, mid, index * 2, originIndex, diff);
    		update(mid + 1, end, index * 2 + 1, originIndex, diff);
    
    	}

     

     

    방법 2

    	static long update(int start, int end, int index, int originIndex, long newValue) {
        		//범위 밖이라면 tree값 그대로 반환
    		if (originIndex < start || originIndex > end) {
    			return segmentTree[index];
    		}
    
    		// 리프 노드라면 NewValue로 초기화
    		if (start == end) {
    			return segmentTree[index] = newValue;
    		}
    		
            	//재귀적으로 올라가면서 tree값 재구성
    		int mid = (start + end) / 2;
    		return segmentTree[index] = update(start, mid, index * 2, originIndex, newValue)
    				+ update(mid + 1, end, index * 2 + 1, originIndex, newValue);
    	}

     

     

    출처

    https://m.blog.naver.com/ndb796/221282210534

     

    41. 세그먼트 트리(Segment Tree)

    이번 시간에 다룰 내용은 여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하는 ...

    blog.naver.com

    https://www.youtube.com/watch?v=ahFB9eCnI6c 

     

    댓글

Designed by Tistory.