[模版]树链剖分

2018-08-13  本文已影响0人  emiya_d8a0

题面见(模版)树链剖分.

初看题面, 我不知道该怎么做, 我想我需要进行一次树链剖分, 在进行一次dfs序, 然后用线段树维护链操作, 用树状数组维护子树操作.

后来稍微想了一下, 树链剖分本来就是一种dfs序, 而子树操作只是需要多知道节点在dfs序中离开的时间.

这么想了之后, 题目就变得清晰了起来.

敲完代码之后, 还有三处错误, 经过样例改正了两个. 但是提交只能过2个测试数据, 虽然下载了一个测试数据, 但是出于某些原因, 没有看. 经过肉眼看了很多遍之后最后发现了错误之处.

  1. 错把mod作为了val;
  2. 标记下传时, 没有将标记下传;
  3. modify函数的最后没有update操作;
  4. 叶子节点的结束时间没有统计.

树链剖分并没有什么特别难的, 这道题也只是线段树操作比较多, 我最后在肯定了树剖部分没有写错之后对照了之前的树剖代码, 专心从线段树里找到了错误.

这道题出现的三个问题, 全都是出在线段树部分, 一是因为我有几天没有敲线段树的模版的, 二是我没有用之前用习惯的结构体线段树. 这需要提醒自己.

#include <iostream>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
const int MAXN = 1e5 + 600;
using namespace std;
int n, m, root, p;
int tIme;
int deep[MAXN], faz[MAXN], top[MAXN], son[MAXN];
int in[MAXN], out[MAXN], dfn[MAXN], size[MAXN];
int weight[MAXN];
int ihead[MAXN], edgeCnt;
struct Edge{
    int to, next;
} edges[MAXN << 1 | 1];
int s[MAXN << 2 | 1];
int add[MAXN << 2 | 1];
int ql, qh, val;
void init(){
    tIme = edgeCnt = 0;
     memset(deep, -1, sizeof(deep));
    // memset(faz, -1, sizeof(faz));
}
void insert(int u, int v){
    Edge & e = edges[++edgeCnt];
    e.to = v;
    e.next = ihead[u];
    ihead[u] = edgeCnt;
}
void dfs_first(int u, int father){
    deep[u] = deep[father] + 1;
    faz[u] = father;
    size[u] = 1;
    for(int i = ihead[u]; i; i = edges[i].next){
        Edge & e = edges[i];
        if(e.to == father)
            continue;

        dfs_first(e.to, u);
        size[u] += size[e.to];
        if(size[son[u]] < size[e.to])
            son[u] = e.to;
    }
}
void dfs_second(int u, int t){
    top[u] = t;
    in[u] = ++tIme;
    dfn[tIme] = u;
    if(!son[u]){
        out[u] = tIme;// 0. 最初略过了叶子节点的结束时间
        return;
    }

    dfs_second(son[u], t);
    for(int i = ihead[u]; i; i = edges[i].next){
        Edge & e = edges[i];
        if(e.to == son[u] || e.to == faz[u])
            continue;

        dfs_second(e.to, e.to);
    }
    out[u] = tIme;
}
inline void update(int o){
    s[o] = (s[o << 1] + s[o << 1 | 1]) % p;
}
inline void spread(int o, int lo ,int hi){
    if(add[o]){
        int mi = (lo + hi ) >> 1;
        s[o << 1] += (mi - lo) * add[o];// 1. 乘了p, 调了一会
        s[o << 1 | 1] += (hi - mi) * add[o];
        s[o << 1] %= p;
        s[o << 1 | 1] %= p;

        add[o << 1] += add[o];// 2. 标记没有下传, 调了一会
        add[o << 1 | 1] += add[o];

        add[o] = 0;
    }
}
void build(int o, int lo, int hi){
    if(hi - lo < 2){
        s[o] = weight[dfn[lo]] % p;
        add[o] = 0;
        return;
    }
    int mi = (lo + hi) >> 1;
    build(o << 1, lo, mi);
    build(o << 1 | 1, mi, hi);

    update(o);
}
void modify(int o, int lo, int hi){
    if(qh <= lo || hi <= ql){
        return;
    }
    if(ql <= lo && hi <= qh){
        s[o] += (hi - lo) * val;
        add[o] += val;
        s[o] %= p;

        return;
    }
    spread(o, lo, hi);
    int mi = (lo + hi) >> 1;

    modify(o << 1, lo, mi);
    modify(o << 1 | 1, mi, hi);

    update(o);// 3.没有更新操作, 调了几个小时
}
int query(int o, int lo, int hi){
    if(qh <= lo || hi <= ql){
        return 0;
    }
    if(ql <= lo && hi <= qh){
        return s[o];
    }
    spread(o, lo, hi);
    int mi = (lo + hi) >> 1;

    return (query(o << 1, lo, mi) + query(o << 1 | 1, mi, hi)) % p;
}
inline int querY(int lo, int hi){
    ql = lo, qh = hi;
    return query(1, 1, tIme) % p;
}
int getSum(int lo, int hi){
    int fx = top[lo], fy = top[hi];
    int rst = 0;
    while(fx != fy){
        if(deep[fx] < deep[fy])
            swap(fx, fy), swap(lo, hi);
        rst += querY(in[fx], in[lo] + 1);
        rst %= p;
        lo = faz[fx], fx = top[lo];
    }
    if(deep[lo] > deep[hi])
        swap(lo, hi);
    rst += querY(in[lo], in[hi] + 1);
    return rst % p;
}
void modifY(int lo, int hi){
    int fx = top[lo], fy = top[hi];
    while(fx != fy){
        if(deep[fx] < deep[fy])
            swap(fx, fy), swap(lo, hi);
        ql = in[fx], qh = in[lo] + 1;
        modify(1, 1, tIme);

        lo = faz[fx], fx = top[lo];
    }
    if(deep[lo] > deep[hi])
        swap(lo, hi);

    ql = in[lo], qh = in[hi] + 1;
    modify(1, 1, tIme);
}
void print(int o, int lo, int hi){
    if(hi - lo < 2){
        printf(" %d", s[o]);
        return;
    }
    spread(o, lo, hi);

    int mi = (lo + hi) >> 1;
    print(o << 1, lo, mi);
    print(o << 1 | 1, mi, hi);
}
int main(){
    init();

    scanf("%d%d%d%d", &n, &m, &root, &p);
    for(int i = 1; i <= n; ++i){
        scanf("%d", weight + i);
    }
    int u, v;
    for(int i = 1; i < n; ++i){
        scanf("%d%d", &u, &v);
        insert(u, v);
        insert(v, u);
    }
    dfs_first(root, 0);
    dfs_second(root, root);
    build(1, 1, ++tIme);

    int opt, x, y, z;
    while(m--){
        scanf("%d", &opt);
        switch(opt){
            case 1:
                scanf("%d%d%d", &x, &y, &z);
                val = z;
                modifY(x, y);
                break;
            case 2:
                scanf("%d%d", &x, &y);
                printf("%d\n", getSum(x, y));
                break;
            case 3:
                scanf("%d%d", &x, &z);
                val = z;
                ql = in[x], qh = out[x] + 1;
                modify(1, 1, tIme);
                break;
            case 4:
                scanf("%d", &x);
                ql = in[x], qh = out[x] + 1;
                printf("%d\n", query(1, 1, tIme));
                break;
        }
    }
    return 0;
}

上一篇下一篇

猜你喜欢

热点阅读