机器学习与深度学习原理与应用程序员首页投稿(暂停使用,暂停投稿)

线段树

2017-12-20  本文已影响39人  苏南走了
线段树

主要是涉及到线段树的思路。构造和一些简单的应用。
线段树的应用比较集中,主要是用来处理数组结构里的一些快速查找和修改。以查找数组某区间的最大值这个需求为例子,构建线段树的复杂度为N,之后每次查找和修改操作复杂度都是logN。
线段树的构造比较简单,查找和修改对于不同类的问题略有不同。如果是查找最大值最小值这样的,在节点分叉后,我们需要根据分叉部分判断是进入左分支还是右分支,并更新结果。如果是查找某区间的和,则需要根据分叉节点的位置分别计算左部分和右部分的结果,最后相加即可。

下面给出一些例子。其实例子代码都差不多。。。

线段树的基本结构例子
#coding=utf8
class SegmentTreeNode:
    def __init__(self,start,end,maxv):
        self.start = start
        self.end = end
        self.count = maxv
        self.left = None
        self.right = None
class Solution:
    """
    @param: A: An integer array
    @param: queries: The query list
    @return: The number of element in the array that are smaller that the given integer
    """

    def countOfSmallerNumber(self, A, queries):
        # write your code here
        lena = len(A)
        tmp = [0 for i in range(10001)]
        tmp = self.build(tmp)
        for i in A:
            self.modifyt(tmp,i)
        ans = []
        for i in queries:
            res  = self.queryt(tmp,0,i-1)
            ans.append(res)
        return ans
    def build(self,A):
        return self.buildhelper(0,len(A)-1,A)

    def buildhelper(self,left,right,A):
        if left>right:
            return None
        root = SegmentTreeNode(left,right,A[left])
        if left==right:
            return root
        mid = int((root.start+root.end)/2)
        root.left = self.buildhelper(left,mid,A)
        root.right =self.buildhelper(mid+1,right,A)
        root.maxv = root.left.count + root.right.count
        #root.minv = min(root.left.minv,root.right.minv)
        #root.sumv = sum(root.left.sumv,root.right.sumv)
        return root

    def queryt(self,root,start,end):
        if start<=root.start and end>=root.end:
            return root.count
        mid = int((root.start+root.end)/2)
        lefta = 0
        righta = 0
        if mid>=start:
            lefta = self.queryt(root.left,start,end)
        if mid+1<=end:
            righta = self.queryt(root.right,start,end)
        return lefta+righta

    def modifyt(self,root,index):
        if root.start == root.end and root.start == index:
            root.count += 1
            return
        mid = int((root.start+root.end)/2)
        if index<=mid:
            self.modifyt(root.left,index)
            root.count = root.left.count + root.right.count
        else:
            self.modifyt(root.right,index)
            root.count = root.left.count + root.right.count
        return


    def query(self,root,start,end):
        if start<=root.start and end>=root.end:
            return root.max
        mid = int((root.start+root.end)/2)
        ans = -9999
        if mid>=start:
            ans = max(ans,query(root.left,start,end))
        if mid+1<=end:
            ans = max(ans,query(root.right,start,end))
        return ans

def modify(root,index,value):
    if root.start == root.end and root.start == index:
        root.max = value
        return
    mid = int((root.start+root.end)/2)
    if index<=mid:
        modify(root.left,index,value)
        root.maxv = max(root.right.maxv,root.left.maxv)
    else:
        modify(root.right,index,value)
        root.maxv = max(root.left.maxv,root.right.maxv)
    return
求区间的和
"""
Definition of Interval.
class Interval(object):
    def __init__(self, start, end):
        self.start = start
        self.end = end
"""


class Solution:
    """
    @param: A: An integer list
    @param: queries: An query list
    @return: The result list
    """
    def intervalSum(self, A, queries):
        # write your code here
        tmp = self.build(A)
        res = []
        for i in queries:
            print(i.start,i.end)
            p = self.query(tmp,i.start,i.end)
            res.append(p)
        return res

    def build(self,A):
        return self.buildhelper(0,len(A)-1,A)

    def buildhelper(self,left,right,A):
        if left>right:
            return None
        root = SegmentTreeNode(left,right,A[left])
        if left==right:
            return root
        mid = int((root.start+root.end)/2)
        root.left = self.buildhelper(left,mid,A)
        root.right =self.buildhelper(mid+1,right,A)
        root.max = root.left.max + root.right.max
        return root

    def query(self,root,start,end):
        print('r',start,end)
        if root==None:
            return 0
        if root.start > end or root.end < start:
            return 0
        if start<=root.start and end>=root.end:
            print('w')
            return root.max
        mid = int((root.start+root.end)/2)
        lefts = 0
        rights = 0
        lefts = self.query(root.left,start,end)
        rights = self.query(root.right,start,end)
        return lefts+rights

