ACM

ACM数据结构(一)——主席树

2018-01-17  本文已影响383人  hymscott

让我们来看一个经典的问题吧:

给定一个[1,n]的区间,m次操作,操作种类如下:

1 L R:查询[L,R]的区间和

2 L R X:将[L,R]的值加上X

这种经典问题,想必大家学过线段树后都可以轻松解决。然而如果再增加一种操作:

3 K:回退到第K次修改操作的结果

可见,如果题目要求回溯到历史版本,那么普通的线段树就不能解决了,因为在每次更新操作后,线段树存储的内容就发生了改变,如果不进行特殊记录,那么这种改变将是永久的。因此,对于这种类型的题目,我们可以用到今天要讨论的数据结构——主席树来进行解决。


主席树,严格来讲应该叫:函数式线段树,是基于线段树的一种数据结构,常用于处理一些在线问题,关于在线离线的概念参考上一篇文章:在线和离线算法。事实上,主席树就是多个线段树的集合体。

主席树的实质,就是以最初的线段树作为模板,通过"结点复用“的方式,实现存储多个线段树。

对于文章开始的问题,观察后可以发现,在2操作进行后,在上一次修改后的线段树上,最多修改O(logn)个结点(最远是从根节点到叶子节点)。如果每次单独新建一个线段树,则会造成重复存储,如图所示:

原始线段树 修改[6,7] 修改[3,5] 修改[1,9]

浅蓝色的结点是当前修改操作时访问的结点,白色结点为上一棵线段树的结点。

如果对每次修改操作无差别复制一棵线段树,那么用于存储节点的开销是巨大的,因为对于单次修改,大部分结点都不曾被访问修改。

通过“结点复用”的方式,我们可以将这多棵线段树压缩成如下形式:


开辟新结点 结点复用

因此第i个线段树只要通过保留除修改路径外的第i-1棵线段树的结点,再新增加至多O(logn)个结点。

rt[i]保存第i次操作的线段树的根节点,这样,回退到第k次操作等价于rt[i]=rt[k],我们的问题就迎刃而解啦。


那么,怎么来建立一棵主席树呢?针对文章开始的题目,下面给出实现步骤:

1. 创建根节点、左右儿子结点数组

int tot=0,rt[maxn*20],lson[maxn*20],rson[maxn*20],v[maxn*20],lz[maxn*20],a[maxn];

tot是每次新建的结点编号。
rt[i]是第i棵线段树的根节点的编号。
lson[x]和rson[x]是结点x的左右儿子结点的编号。
v[x]是结点x代表的区间的和。
lz[x]是结点x的懒惰(lazy)值。
a[i]是初始的第i个位置的值。
因为结点每次至多更新O(logn)个,所以数组范围应该在原来的20-50倍左右。

2.区间更新的pushup和pushdown

void push_up(int x){
    v[x]=v[lson[x]]+v[rson[x]];
}

void push_down(int x,int len){
    if(lz[x]){
        v[lson[x]]+=(len>>1)*lz[x];
        v[rson[x]]+=(len-(len>>1))*lz[x];
        lz[lson[x]]+=(len>>1)*lz[x];
        lz[rson[x]]+=(len-(len>>1))*lz[x];
        lz[x]=0;
    }
}

区间更新基础,不会的可以先了解线段树的区间更新写法。

3. 建树

void build(int &x,int l,int r){
    x=++tot;
    lz[x]=0;
    if(l==r){
        v[x]=a[l];
        return;
    }
    int mid=l+r>>1;
    build(lson[x],l,mid);
    build(rson[x],mid+1,r);
    push_up(x);
}

和线段树的思想是一样的,只是在调用过程中,我们以引用的形式,实现对rt,lson,rson的更新。
建树的调用如下:

build(rt[0],1,n);

3. 更新

