Segment Tree: LeetCode 307

2019-03-29  本文已影响0人  vickeex

basic

回顾下页式(page-oritented)存储: 有一个page为树的根页; 页中包含多个键以及子页的索引; 子页则负责以父页中某两个相邻键为边界的连续范围内的所有键.

Segment Tree与其类似, 每个节点代表一个区间中的某统计值(如: 最大/小值, 和, 积等), 节点的子节点按序负责父节点区间中的某一子区间(一般均匀分配).
以binary segment tree为例, 每个节点的区间分为两个子区间, 节点值为该区间的和, 示意图如下:

故我们还需要存储每个节点所负责区间的起止位置, 即除了val(此处为sum)还需有 start/end变量.

operations

基于这种结构, 提供的操作(API):

code

以 [LeetCode 307. Range Sum Query - Mutable][2] 为例, Segment Tree的Python实现如下:

class SegmentNode:
    def __init__(self, start, end):
        self.start, self.end, self.sum = start, end, 0  # the start/end/sum of the interval
        self.left, self.right = None, None  # left/right interval


class NumArray:
    def __init__(self, nums: list):
        def buildTree(l, r):
            if l > r:  # irregular parameters
                return None
            if l == r:  # leaf node
                n = SegmentNode(l, r)
                n.sum = nums[l]
                return n
            mid, root = (l + r) // 2, SegmentNode(l, r)
            root.left, root.right = buildTree(l, mid), buildTree(mid + 1, r)  # recursively build the tree
            root.sum = root.left.sum + root.right.sum  # update the sum from children
            return root

        self.root = buildTree(0, len(nums) - 1)

    def update(self, i: int, val: int) -> None:
        def updateTree(root, i, val):
            if root.start == root.end:  # the leaf node to update
                root.sum = val
                return val
            mid = (root.start + root.end) // 2  # then recursively update the tree
            if i <= mid:
                updateTree(root.left, i, val)
            else:
                updateTree(root.right, i, val)
            root.sum = root.left.sum + root.right.sum  # update the sum from children
            return root.sum

        updateTree(self.root, i, val)

    def sumRange(self, i: int, j: int) -> int:
        def findNode(root, x, y):
            if root.start == x and root.end == y:  # just the interval
                return root.sum
            mid = (root.start + root.end) // 2
            if y <= mid:
                return findNode(root.left, x, y)  # the interval belongs to left
            if x > mid:
                return findNode(root.right, x, y)  # the interval belongs to right
            return findNode(root.left, x, mid) + findNode(root.right, mid + 1, y)  # cross left and right

        return findNode(self.root, i, j)

上一篇 下一篇

猜你喜欢

热点阅读