求区间的最小值
"""
Definition of Interval.
class Interval(object):
    def __init__(self, start, end):
        self.start = start
        self.end = end
"""


class Solution:
    """
    @param: A: An integer list
    @param: queries: An query list
    @return: The result list
    """
    def intervalMinNumber(self, A, queries):
        # write your code here
        tmp = self.build(A)
        res = []
        for i in queries:
            print(i.start,i.end)
            p = self.query(tmp,i.start,i.end)
            res.append(p)
        return res

    def build(self,A):
        return self.buildhelper(0,len(A)-1,A)

    def buildhelper(self,left,right,A):
        if left>right:
            return None
        root = SegmentTreeNode(left,right,A[left])
        if left==right:
            return root
        mid = int((root.start+root.end)/2)
        root.left = self.buildhelper(left,mid,A)
        root.right =self.buildhelper(mid+1,right,A)
        root.max = min(root.left.max , root.right.max)
        return root

    def query(self,root,start,end):
        print('r',start,end)
        if root==None:
            return 0
        if root.start > end or root.end < start:
            return 0
        if start<=root.start and end>=root.end:
            print('w')
            return root.max
        mid = int((root.start+root.end)/2)
        ans = 9999
        if mid>=start:
            ans = min(ans,self.query(root.left,start,end))
        if mid+1<=end:
            ans = min(ans,self.query(root.right,start,end))
        return ans

求区间比其小的数的个数

这种题目的思路:一般数的范围是固定的,比如一万,那么事先建立好一个10000的线段树,而后将数组遍历,对线段树进行修改,如此便可以在后续快速查找。

class Solution:
    """
    @param: A: An integer array
    @param: queries: The query list
    @return: The number of element in the array that are smaller that the given integer
    """

    def countOfSmallerNumber(self, A, queries):
        # write your code here
        lena = len(A)
        tmp = [0 for i in range(10001)]
        tmp = self.build(tmp)
        for i in A:
            self.modifyt(tmp,i)
        ans = []
        for i in queries:
            res  = self.queryt(tmp,0,i-1)
            ans.append(res)
        return ans
    def build(self,A):
        return self.buildhelper(0,len(A)-1,A)

    def buildhelper(self,left,right,A):
        if left>right:
            return None
        root = SegmentTreeNode(left,right,A[left])
        if left==right:
            return root
        mid = int((root.start+root.end)/2)
        root.left = self.buildhelper(left,mid,A)
        root.right =self.buildhelper(mid+1,right,A)
        root.maxv = root.left.count + root.right.count
        #root.minv = min(root.left.minv,root.right.minv)
        #root.sumv = sum(root.left.sumv,root.right.sumv)
        return root

    def queryt(self,root,start,end):
        if start<=root.start and end>=root.end:
            return root.count
        mid = int((root.start+root.end)/2)
        lefta = 0
        righta = 0
        if mid>=start:
            lefta = self.queryt(root.left,start,end)
        if mid+1<=end:
            righta = self.queryt(root.right,start,end)
        return lefta+righta

    def modifyt(self,root,index):
        if root.start == root.end and root.start == index:
            root.count += 1
            return
        mid = int((root.start+root.end)/2)
        if index<=mid:
            self.modifyt(root.left,index)
            root.count = root.left.count + root.right.count
        else:
            self.modifyt(root.right,index)
            root.count = root.left.count + root.right.count
        return


    def query(self,root,start,end):
        if start<=root.start and end>=root.end:
            return root.max
        mid = int((root.start+root.end)/2)
        ans = -9999
        if mid>=start:
            ans = max(ans,query(root.left,start,end))
        if mid+1<=end:
            ans = max(ans,query(root.right,start,end))
        return ans

上一篇 下一篇

猜你喜欢

热点阅读