void update(int L,int R,int l,int r,int &x,int last,int val){
    x=++tot;
    lson[x]=lson[last];rson[x]=rson[last];lz[x]=lz[last];v[x]=v[last];
    if(L<=l&&R>=r){
        v[x]+=(r-l+1)*val;lz[x]+=val;
        return;
    }
    push_down(x,r-l+1);
    int mid=l+r>>1;
    if(L<=mid) update(L,R,l,mid,lson[x],lson[last],val);
    if(R>mid) update(L,R,mid+1,r,rson[x],rson[last],val);
    push_up(x);
}

第1行开辟了新的结点,第2行进行了结点复用,last就是上一棵线段树的结点,从根节点向下更新。
更新的调用如下:

update(x,y,1,n,rt[i],rt[i-1],w);

4. 查询

int query(int L,int R,int l,int r,int x){
    if(L<=l&&R>=r){
        return v[x];
    }
    push_down(x,r-l+1);
    int mid=l+r>>1,sum=0;
    if(L<=mid) sum+=query(L,R,l,mid,lson[x]);
    if(R>mid) sum+=query(L,R,mid+1,r,rson[x]);
    push_up(x);
    return sum;
}

查询就是简单的区间查询。
查询的调用如下:

query(x,y,1,n,rt[i]);

5. 实现

#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <cmath>
#include <functional>
#include <map>
#include <stack>
#include <ctime>
#include <sstream>
#include <bitset>

//#include<ext/pb_ds/assoc_container.hpp>

//#include <bits/stdc++.h>

#define REP(i,j,k) for(int (i)=(j);(i)<=(k);(i)++)
#define ERP(i,j,k) for(int (i)=(j);(i)>=(k);(i)--)
#define MEM(a,b) memset(a,b,sizeof(a))
#define NE putchar('\n')
#define SP putchar(' ')
#define fi first
#define sc second
#define mkp make_pair
#define pb push_back
#define all(a) a.begin(),a.end()
//#define lson l,mid,x<<1
//#define rson mid+1,r,x<<1|1
#define lowbit(x) ((x)&(-(x)))
#define lc(a) ch[(a)][0]
#define mod_add(a,b,m) (a+b>m?a+b-m:a+b)
#define mod_sub(a,b,m) (a-b<0?a-b+m:a-b)

using namespace std;
//using namespace __gnu_pbds;
typedef double DB;
typedef long double LDB;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;

const DB eps=1e-6;
const DB Pi=acos(-1.0);
const ll mod=1e9+7;
const ull ha1=1e9+7;
const ull ha2=1e9+9;
const int maxn=1e5+10;
const int maxm=1e6+10;
const int inf=1e9+10;

//IO挂
template<typename Type>inline void read(Type&in){
    in=0;Type f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){in=in*10+ch-'0';ch=getchar();}
    in*=f;
}

template<typename Type>inline void out(Type o){
    if(o<0){putchar('-');o=-o;}
    if(o>=10) out(o/10);
    putchar(o%10+'0');
}

/*Header*/
//printf("%d%c",a[i]," \n"[i==n]);

int tot=0,rt[maxn*20],lson[maxn*20],rson[maxn*20],v[maxn*20],lz[maxn*20],a[maxn];

void push_up(int x){
    v[x]=v[lson[x]]+v[rson[x]];
}

void push_down(int x,int len){
    if(lz[x]){
        v[lson[x]]+=(len>>1)*lz[x];
        v[rson[x]]+=(len-(len>>1))*lz[x];
        lz[lson[x]]+=(len>>1)*lz[x];
        lz[rson[x]]+=(len-(len>>1))*lz[x];
        lz[x]=0;
    }
}

void build(int &x,int l,int r){
    x=++tot;
    lz[x]=0;
    if(l==r){
        v[x]=a[l];
        return;
    }
    int mid=l+r>>1;
    build(lson[x],l,mid);
    build(rson[x],mid+1,r);
    push_up(x);
}

void update(int L,int R,int l,int r,int &x,int last,int val){
    x=++tot;
    lson[x]=lson[last];rson[x]=rson[last];lz[x]=lz[last];v[x]=v[last];
    if(L<=l&&R>=r){
        v[x]+=(r-l+1)*val;lz[x]+=val;
        return;
    }
    push_down(x,r-l+1);
    int mid=l+r>>1;
    if(L<=mid) update(L,R,l,mid,lson[x],lson[last],val);
    if(R>mid) update(L,R,mid+1,r,rson[x],rson[last],val);
    push_up(x);
}

