数据结构

「数据结构进阶」例题之数据结构相关思想

2019-02-14  本文已影响10人  云中翻月
0x40「数据结构进阶」例题

分块

分块思想的本质一句话概括:大段维护,小段朴素。
它同样能够在O(n\sqrt{n})的时间内维护区间求和,单点修改,区间最值,单点查值,区间修改(区间修改用延迟标记add实现)
代码如下(以POJ3468 A Simple Problem with Integers为例)
预处理每个块的左右端点

t = sqrt(n*1.0);
for (int i = 1; i <= t; i++) {
    L[i] = (i - 1)*sqrt(n*1.0) + 1;
    R[i] = i*sqrt(n*1.0);
}
if (R[t] < n) t++, L[t] = R[t - 1] + 1, R[t] = n; //考虑不完整的块

预处理块内维护信息

for (int i = 1; i <= t; i++)
    for (int j = L[i]; j <= R[i]; j++) {
        pos[j] = i;
        sum[i] += a[j];
    }

读取指令

while (m--) {
    char op[3];
    int l, r, d;
    scanf("%s%d%d", op, &l, &r);
    if (op[0] == 'C') {
        scanf("%d", &d);
        change(l, r, d);
    }
    else printf("%lld\n", ask(l, r));
}

区间修改

void change(int l, int r, long long d) {
    int p = pos[l], q = pos[r];
    if (p == q) {
        for (int i = l; i <= r; i++) a[i] += d;
        sum[p] += d*(r - l + 1);
    }
    else {
        for (int i = p + 1; i <= q - 1; i++) add[i] += d;
        for (int i = l; i <= R[p]; i++) a[i] += d;
        sum[p] += d*(R[p] - l + 1);
        for (int i = L[q]; i <= r; i++) a[i] += d;
        sum[q] += d*(r - L[q] + 1);
    }
}

区间求和

long long ask(int l, int r) {
    int p = pos[l], q = pos[r];
    long long ans = 0;
    if (p == q) {
        for (int i = l; i <= r; i++) ans += a[i];
        ans += add[p] * (r - l + 1);
    }
    else {
        for (int i = p + 1; i <= q - 1; i++)
            ans += sum[i] + add[i] * (R[i] - L[i] + 1);
        for (int i = l; i <= R[p]; i++) ans += a[i];
        ans += add[p] * (R[p] - l + 1);
        for (int i = L[q]; i <= r; i++) ans += a[i];
        ans += add[q] * (r - L[q] + 1);
    }
    return ans;
}

