平衡树acm算法解析

splay学习笔记(上)

2018-08-20  本文已影响0人  MasterHZJ

前不久其实已经学过splay了,但是总觉得似乎不能灵活地改造它,于是重新学习了一波。

感谢https://oi.men.ci/splay-notes-1/
关于splay的解释这里说的也比较清楚
下面我们分成单点版和区间版讨论

入门题 Tyvj 1728 普通平衡树

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)

#include <bits/stdc++.h>
#define rep(i, a, b) for (int i=a; i<=b; i++)
#define drep(i, a, b) for (int i=b; i>=a; i--)
typedef long long LL;
using namespace std;

struct Splay {
    struct node {
        node *fa, *ch[2], **root;
        int x, size, cnt;
        node(node **root, node *fa, int x): root(root), fa(fa), x(x), cnt(1), size(1) {
            ch[0]=ch[1]=NULL;
        }
        inline int relation() {
            return this == fa->ch[0] ? 0 : 1;
        }
        inline void maintain() {
            size = cnt;
            if (ch[0]) size += ch[0]->size;
            if (ch[1]) size += ch[1]->size;
        }

        void rotate() {
            node *old=fa;
            int r=relation();

            fa=old->fa;
            if (old->fa) old->fa->ch[old->relation()]=this;
            if (ch[r^1]) ch[r^1]->fa=old;
            old->ch[r]=ch[r^1];
            old->fa=this;
            ch[r^1]=old;

            old->maintain();
            maintain();
            if (fa==NULL) *root=this;
        }
        void splay(node *target=NULL) {
            while (fa!=target) {
                if (fa->fa==target) rotate();
                else if (fa->relation()==relation()) {
                    fa->rotate();
                    rotate();
                }else {
                    rotate();
                    rotate();
                }
            }
        }
        node *pred() {
            node *v=ch[0];
            while (v->ch[1]) v=v->ch[1];
            return v;
        }//前驱precursor
        node *succ() {
            node *v = ch[1];
            while (v->ch[0]) v=v->ch[0];
            return v;
        }

        int rank() {
            return ch[0] ? ch[0]->size : 0;
        }
    } *root;
    Splay():root(NULL) {
        insert(INT_MAX);
        insert(INT_MIN);
    }
    node *insert(int x) {
        node **v = &root, *fa=NULL;
        while (*v!=NULL && (*v)->x!=x) {
            fa=*v;
            fa->size++;
            if (x<fa->x) v=&fa->ch[0]; else v=&fa->ch[1];
        }
        if (*v!=NULL) {
            (*v)->cnt++;
            (*v)->size++;
        }else (*v) = new node(&root, fa, x);
        (*v)->splay();
        return root;
    }
    node *find(int x) {
        node *v=root;
        while (v!=NULL && v->x != x) if (x<v->x) v=v->ch[0]; else v=v->ch[1];
        if (v) v->splay();
        return v;
    }
    void erase(node *v) {
        node *pred=v->pred(), *succ=v->succ();
        pred->splay();
        succ->splay(pred);
        if (v->size>1) {
            v->size--, v->cnt--;
        }else {
            delete succ->ch[0];
            succ->ch[0]=NULL;
        }
        succ->size--, pred->size--;
    }
    void erase(int x) {
        node *v=find(x);
        if (!v) return;
        erase(v);
    }
    int pred(int x) {
        node *v = find(x);
        if (v==NULL) {
            v=insert(x);
            int res=v->pred()->x;
            erase(v);
            return res;
        }else return v->pred()->x;
    }
    int succ(int x) {
        node *v=find(x);
        if (v==NULL) {
            v=insert(x);
            int res=v->succ()->x;
            erase(v);
            return res;
        }else return v->succ()->x;
    }
    int rank(int x) {
        node *v=find(x);
        if (v==NULL) {
            v=insert(x);
            int res=v->rank();
            erase(v);
            return res;
        }else return v->rank();
    }
    int select(int k) {
        node *v = root;
        while (!(k >= v->rank() && k < v->rank() + v->cnt)){
            if (k<v->rank()) v=v->ch[0]; else {
                k-=v->rank()+v->cnt;
                v=v->ch[1];
            }
        }
        v->splay();
        return v->x;
    }
}splay;

int main() {
    int n;
    scanf("%d", &n);
    while (n--) {
        int opt, x;
        scanf("%d %d", &opt, &x);
        if (opt==1) splay.insert(x);
        if (opt==2) splay.erase(x);
        if (opt==3) printf("%d\n", splay.rank(x));
        if (opt==4) printf("%d\n", splay.select(x));
        if (opt==5) printf("%d\n", splay.pred(x));
        if (opt==6) printf("%d\n", splay.succ(x));
    }
    return 0;
}

如果是打acm的话不太懂splay的原理其实没有太大关系,这个板子已经把splay的基本操作封装再node结构体里面了,可以理解成splay是一个一直在维护平衡的名次树。所以起码要理解名次树的性质:

左子树的值<根节点的值<右子树的值

上面的splay只支持单点操作,其实用线段树也可以实现(强烈推荐zkw的论文《统计的力量》,很精妙),下面我们来讨论区间操作的splay。