int query(int L,int R,int l,int r,int x){
    if(L<=l&&R>=r){
        return v[x];
    }
    push_down(x,r-l+1);
    int mid=l+r>>1,sum=0;
    if(L<=mid) sum+=query(L,R,l,mid,lson[x]);
    if(R>mid) sum+=query(L,R,mid+1,r,rson[x]);
    push_up(x);
    return sum;
}

int x,y,w;

int main(){
    int n,k,opt;
    cin>>n>>k;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    build(rt[0],1,n);
    for(int i=1;i<=k;i++){
        cin>>opt;
        if(opt==1){
            rt[i]=rt[i-1];
            cin>>x>>y;
            cout<<query(x,y,1,n,rt[i])<<endl;
        }
        else if(opt==2){
            cin>>x>>y>>w;
            update(x,y,1,n,rt[i],rt[i-1],w);
        }
        else{
            cin>>x;
            rt[i]=rt[x];
        }
    }
    return 0;
}

对于第i个操作,方式1通过rt[i-1]更新rt[i],方式2通过引用更新rt[i],方式3通过rt[x]更新rt[i]。

6. 测试一下~

input.txt
10 8
1 2 3 4 5 6 7 8 9 10
2 6 7 2
1 6 7
2 3 5 4
1 3 5
2 1 9 5
1 1 9
3 3
1 1 10

output.txt
17
24
106
71

正确无误~(blink)


那么,主席树的入门就到这里了,下面给出poj 2104(静态区间求第K大)的主席树代码,作为参考啦~

#include <bits/stdc++.h>
#include <cstdio>

#define fi first
#define sc second
#define mkp make_pair
#define pb push_back
#define all(a) a.begin(),a.end()

using namespace std;
typedef long long ll;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;

const double eps=1e-8;
const double pi=acos(-1);
const int mod=1e9+7;

/*Header*/

const int maxn=1e5+10;

int rt[maxn*20],lson[maxn*20],rson[maxn*20],sum[maxn*20];
int a[maxn],b[maxn];
int tot;

int n,q;

void build(int &x,int l,int r){
    x=++tot;
    sum[x]=0;
    if(l==r) return;
    int mid=(l+r)>>1;
    build(lson[x],l,mid);
    build(rson[x],mid+1,r);
}

void update(int &x,int last,int l,int r,int pos){
    x=++tot;
    lson[x]=lson[last];
    rson[x]=rson[last];
    sum[x]=sum[last]+1;
    if(l==r) return;
    int mid=(l+r)>>1;
    if(pos<=mid) update(lson[x],lson[last],l,mid,pos);
    else update(rson[x],rson[last],mid+1,r,pos);
}

int query(int L,int R,int l,int r,int k){
    if(l==r) return l;
    int mid=(l+r)>>1;
    int cnt=sum[lson[R]]-sum[lson[L]];
    if(k<=cnt) return query(lson[L],lson[R],l,mid,k);
    else return query(rson[L],rson[R],mid+1,r,k-cnt);
}

int main(){
    int T;
    scanf("%d",&T);
    while(T--){
        scanf("%d %d",&n,&q);
        for(int i=1;i<=n;i++){
            scanf("%d",&a[i]);
            b[i]=a[i];
        }
        sort(b+1,b+1+n);
        int m=unique(b+1,b+1+n)-(b+1);
        tot=0;
        build(rt[0],1,m);
        for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b;
        for(int i=1;i<=n;i++) update(rt[i],rt[i-1],1,m,a[i]);
        int x,y,k,ans;
        while(q--){
            scanf("%d %d %d",&x,&y,&k);
            ans=query(rt[x-1],rt[y],1,m,k);
            printf("%d\n",b[ans]);
        }
    }
    return 0;
}
上一篇下一篇

猜你喜欢

热点阅读