也谈线段树

2020-11-26  本文已影响0人  乔治yuanbo

国内大佬们写的很难理解,找了个外国友人的文章,一下就看懂了。本文参考:
geeksforgeeks基础线段树
geeksforgeeks懒标记区间更新
要掌握线段树,得一步一步来。一上来就lazytag,很难理解。

一、普通单点修改

如果修改的单点属于当前树上节点覆盖的范围,直接改,然后改左右子树。没有什么pushup和pushdown。

//ss、se分别是当前树上节点覆盖范围开始和结束下标
//si是树上元素在树的数组里的下标,i是原数组下标,diff是加多少
//调用的时候从根开始update(1,n,1,5,20)
void update(int ss, int se,, int si, int i,  int diff) 
{ 
    if (i < ss || i > se) 
        return; 
    st[si] = st[si] + diff; 
    if (se != ss) 
    { 
        int mid = getMid(ss, se); 
        update(ss, mid,  2*si, i, diff); 
        update(mid+1, se,,2*si+1, i, diff); 
    } 
}

二、普通区间修改

区间修改,先看树上节点覆盖的范围和修改的范围有没有交集,没有就什么都不干;有的话分两种情况,一是到了叶子,直接更新;二是没到叶子节点,又分两种情况,1是节点覆盖范围被修改范围完全覆盖;2是不完全覆盖,不管哪种情况,做法都一样,直接更新左右子树,更新完以后,重新计算左右子树的值,更新当前节点值。也没有什么pushup和pushdown。

//us和ue分别是更新区间的下标开始、结束
void update(int ss, int se, int si, int us, int ue, int diff){
    if (ss > ue || se < us) return;
    if(ss == se){
        st[si].v = st[si].v + diff;
        return;
    }
    int mid = getmid(ss, se);
    update(ss, mid, si * 2, us, ue, diff);
    update(mid+1, se, si * 2 + 1, us, ue, diff);
    st[si].v = st[si*2].v + st[si*2+1].v;
}

仔细体会这种方式,类似深度优先遍历,从根直接到叶子节点,叶子节点更新完成后,一层一层往上更新中间节点,最后更新根。

三、懒标记区间修改

暴力区间修改太慢了,最坏情况下,如果更新整个数组,复杂度O(nlogn),比直接在原数组上更新还慢,所以必须改进。
改进办法是加入懒标记,首先必须明确最重要的一点,当一个树上节点覆盖范围完全被更新区间包含时,这个节点和所有这个节点的子孙都需要更新;反之如果一个树上节点覆盖范围和更区间部分重合,则肯定有一部分子孙需要更新,另一部分绝不需要更新。我们的做法是,该更新还是更新,直接更新就行((se-ss+1)x diff),而不像上面暴力更新那样,深度优先到叶子上,从叶子一层一层往上更新。直接更新完以后,给子孙设置懒标记,被设置懒标记的节点,先不要动,等以后更新或者查询的时候,再处理。
一个节点的懒标记,延迟的是这个节点和它的所有子孙的更新。当一个节点遇到更新和查询操作时,有懒标记的话就先消化懒标记,然后把懒标记下传(也就是他们说的pushdown)给子孙,最后正常更新。
更新完一个节点后,也需要下传懒标记,停止更新进程,把子孙的更新推迟。
两种情况需要下传懒标记,一是自己消化懒标记时,二是自己更新时。下传懒标记的时候注意判断自己是不是叶子,不是才下传,是的话下传就数组越界了。
总之,懒标记是爸爸给他的,不是自己给自己的。懒标记的消化,在更新和查询操作中。懒标记消化分3步:更新自己、传给儿子、还原初始状态(还原或清零)。
举个例子,首先更新1-3,有个节点覆盖1-3,先把它更新,懒标记下传给1-2的爸爸,和3,结束。这时要查询2-4,需要查询2和3,这两个节点上都有懒标记,先消化,再返回。
看代码:

//洛谷p3373线段树模板2
#include <cstdio>
#define MAXN 100000
typedef long long ll;
using namespace std;

