[模版]树链剖分
2018-08-13 本文已影响0人
emiya_d8a0
题面见(模版)树链剖分.
初看题面, 我不知道该怎么做, 我想我需要进行一次树链剖分, 在进行一次dfs序, 然后用线段树维护链操作, 用树状数组维护子树操作.
后来稍微想了一下, 树链剖分本来就是一种dfs序, 而子树操作只是需要多知道节点在dfs序中离开的时间.
这么想了之后, 题目就变得清晰了起来.
敲完代码之后, 还有三处错误, 经过样例改正了两个. 但是提交只能过2个测试数据, 虽然下载了一个测试数据, 但是出于某些原因, 没有看. 经过肉眼看了很多遍之后最后发现了错误之处.
- 错把mod作为了val;
- 标记下传时, 没有将标记下传;
- modify函数的最后没有update操作;
- 叶子节点的结束时间没有统计.
树链剖分并没有什么特别难的, 这道题也只是线段树操作比较多, 我最后在肯定了树剖部分没有写错之后对照了之前的树剖代码, 专心从线段树里找到了错误.
这道题出现的三个问题, 全都是出在线段树部分, 一是因为我有几天没有敲线段树的模版的, 二是我没有用之前用习惯的结构体线段树. 这需要提醒自己.
#include <iostream>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
const int MAXN = 1e5 + 600;
using namespace std;
int n, m, root, p;
int tIme;
int deep[MAXN], faz[MAXN], top[MAXN], son[MAXN];
int in[MAXN], out[MAXN], dfn[MAXN], size[MAXN];
int weight[MAXN];
int ihead[MAXN], edgeCnt;
struct Edge{
int to, next;
} edges[MAXN << 1 | 1];
int s[MAXN << 2 | 1];
int add[MAXN << 2 | 1];
int ql, qh, val;
void init(){
tIme = edgeCnt = 0;
memset(deep, -1, sizeof(deep));
// memset(faz, -1, sizeof(faz));
}
void insert(int u, int v){
Edge & e = edges[++edgeCnt];
e.to = v;
e.next = ihead[u];
ihead[u] = edgeCnt;
}
void dfs_first(int u, int father){
deep[u] = deep[father] + 1;
faz[u] = father;
size[u] = 1;
for(int i = ihead[u]; i; i = edges[i].next){
Edge & e = edges[i];
if(e.to == father)
continue;
dfs_first(e.to, u);
size[u] += size[e.to];
if(size[son[u]] < size[e.to])
son[u] = e.to;
}
}
void dfs_second(int u, int t){
top[u] = t;
in[u] = ++tIme;
dfn[tIme] = u;
if(!son[u]){
out[u] = tIme;// 0. 最初略过了叶子节点的结束时间
return;
}
dfs_second(son[u], t);
for(int i = ihead[u]; i; i = edges[i].next){
Edge & e = edges[i];
if(e.to == son[u] || e.to == faz[u])
continue;
dfs_second(e.to, e.to);
}
out[u] = tIme;
}
inline void update(int o){
s[o] = (s[o << 1] + s[o << 1 | 1]) % p;
}
inline void spread(int o, int lo ,int hi){
if(add[o]){
int mi = (lo + hi ) >> 1;
s[o << 1] += (mi - lo) * add[o];// 1. 乘了p, 调了一会
s[o << 1 | 1] += (hi - mi) * add[o];
s[o << 1] %= p;
s[o << 1 | 1] %= p;
add[o << 1] += add[o];// 2. 标记没有下传, 调了一会
add[o << 1 | 1] += add[o];
add[o] = 0;
}
}
void build(int o, int lo, int hi){
if(hi - lo < 2){
s[o] = weight[dfn[lo]] % p;
add[o] = 0;
return;
}
int mi = (lo + hi) >> 1;
build(o << 1, lo, mi);
build(o << 1 | 1, mi, hi);
update(o);
}
void modify(int o, int lo, int hi){
if(qh <= lo || hi <= ql){
return;
}
if(ql <= lo && hi <= qh){
s[o] += (hi - lo) * val;
add[o] += val;
s[o] %= p;
return;
}
spread(o, lo, hi);
int mi = (lo + hi) >> 1;
modify(o << 1, lo, mi);
modify(o << 1 | 1, mi, hi);
update(o);// 3.没有更新操作, 调了几个小时
}
int query(int o, int lo, int hi){
if(qh <= lo || hi <= ql){
return 0;
}
if(ql <= lo && hi <= qh){
return s[o];
}
spread(o, lo, hi);
int mi = (lo + hi) >> 1;
return (query(o << 1, lo, mi) + query(o << 1 | 1, mi, hi)) % p;
}
inline int querY(int lo, int hi){
ql = lo, qh = hi;
return query(1, 1, tIme) % p;
}
int getSum(int lo, int hi){
int fx = top[lo], fy = top[hi];
int rst = 0;
while(fx != fy){
if(deep[fx] < deep[fy])
swap(fx, fy), swap(lo, hi);
rst += querY(in[fx], in[lo] + 1);
rst %= p;
lo = faz[fx], fx = top[lo];
}
if(deep[lo] > deep[hi])
swap(lo, hi);
rst += querY(in[lo], in[hi] + 1);
return rst % p;
}
void modifY(int lo, int hi){
int fx = top[lo], fy = top[hi];
while(fx != fy){
if(deep[fx] < deep[fy])
swap(fx, fy), swap(lo, hi);
ql = in[fx], qh = in[lo] + 1;
modify(1, 1, tIme);
lo = faz[fx], fx = top[lo];
}
if(deep[lo] > deep[hi])
swap(lo, hi);
ql = in[lo], qh = in[hi] + 1;
modify(1, 1, tIme);
}
void print(int o, int lo, int hi){
if(hi - lo < 2){
printf(" %d", s[o]);
return;
}
spread(o, lo, hi);
int mi = (lo + hi) >> 1;
print(o << 1, lo, mi);
print(o << 1 | 1, mi, hi);
}
int main(){
init();
scanf("%d%d%d%d", &n, &m, &root, &p);
for(int i = 1; i <= n; ++i){
scanf("%d", weight + i);
}
int u, v;
for(int i = 1; i < n; ++i){
scanf("%d%d", &u, &v);
insert(u, v);
insert(v, u);
}
dfs_first(root, 0);
dfs_second(root, root);
build(1, 1, ++tIme);
int opt, x, y, z;
while(m--){
scanf("%d", &opt);
switch(opt){
case 1:
scanf("%d%d%d", &x, &y, &z);
val = z;
modifY(x, y);
break;
case 2:
scanf("%d%d", &x, &y);
printf("%d\n", getSum(x, y));
break;
case 3:
scanf("%d%d", &x, &z);
val = z;
ql = in[x], qh = out[x] + 1;
modify(1, 1, tIme);
break;
case 4:
scanf("%d", &x);
ql = in[x], qh = out[x] + 1;
printf("%d\n", query(1, 1, tIme));
break;
}
}
return 0;
}