A Quick Note of Segment Tree

  1. What is segment tree

    • A binary tree where each node represents a segment of an original array.
    • The node stores the information of the subarray like sum ,minimum, max...
    • The root represents the entire array, and each child node represents half of its parent.
  2. Why segment tree

    • A data structure for efficiently process range queries.

    • Minimum, maximum, rangeSum of an subarray of a list or array in a logarithmic time.

    • It can support rangeSum in logarithmic time; while updating subarray element very fast.

  3. Implementation and comments

    import static org.junit.Assert.assertEquals;
    
    
    public class SegmentTree {
        int[] nums;
        int[] tree;
        int n;
    
        public SegmentTree(int[] arr) {
            n = arr.length;
            tree = new int[4*n];
            nums = arr;
            buildTree(0, 0, n - 1);
        }
    
    
        /**
         *
         * @param treeIndex The index of node in tree we are building.
         * @param start The startIndex of elements in nums/arr represented by current node.
         * @param end The endIndex of elements in nums/arr represented by current node.
         */
        private void buildTree(int treeIndex, int start, int end) {
            if (start == end) { // leaf node;
                tree[treeIndex] = nums[start];
            } else {
                int mid = start + (end - start) / 2;
                buildTree(2*treeIndex+1, start, mid);
                buildTree(2*treeIndex+2, mid+1, end);
                tree[treeIndex] = tree[2*treeIndex+1] + tree[2*treeIndex+2];
            }
        }
    
    
        /**
         *
         * @param queryStart The range start we are going to query.
         * @param queryEnd The range end we are going to query.
         * @return
         */
        public int rangeSum(int queryStart, int queryEnd) {
            return rangeSum(0, 0, n-1, queryStart, queryEnd);
        }
    
    
        /**
         *
         * @param treeIndex The index of the tree node we are currently querying.
         * @param queryStart The range start we are going to query.
         * @param queryEnd The range end we are going to query.
         * @param start The startIndex of elements in nums/arr represented by current node.
         * @param end The endIndex of elements in nums/arr represented by current node.
         * @return
         */
        private int rangeSum(int treeIndex, int start, int end, int queryStart, int queryEnd) {
            if (start > queryEnd || end < queryStart) { // index out of range
                return 0; // default value;
            } else if (start >= queryStart && end <= queryEnd) {
                return tree[treeIndex];
            } else {
                int mid = start + (end - start) / 2;
            if (queryEnd <= mid) {
                return rangeSum(2*treeIndex+1, start, mid, queryStart, queryEnd);
            } else if (queryStart > mid) {
                return rangeSum(2*treeIndex+2, mid+1, end, queryStart, queryEnd);
            } else {
                int leftSum = rangeSum(2 * treeIndex + 1, start, mid, queryStart, queryEnd);
                int rightSum = rangeSum(2 * treeIndex + 2, mid + 1, end, queryStart, queryEnd);
                return leftSum + rightSum;
            }
        }
    
    
        /**
         *
         * @param arrPos The index of element in nums/arr that we are going to update.
         * @param newVal The new value of the element we are going to update.
         */
        public void update(int arrPos, int newVal) {
            update(0, 0, n-1, arrPos, newVal);
        }
    
    
        /**
         *
         * @param treeIndex The index of the tree node we are going to update value.
         * @param start The startIndex of elements in nums/arr represented by current node.
         * @param end The endIndex of elements in nums/arr represented by current node.
         * @param arrPos The index of element in nums/arr that we are going to update.
         * @param newVal The new value of the element we are going to update.
         */
        private void update(int treeIndex, int start, int end, int arrPos, int newVal) {
            if (start == end) { // leaf node;
                tree[treeIndex] = newVal;
            } else {
                int mid = start + (end - start) / 2;
                if (arrPos <= mid) {
                    update(2*treeIndex+1, start, mid, arrPos, newVal);
                } else {
                    update(2*treeIndex+2, mid+1, end, arrPos, newVal);
                }
                tree[treeIndex] = tree[2*treeIndex+1] + tree[2*treeIndex+2];
            }
        }
    
        
        public static void main(String[] args) {
            int[] arr = {1, 3, 5, 7, 9, 11};
            SegmentTree segmentTree = new SegmentTree(arr);
    
            // Test rangeSum
            assertEquals(1, segmentTree.rangeSum(0, 0));
            assertEquals(3, segmentTree.rangeSum(1, 1));
            assertEquals(36, segmentTree.rangeSum(0, 5));
            assertEquals(32, segmentTree.rangeSum(2, 5));
    
            // Test update
            segmentTree.update(3, 8);
            assertEquals(33, segmentTree.rangeSum(2, 5));
    
            segmentTree.update(0, 2);
            assertEquals(38, segmentTree.rangeSum(0, 5));
        }
    }
    
  4. Lazy

  • For the operations we mark(with the addition) the nodes first instead of executing the updating immediately. The execution of updating happens when querying.
  • This lazy operation or thoughts is because of sometimes we do the range updates and all most all the nodes needs to be updated and the time becomes almost O(n)