线段树
2017-12-20 本文已影响39人
苏南走了
线段树
主要是涉及到线段树的思路。构造和一些简单的应用。
线段树的应用比较集中,主要是用来处理数组结构里的一些快速查找和修改。以查找数组某区间的最大值这个需求为例子,构建线段树的复杂度为N,之后每次查找和修改操作复杂度都是logN。
线段树的构造比较简单,查找和修改对于不同类的问题略有不同。如果是查找最大值最小值这样的,在节点分叉后,我们需要根据分叉部分判断是进入左分支还是右分支,并更新结果。如果是查找某区间的和,则需要根据分叉节点的位置分别计算左部分和右部分的结果,最后相加即可。
下面给出一些例子。其实例子代码都差不多。。。
线段树的基本结构例子
#coding=utf8
class SegmentTreeNode:
def __init__(self,start,end,maxv):
self.start = start
self.end = end
self.count = maxv
self.left = None
self.right = None
class Solution:
"""
@param: A: An integer array
@param: queries: The query list
@return: The number of element in the array that are smaller that the given integer
"""
def countOfSmallerNumber(self, A, queries):
# write your code here
lena = len(A)
tmp = [0 for i in range(10001)]
tmp = self.build(tmp)
for i in A:
self.modifyt(tmp,i)
ans = []
for i in queries:
res = self.queryt(tmp,0,i-1)
ans.append(res)
return ans
def build(self,A):
return self.buildhelper(0,len(A)-1,A)
def buildhelper(self,left,right,A):
if left>right:
return None
root = SegmentTreeNode(left,right,A[left])
if left==right:
return root
mid = int((root.start+root.end)/2)
root.left = self.buildhelper(left,mid,A)
root.right =self.buildhelper(mid+1,right,A)
root.maxv = root.left.count + root.right.count
#root.minv = min(root.left.minv,root.right.minv)
#root.sumv = sum(root.left.sumv,root.right.sumv)
return root
def queryt(self,root,start,end):
if start<=root.start and end>=root.end:
return root.count
mid = int((root.start+root.end)/2)
lefta = 0
righta = 0
if mid>=start:
lefta = self.queryt(root.left,start,end)
if mid+1<=end:
righta = self.queryt(root.right,start,end)
return lefta+righta
def modifyt(self,root,index):
if root.start == root.end and root.start == index:
root.count += 1
return
mid = int((root.start+root.end)/2)
if index<=mid:
self.modifyt(root.left,index)
root.count = root.left.count + root.right.count
else:
self.modifyt(root.right,index)
root.count = root.left.count + root.right.count
return
def query(self,root,start,end):
if start<=root.start and end>=root.end:
return root.max
mid = int((root.start+root.end)/2)
ans = -9999
if mid>=start:
ans = max(ans,query(root.left,start,end))
if mid+1<=end:
ans = max(ans,query(root.right,start,end))
return ans
def modify(root,index,value):
if root.start == root.end and root.start == index:
root.max = value
return
mid = int((root.start+root.end)/2)
if index<=mid:
modify(root.left,index,value)
root.maxv = max(root.right.maxv,root.left.maxv)
else:
modify(root.right,index,value)
root.maxv = max(root.left.maxv,root.right.maxv)
return
求区间的和
"""
Definition of Interval.
class Interval(object):
def __init__(self, start, end):
self.start = start
self.end = end
"""
class Solution:
"""
@param: A: An integer list
@param: queries: An query list
@return: The result list
"""
def intervalSum(self, A, queries):
# write your code here
tmp = self.build(A)
res = []
for i in queries:
print(i.start,i.end)
p = self.query(tmp,i.start,i.end)
res.append(p)
return res
def build(self,A):
return self.buildhelper(0,len(A)-1,A)
def buildhelper(self,left,right,A):
if left>right:
return None
root = SegmentTreeNode(left,right,A[left])
if left==right:
return root
mid = int((root.start+root.end)/2)
root.left = self.buildhelper(left,mid,A)
root.right =self.buildhelper(mid+1,right,A)
root.max = root.left.max + root.right.max
return root
def query(self,root,start,end):
print('r',start,end)
if root==None:
return 0
if root.start > end or root.end < start:
return 0
if start<=root.start and end>=root.end:
print('w')
return root.max
mid = int((root.start+root.end)/2)
lefts = 0
rights = 0
lefts = self.query(root.left,start,end)
rights = self.query(root.right,start,end)
return lefts+rights
求区间的最小值
"""
Definition of Interval.
class Interval(object):
def __init__(self, start, end):
self.start = start
self.end = end
"""
class Solution:
"""
@param: A: An integer list
@param: queries: An query list
@return: The result list
"""
def intervalMinNumber(self, A, queries):
# write your code here
tmp = self.build(A)
res = []
for i in queries:
print(i.start,i.end)
p = self.query(tmp,i.start,i.end)
res.append(p)
return res
def build(self,A):
return self.buildhelper(0,len(A)-1,A)
def buildhelper(self,left,right,A):
if left>right:
return None
root = SegmentTreeNode(left,right,A[left])
if left==right:
return root
mid = int((root.start+root.end)/2)
root.left = self.buildhelper(left,mid,A)
root.right =self.buildhelper(mid+1,right,A)
root.max = min(root.left.max , root.right.max)
return root
def query(self,root,start,end):
print('r',start,end)
if root==None:
return 0
if root.start > end or root.end < start:
return 0
if start<=root.start and end>=root.end:
print('w')
return root.max
mid = int((root.start+root.end)/2)
ans = 9999
if mid>=start:
ans = min(ans,self.query(root.left,start,end))
if mid+1<=end:
ans = min(ans,self.query(root.right,start,end))
return ans
求区间比其小的数的个数
这种题目的思路:一般数的范围是固定的,比如一万,那么事先建立好一个10000的线段树,而后将数组遍历,对线段树进行修改,如此便可以在后续快速查找。
class Solution:
"""
@param: A: An integer array
@param: queries: The query list
@return: The number of element in the array that are smaller that the given integer
"""
def countOfSmallerNumber(self, A, queries):
# write your code here
lena = len(A)
tmp = [0 for i in range(10001)]
tmp = self.build(tmp)
for i in A:
self.modifyt(tmp,i)
ans = []
for i in queries:
res = self.queryt(tmp,0,i-1)
ans.append(res)
return ans
def build(self,A):
return self.buildhelper(0,len(A)-1,A)
def buildhelper(self,left,right,A):
if left>right:
return None
root = SegmentTreeNode(left,right,A[left])
if left==right:
return root
mid = int((root.start+root.end)/2)
root.left = self.buildhelper(left,mid,A)
root.right =self.buildhelper(mid+1,right,A)
root.maxv = root.left.count + root.right.count
#root.minv = min(root.left.minv,root.right.minv)
#root.sumv = sum(root.left.sumv,root.right.sumv)
return root
def queryt(self,root,start,end):
if start<=root.start and end>=root.end:
return root.count
mid = int((root.start+root.end)/2)
lefta = 0
righta = 0
if mid>=start:
lefta = self.queryt(root.left,start,end)
if mid+1<=end:
righta = self.queryt(root.right,start,end)
return lefta+righta
def modifyt(self,root,index):
if root.start == root.end and root.start == index:
root.count += 1
return
mid = int((root.start+root.end)/2)
if index<=mid:
self.modifyt(root.left,index)
root.count = root.left.count + root.right.count
else:
self.modifyt(root.right,index)
root.count = root.left.count + root.right.count
return
def query(self,root,start,end):
if start<=root.start and end>=root.end:
return root.max
mid = int((root.start+root.end)/2)
ans = -9999
if mid>=start:
ans = max(ans,query(root.left,start,end))
if mid+1<=end:
ans = max(ans,query(root.right,start,end))
return ans