Segment Tree

Let me start with a usual problem:

  1. Assume we have an limit range array[0,1,…,n-1]. We should
  • Get the sum range from i to j, which $i,j \in [0,n-1]$ and $i<=j$.
  • Update the elements.

A brute force way is to solve by preprocessing the array, calculate and store sum range from $0$ to $i$, then we can get the sum range from i to j by calculate array[j] - array[i], in this way, the update time complexity would be O(n) on average, and the cost for getting the sum after update is O(1).

But under certain circumstance, we may update a lot, can we do better than linear time?

The answer is yes(if not, the article would be meaningless.), we can acheive an average time complexity O(logn) by using Segment Tree.

Assume now we have an array [2,4,6,8,10], the tree will be,

Segment Tree

How to construct the segment tree?
We can start with the root, whose range should be [0,1,…,n-1], then split it into halves until the range size reaches 1. The left child of root will be [0,1,…(n-1)/2], and the right part will be [(n-1)/2+1,…n-1].

Talk is cheap, here is the code.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
public class SegmentTree{

class SegmentTreeNode{
SegmentTreeNode left, right;
int sum, start, end;
public SegmentTreeNode(int start, int end){
sum = 0;
this.start = start;
this.end = end;
}
}

SegmentTreeNode root = null;

public SegmentTree(int[] array){
root = constructTree(root, 0, array.length - 1, array);
}

private SegmentTreeNode constructTree(SegmentTreeNode cur, int l, int r, int[] array){
if(l == r){
cur = new SegmentTreeNode(l, r);
cur.sum = array[l];
return cur;
}else{
int mid = l + (r-l)/2;
cur = new SegmentTreeNode(l, r);
cur.left = constructTree(cur.left, l, mid, array);
cur.right = constructTree(cur.right, mid+1, r, array);
cur.sum = cur.left.sum + cur.right.sum;
return cur;
}
}

public void update(int i, int value){
root = update(root, i, value);
}

private SegmentTreeNode update(SegmentTreeNode cur, int i, int val){
if(cur.start == cur.end){
cur.sum = val;
return cur;
}else{
int mid = cur.start + (cur.end - cur.start)/2;
if(i<=mid){
cur.left = update(cur.left, i, val);
}else{
cur.right = update(cur.right, i, val);
}
cur.sum = cur.left.sum + cur.right.sum;
return cur;
}
}

public int getSum(int i, int j){
return getSum(root, i, j);
}

private int getSum(SegmentTreeNode cur, int l, int r){
if(cur == null) return 0;
if(cur.start == cur.end) return cur.sum;
int mid = cur.start + (cur.end - cur.start)/2;
if(r<=mid){
return getSum(cur.left, l, mid);
}else if(l>mid){
return getSum(cur.right, mid+1, r);
}else{
return getSum(cur.left, l, mid) + getSum(cur.right, mid+1, r);
}
}

}