splay的区间操作对比线段树/数状数组,支持:

区间splay重要的操作是选择区间,比如要对区间[l,r]进行操作,我们要做的是将节点 l-1 Splay到根,再讲节点 r-1 splay到根节点的右儿子,那么根节点的右儿子的左子树就是区间[l, r] (根据 左子树的值<根节点的值<右子树的值 的性质)

其它区间求和,区间最小值,区间修改之类的类似与线段树,通过lazy标记来实现

我们来看一下这题

Tyvj1729 文艺平衡树

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:翻转一个区间,例如原有序序列是5 4 3 2 1,翻转区间是[2,4]的话,结果是5 2 3 4 1

需要注意的是,在上一题中,我们节点的权值是数的大小,在这一题中,我们的节点的权值是数的位置

#include <bits/stdc++.h>
#define rep(i, a, b) for (int i=a; i<=b; i++)
#define drep(i, a, b) for (int i=b; i>=a; i--)
typedef long long LL;
using namespace std;

template <typename T>
struct Splay {
    struct node{
        node *ch[2], *parent, **root;
        T value;
        int size;
        bool bound, reverse;
        node(node *parent, node **root, const T &value, bool bound=false, bool reverse=false):parent(parent), root(root), value(value), reverse(false), size(1), bound(bound){
            ch[0]=ch[1]=NULL;
        }
        ~node(){
            if (ch[0]) delete(ch[0]);
            if (ch[1]) delete(ch[1]);
        }
        inline int relation(){return this==parent->ch[0]?0:1;}
        inline int lsize(){return ch[0] ? ch[0]->size : 0;}
        inline int rsize(){return ch[1] ? ch[1]->size : 0;}
        inline void maintain(){size = lsize() + rsize() +1;}
        inline node *grandparent(){return !parent ? NULL : parent->parent;}
        void *pushdown(){
            if (reverse){
                //swap(ch[0], ch[1]);
                node *tmp=ch[0];
                ch[0]=ch[1];
                ch[1]=tmp;
                if (ch[0]) ch[0]->reverse^=1;
                if (ch[1]) ch[1]->reverse^=1;
                reverse = false;
            }
        }
        void rotate(){
            parent->pushdown(), pushdown();
            node *old=parent;
            int x=relation();
            if (grandparent()) grandparent()->ch[old->relation()] = this;
            parent=grandparent();
            old->ch[x] = ch[x^1];
            if (ch[x^1]) ch[x^1]->parent = old;
            ch[x^1]=old;
            old->parent=this;
            old->maintain(), maintain();
            if (!parent) *root=this;
        }
        node *splay(node **target=NULL){
            if (!target) target=root;
            while (this!=*target){
                parent->pushdown();
                if (parent == *target) rotate();
                else if (parent->relation() == relation()) parent->rotate(), rotate();
                else rotate(), rotate();
            }
            return *target;
        }
    }*root;
    ~Splay() {
        if (root) delete root;
    }
    void build(const T *a, int n){
        root = build(a, 1, n, NULL);
        rep(i, 0, 1){
            node *bound_parent=NULL, **bound=&root;
            while (*bound){
                bound_parent = *bound;
                bound_parent->size++;
                bound = &(*bound)->ch[i];
            }
            *bound=new node(bound_parent, &root, 0, true);
        }
    }//插入边界值
    node *build (const T *a, int l, int r, node *parent){
        if (l>r) return NULL;
        int mid=(l+r)>>1;
        node *v=new node(parent, &root, a[mid-1]);
        v->ch[0] = build(a, l, mid - 1, v);
        v->ch[1] = build(a, mid + 1, r, v);
        v->maintain();
        return v;
    }
    node *select(int k) {
        k++;
        node *v = root;
        while (v->pushdown(), k != v->lsize() + 1)
            if (k < v->lsize() + 1) v = v->ch[0]; else k -= v->lsize() + 1, v = v->ch[1];
        return v->splay();
    }
    node *&select(int l, int r) {
        node *lbound=select(l-1), *rbound=select(r+1);
        lbound->splay();
        rbound->splay(&lbound->ch[1]);
        return rbound->ch[0];
    }
    void reverse(int l, int r) {
        node *range = select(l, r);
        range->reverse ^= 1;
    }
    void fetch(T *a) {
        dfs(a, root);
    }
    void dfs(T *&a, node *v) {
        if (v) {
            v->pushdown();
            dfs(a, v->ch[0]);
            if (!v->bound) *a++=v->value;
            dfs(a, v->ch[1]);
        }
    }
};
Splay<int>splay;
const int MAXN=101000;
int n, m;
int a[MAXN];
int main() {
    scanf("%d%d", &n, &m);
    for (int i=0; i<n; i++) a[i]=i+1;
    splay.build(a, n);

    for (int i=0; i<m; i++) {
        int l, r;
        scanf("%d%d", &l, &r);
        splay.reverse(l, r);
    }
    splay.fetch(a);
    for (int i=0; i<n; i++) printf("%d ", a[i]);
    return 0;
}
上一篇下一篇

猜你喜欢

热点阅读