//线段树节点,v表示值,lza加法懒标记,lzm乘法懒标记
struct node {
    ll v, lza, lzm; 
} st[MAXN*4+1];
int a[MAXN+1];
int n, m, p;
inline int getmid(int s, int e){
    return s + (e - s) / 2;
}
inline int left(int si){
    return si * 2;
}
inline int right(int si){
    return si * 2 + 1;
}
ll build(int ss, int se, int si){
    st[si].lzm = 1;
    if (ss == se) {
        return st[si].v = a[ss] % p;
    }
    int mid = getmid(ss, se);
    return st[si].v = (build(ss, mid, si * 2) + build(mid+1, se, si * 2 + 1)) % p;
}
void update(int ss, int se, int si, int us, int ue, int op, int opt){
    if (st[si].lzm != 1){
        st[si].v = st[si].v * st[si].lzm % p;//消化
        if(ss != se){//下传
            st[left(si)].lzm = st[left(si)].lzm * st[si].lzm % p;
            st[left(si)].lza = st[left(si)].lza * st[si].lzm % p;
            st[right(si)].lzm = st[right(si)].lzm * st[si].lzm % p;
            st[right(si)].lza = st[right(si)].lza * st[si].lzm % p;
        }
        st[si].lzm = 1;//还原
    }
    if (st[si].lza != 0){
        st[si].v = (st[si].v + (se - ss + 1) * st[si].lza) % p;
        if (ss != se){
            st[left(si)].lza = (st[left(si)].lza + st[si].lza) % p;
            st[right(si)].lza = (st[right(si)].lza + st[si].lza) % p;
        }
        st[si].lza = 0;
    }
    if (ss > ue || se < us) return;
    if (ss >= us && se <= ue){//完全在更新范围内
        //先更新自己
        if (op == 1){
            st[si].v = st[si].v * opt % p;
        } else if (op == 2){
            st[si].v = (st[si].v + (se - ss + 1) * opt) % p;
        }
        if (ss != se){//给儿孙设置懒标记
            if(op == 1){
                st[left(si)].lzm = st[left(si)].lzm * opt % p;
                st[left(si)].lza = st[left(si)].lza * opt % p;
                st[right(si)].lzm = st[right(si)].lzm * opt % p;
                st[right(si)].lza = st[right(si)].lza * opt % p;
            } else {
                st[left(si)].lza += opt;
                st[right(si)].lza += opt;
            }
        }
        return;
    }

    int mid = getmid(ss, se);
    update(ss, mid, left(si), us, ue, op, opt);
    update(mid+1, se, right(si), us, ue, op, opt);
    st[si].v = (st[left(si)].v + st[right(si)].v) % p;
}
ll query(int ss, int se, int si, int qs, int qe){
    if (st[si].lzm != 1){
        st[si].v = st[si].v * st[si].lzm % p;//消化
        if(ss != se){//下传
            st[left(si)].lzm = st[left(si)].lzm * st[si].lzm % p;
            st[left(si)].lza = st[left(si)].lza * st[si].lzm % p;
            st[right(si)].lzm = st[right(si)].lzm * st[si].lzm % p;
            st[right(si)].lza = st[right(si)].lza * st[si].lzm % p;
        }
        st[si].lzm = 1;//还原
    }
    if (st[si].lza != 0){
        st[si].v = (st[si].v + (se - ss + 1) * st[si].lza) % p;
        if (ss != se){
            st[left(si)].lza = (st[left(si)].lza + st[si].lza) % p;
            st[right(si)].lza = (st[right(si)].lza + st[si].lza) % p;
        }
        st[si].lza = 0;
    }
    if(ss >= qs && se <= qe){
        return st[si].v;
    }
    if (ss > qe || se < qs) return 0;
    int mid = getmid(ss, se);
    return (query(ss, mid, si * 2, qs, qe) + query(mid + 1, se, si * 2 + 1, qs, qe)) % p;
}
int main(){
    // freopen("P3373_2.in", "r", stdin);
    scanf("%d%d%d", &n, &m, &p);
    for(int i = 1; i <= n; i++){
        scanf("%d", a + i);
    }
    build(1, n, 1);
    int op, x, y, k;
    while(m--){
        scanf("%d", &op);
        if (op == 1 || op == 2){
            scanf("%d%d%d", &x, &y, &k);
            update(1, n, 1, x, y, op, k);
        } else {
            scanf("%d%d", &x, &y);
            printf("%lld\n", query(1, n, 1, x, y));
        }
    }
    return 0;
}
上一篇下一篇

猜你喜欢

热点阅读