但是,由于分块的复杂度略大于线段树的复杂度,所以实现上述操作在极限数据的情况下,分块算法的耗时是线段树算法的两倍,已经接近极限时限,所以如果仅仅是维护以上的几个功能,分块其实不如线段树。
那么,分块有什么独特之处?
1 它维护的信息不一定需要满足区间加法
2 它可以在块之间和块内维护不同的信息。(例题:CH #46A)
3 它可以离线对询问分块。
下面我们通过几道例题详细解释。

例题

4401 蒲公英
本题要求强制在线维护区间众数。
分析众数的性质。首先,它不满足区间加法,即两个区间的众数可能都不是合并后的大区间的众数。其次,假设一个区间[l,r]包括前后两个不完整的块[l,L)和(R,r],以及中间若干个完整的块[L,R],那么这个区间的众数这可能是[L,R]的众数或者[l,L)和(R,r]里出现的数
因此,我们用分块预处理出所有以段边界为端点的区间的众数,若块大小为T,那么这样的区间数有O(T^{2}),预处理出数组c[i][j][k]表示第i块到第j块中数字k出现的次数(由于k的范围较大,这里需要离散化),f[i][j]表示第i块到第j块中的众数出现次数,d[i][j]表示第i块到第j块中的众数,时间复杂度为O(nT^{2})
对于每个询问,我们朴素扫描两头不完整的区间,向c数组里累加,计算完答案后又复原。
总时间复杂度为O(nT^{2}+m\frac{n}{T})
nT^{2}=m\frac{n}{T}T=\sqrt[3]{n}
代码如下

/*

*/
#define method_1
#ifdef method_1
/*

*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iomanip>
#define D(x) cout<<#x<<" = "<<x<<"  "
#define E cout<<endl
using namespace std;
typedef long long ll;
typedef pair<int,int>pii;
const int maxn=40005+5;
const int maxk=40+5;
const int INF=0x3f3f3f3f;
int n,m,a[maxn],ans=0,num,l[maxn],r[maxn],b[maxn],tot,c[maxk][maxk][maxn],f[maxk][maxk],d[maxk][maxk],L,R,cnt,pos;
void pre() {
    num=int(pow(n*1.0,1.0/3)); //num为块数
    int sz;
    if(num) sz=n/num;
    for(int i=1; i<=num; i++) {
        l[i]=(i-1)*sz+1;
        r[i]=i*sz;
    }
    if(r[num]<n) l[num+1]=r[num]+1,r[++num]=n;
    memcpy(b,a,sizeof(a));
    sort(b+1,b+n+1);
    tot=unique(b+1,b+n+1)-b-1;
    for(int i=1; i<=n; i++) {
        a[i]=lower_bound(b+1,b+tot+1,a[i])-b;
    }
    for(int i=1;i<=num; i++) {
        for(int j=i; j<=num; j++) {
            for(int k=l[i]; k<=r[j]; k++) {
                c[i][j][a[k]]++;
            }
            for(int k=1; k<=tot; k++) {
                if(c[i][j][k]>f[i][j]||c[i][j][k]==f[i][j]&&k<d[i][j]) {
                    d[i][j]=k;
                    f[i][j]=c[i][j][k];
                }
            }
        }
    }
}
void update(int i) {
    c[L][R][a[i]]++;
    if(c[L][R][a[i]]>cnt||c[L][R][a[i]]==cnt&&a[i]<pos) {
        pos=a[i];
        cnt=c[L][R][a[i]];
    }
}
int solve(int x,int y) {
    if(x>y) swap(x,y);
    int l1,r1;
    for(int i=1; i<=num; i++) {
        if(x<=r[i]) {
            l1=i;
            break;
        }
    }
    for(int i=num; i>=1; i--) {
        if(y>=l[i]) {
            r1=i;
            break;
        }
    }
    if(l1+1<=r1-1) L=l1+1,R=r1-1;
    else L=R=0;
    cnt=f[L][R],pos=d[L][R];
    if(l1==r1) {
        for(int i=x; i<=y; i++) update(i);
        for(int i=x; i<=y; i++) c[L][R][a[i]]--;
    }
    else{
        for(int i=x;i<=r[l1];i++) update(i);
        for(int i=l[r1];i<=y;i++) update(i);
        for(int i=x;i<=r[l1];i++) c[L][R][a[i]]--;
        for(int i=l[r1];i<=y;i++) c[L][R][a[i]]--;
    }
    return b[pos];  
}
int main() {
    ios::sync_with_stdio(false);
//  freopen("蒲公英.in","r",stdin);
    cin>>n>>m;
    for(int i=1; i<=n; i++) {
        cin>>a[i];
    }
    pre();
    int x,y;
    while(m--) {
        cin>>x>>y;
        x=(x+ans-1)%n+1;
        y=(y+ans-1)%n+1;
        ans=solve(x,y);
        cout<<ans<<endl;
    }
    return 0;
}
#endif

4402 小Z的袜子
考虑对询问分块,把询问按左端点升序排序,然后每块内部按右端点升序排序,那么块内相邻两个询问左端点变化在O(\sqrt{n})内,右端点变化单调,那么我们就能够根据上一次询问,每次O(\sqrt{n})的时间处理左端点少去的部分和多出的部分以及右端点多出的部分,块内右端点变化范围为O(\sqrt{n}),那么总的时间复杂度就是O(n\sqrt{n})
具体地说,对于每块的第一个询问朴素计算,得到数组cnt,表示该块第一个询问区间[l,r]中颜色为c的袜子有cnt[c]只。另外,我们记录变量ans,保存\sum_{c}cnt[c]*(cnt[c]-1)/2,实时维护ans的变化,那么每次询问的答案就是\frac{ans}{C_{r-l+1}^{2}}
代码如下

/*

*/
#define method_1
#ifdef method_1
/*

*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iomanip>
#define D(x) cout<<#x<<" = "<<x<<"  "
#define E cout<<endl
using namespace std;
typedef long long ll;
typedef pair<int,int>pii;
const int maxn=50000+5;
const int INF=0x3f3f3f3f;
struct node{
    int l,r,id;
}q[maxn];
int n,m;
ll cnt[maxn],c[maxn],l[maxn],r[maxn],num,sz,up[maxn],down[maxn],ans=0;
ll gcd(ll a,ll b){
    return !b?a:gcd(b,a%b);
}
bool cmp1(node a,node b){
    return a.l<b.l;
}
bool cmp2(node a,node b){
    return a.r<b.r;
}
void cal(int x,int y,int z){
    for(int i=x;i<=y;i++){
        ans-=cnt[c[i]]*(cnt[c[i]]-1);
        cnt[c[i]]+=z;
        ans+=cnt[c[i]]*(cnt[c[i]]-1);
    }
}
int main() {
    ios::sync_with_stdio(false);
//  freopen("小Z的袜子.in","r",stdin);
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        cin>>c[i];
    }
    for(int i=1;i<=m;i++){
        cin>>q[i].l>>q[i].r;
        q[i].id=i;
    }
    num=sqrt(m);
    sz=m/num;
    sort(q+1,q+m+1,cmp1);
    for(int i=1;i<=num;i++){
        l[i]=(i-1)*sz+1;
        r[i]=i*sz;
    }
    if(r[num]<m){
        l[num+1]=r[num]+1;
        r[++num]=n;
    }
    for(int i=1;i<=num;i++){
        sort(q+l[i],q+r[i]+1,cmp2);
    }
    for(int i=1;i<=num;i++){
        int nl=q[l[i]].l,nr=nl-1;//nr<nl 保证第一次是朴素处理的区间 
//[nl,nr]表示已经处理好的区间,初始区间为空,即nr<nl
        memset(cnt,0,sizeof(cnt));
        ans=0;
        for(int j=l[i];j<=r[i];j++){ //枚举块内的每一个询问 
            if(q[j].l>nl) cal(nl,q[j].l-1,-1);//这两行处理左端点多/少出来的部分 
            if(q[j].l<nl) cal(q[j].l,nl-1,1);
            if(q[j].r>nr) cal(nr+1,q[j].r,1);//处理右端点多出的部分 
            up[q[j].id]=ans;
            down[q[j].id]=q[j].r-q[j].l+1;
            nl=q[j].l;
            nr=q[j].r;
        }
    }
    for(int i=1;i<=m;i++){
        if(up[i]==0){
            cout<<"0/1"<<endl;
        }
        else{
            down[i]*=(down[i]-1);
            int d=gcd(up[i],down[i]);
            cout<<up[i]/d<<"/"<<down[i]/d<<endl;
        }
    }
    return 0;
}
#endif

点分治

之前讲的都是对一维序列区间[l,r]上的操作,若给定树上两个节点x,y,那么序列上的区间就对应了树上两点间的路径。点分治就是一种树上静态路径统计算法。

例题

POJ1741 Tree
若当前节点p为根节点,那么树上路径有两类:
1 经过p
2 包含于p的某个子树内(不经过p)
第二类路径显然可以作为一个第一类路径递归的子问题思考,我们着重考虑第一类路径的计数。
预处理出每个点到根节点的距离dis,那么第一类路径中的(x,y)满足以下条件:
1 x和y在p的不同子树里
2 dis[x]+dis[y]\leq K
统计方法一:建立一个树状数组,依次处理p的每棵子树s_{i}:对与s_{i}中的节点x,答案累加ask(K-dis[x]),就是符合条件的y的个数。然后处理完子树的所有节点后,对于s_{i}的每个节点x,执行操作add(dis[x],1),表示与p距离dis[x]的节点多了一个。逐个子树统计的情况下,保证了x和y不在同一个子树里,查询前缀和ask(K-dis[x])保证了dis[x]+dis[y]\leq K。但是这里路径长度过大,树状数组空间无法承受,若用平衡树代替树状数组代码实现难度较高。
统计方法二:将树上的每个点放入一个数组,按照dis升序排序,使用两个指针L,R分别从前、后扫描数组。显然,L从左向右扫描的过程中,满足dis[L]+dis[R]\leq K的R单调递减,因此,我们只要每次将答案中累加R-L即可。当然,这种方法可能导致x和y在同一个子树中的情况出现,产生重复计数,因此在递归其子树时要排除这种情况(容斥原理)。
PS:点分治算法的时间复杂度为O(Tnlogn),其中T为递归深度。为了防止树是一条链的情况导致点分治每次都以链的一端为根而遍历深度过大(时间复杂度退化为O(n^{2}logn)),我们每次点分治的根节点都取当前子树的重心,这时期望时间复杂度为O(nlog^{2}n)
代码如下

/*

*/
#define method_1
#ifdef method_1
/*

*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iomanip>
#define D(x) cout<<#x<<" = "<<x<<"  "
#define E cout<<endl
using namespace std;
typedef long long ll;
typedef pair<int,int>pii;
const int maxn=10000+5;
const int INF=0x3f3f3f3f;
struct node{
    int from,to,v;
}edge[maxn<<1];
int n,k,vis[maxn],head[maxn],dis[maxn],root,maxs,s[maxn],size,mxson[maxn],cnt,tot;
ll ans;
void add(int from,int to,int v){
    edge[++cnt].from=head[from];
    head[from]=cnt;
    edge[cnt].to=to;
    edge[cnt].v=v;
}
void getroot(int x,int fa){
    s[x]=1;
    mxson[x]=0;
//  vis[x]=1;
    for(int i=head[x];i;i=edge[i].from){
        int y=edge[i].to;
        if(!vis[y]&&y!=fa){
            getroot(y,x);
            s[x]+=s[y];
            mxson[x]=max(mxson[x],s[y]);
        }
    }
    mxson[x]=max(mxson[x],size-s[x]);
    if(mxson[x]<maxs){
        maxs=mxson[x];
        root=x;
    }
}
void getdis(int x,int fa,int dist){
    dis[++tot]=dist;
    for(int i=head[x];i;i=edge[i].from){
        int y=edge[i].to;
        int v=edge[i].v;
        if(!vis[y]&&y!=fa){
            getdis(y,x,dist+v);
        }
    }
}
int consolate(int x,int len){
    tot=0;
    memset(dis,0,sizeof(dis));
//  memset(vis,0,sizeof(vis));
    getdis(x,0,len);
    sort(dis+1,dis+tot+1);
    /*
    for(int i=1;i<=n;i++){
        D(dis[i]);
        E;
    }
    E;
    */
    int L=1,R=tot,temp=0;
    while(L<=R){
        if(dis[L]+dis[R]<=k){
            temp+=R-L;
            L++;
        }
        else R--;
    }
    return temp;
}
void divide(int rt){
    ans+=consolate(rt,0);
    vis[rt]=1;
    for(int i=head[rt];i;i=edge[i].from){
        int y=edge[i].to;
        int v=edge[i].v;
        if(!vis[y]){
            ans-=consolate(y,v);
            size=s[y];
            maxs=INF;
            getroot(y,0);
            divide(root);
        }
    }
}
int main() {
    ios::sync_with_stdio(false);
//  freopen("Tree.in","r",stdin);
    while(cin>>n>>k&&n&&k){
//      cin>>n>>k;
        int x,y,z;
        cnt=0;
        memset(head,0,sizeof(head));
        for(int i=1;i<=n-1;i++){
            cin>>x>>y>>z;
            add(x,y,z);
            add(y,x,z);
        }
        maxs=INF;
        ans=0;
        size=n;
        memset(vis,0,sizeof(vis));
        getroot(1,0);
//      for(int i=1;i<=n;i++) D(s[i]);
//      E;
        divide(root);
        cout<<ans<<endl;
    }
    return 0;
}
#endif
上一篇下一篇

猜你喜欢

热点阅读