其他分享
首页 > 其他分享> > 万恶的平衡树

万恶的平衡树

作者:互联网

以普通平衡树为例

有旋Treap

#include <iostream>
#include <cstdlib>
using namespace std;
inline int in(){
	int x = 0;
	bool f = 1;
	char c = getchar();
	while(c > '9' || c < '0'){
		if(c == '-') f = 0;
		c = getchar();
	}
	while(c <= '9' && c >= '0'){
		x = (x << 1) + (x << 3) + (c ^ 48);
		c = getchar();
	}
	if(f) return x;
	else return -x;
}
const int N = 1e5+10;
const int inf = 0x7fffffff;
struct treap{
	int l,r;//左右子树 
	int val;//关键码
	int dat;//权值 
	int sz;//子树大小
	int cnt;//相同的关键码的个数,插入++,删除--,减为0删除该节点 
}a[N];
int n,tot,root;
inline int New(int val){
	a[++tot].val = val;
	a[tot].dat = rand();
	a[tot].cnt = a[tot].sz = 1;
	return tot;
}
inline void update(int rt){
	a[rt].sz = a[a[rt].l].sz + a[a[rt].r].sz + a[rt].cnt;
}
inline void build(){
	New(-inf);
	New(inf);
	root = 1;
	a[1].r = 2;
	update(root);
}
inline void zag(int &rt){//左旋
	int x = a[rt].r;
	a[rt].r = a[x].l;
	a[x].l = rt;
	rt = x;
	update(a[rt].l);
	update(rt);
}
inline void zig(int &rt){//右旋
	int x = a[rt].l;
	a[rt].l = a[x].r;
	a[x].r = rt;
	rt = x;
	update(a[rt].r);
	update(rt);
}
inline void insert(int &rt,int val){
	if(!rt){
		rt = New(val);
		return;
	}
	if(val == a[rt].val){
		a[rt].cnt++;
		update(rt);
		return;
	}
	if(a[rt].val < val){
		insert(a[rt].r,val);
		if(a[rt].dat < a[a[rt].r].dat) zag(rt);//不满足堆性质,左旋 
	}else{
		insert(a[rt].l,val);
		if(a[rt].dat < a[a[rt].l].dat) zig(rt);//右旋
	}
	update(rt);
}
inline void dlt(int &rt,int val){
	if(!rt) return;
	if(a[rt].val == val){
		if(a[rt].cnt > 1){
			a[rt].cnt--;
			update(rt);
			return;
		}
		if(a[rt].l || a[rt].r){//如果不是叶子节点,转下去 
			if(!a[rt].r || a[a[rt].l].dat > a[a[rt].r].dat){
				zig(rt);
				dlt(a[rt].r,val);
			}else{
				zag(rt);
				dlt(a[rt].l,val);
			}
			update(rt);
		}else rt = 0;
		return;
	}
	if(val < a[rt].val) dlt(a[rt].l,val);
	else dlt(a[rt].r,val);
	update(rt);
}
inline int gtr(int rt,int val){
	if(!rt) return 0;
	if(val == a[rt].val) return a[a[rt].l].sz + 1;
	if(val < a[rt].val) return gtr(a[rt].l,val);
	else return gtr(a[rt].r,val) + a[a[rt].l].sz + a[rt].cnt;
}
inline int gtv(int rt,int rank){
	if(!rt) return 0;
	if(a[a[rt].l].sz >= rank) return gtv(a[rt].l,rank);
	else if(a[a[rt].l].sz + a[rt].cnt >= rank) return a[rt].val;
	else return gtv(a[rt].r,rank-a[a[rt].l].sz-a[rt].cnt);
}
inline int gtp(int val){
	int ans = 1;//a[1].val = -inf
	int rt = root;
	while(rt){
		if(val == a[rt].val){
			if(a[rt].l){
				rt = a[rt].l;
				while(a[rt].r) rt = a[rt].r;
				ans = rt;
			}
			break;
		}
		if(a[rt].val < val && a[rt].val > a[ans].val) ans = rt;
		if(val < a[rt].val) rt = a[rt].l;
		else rt = a[rt].r;
	}
	return a[ans].val;
}
inline int gtn(int val){
	int ans = 2;//a[2].val = inf
	int rt = root;
	while(rt){
		if(val == a[rt].val){
			if(a[rt].r){
				rt = a[rt].r;
				while(a[rt].l) rt = a[rt].l;
				ans = rt;
			}
			break;
		}
		if(a[rt].val > val && a[rt].val < a[ans].val) ans = rt;
		if(val < a[rt].val) rt = a[rt].l;
		else rt = a[rt].r;
	}
	return a[ans].val;
}
int main(){
	build();
	n = in();
	int opt,x;
	while(n--){
		opt = in();
		x = in();
		if(opt == 1) insert(root,x);
		else if(opt == 2) dlt(root,x);
		else if(opt == 3) printf("%d\n",gtr(root,x)-1);
		else if(opt == 4) printf("%d\n",gtv(root,x+1));
		else if(opt == 5) printf("%d\n",gtp(x));
		else if(opt == 6) printf("%d\n",gtn(x)); 
	}
	return 0;
}

Splay

#include <iostream>
#include <cstdlib>
#define fa(x) a[x].fa
#define ls(x) a[x].ch[0]
#define rs(x) a[x].ch[1]
#define sz(x) a[x].sz
#define cnt(x) a[x].cnt
#define val(x) a[x].val
using namespace std;
inline int in(){
    int x = 0;
    bool f = 1;
    char c = getchar();
    while(c > '9' || c < '0'){
        if(c == '-') f = 0;
        c = getchar();
    }
    while(c <= '9' && c >= '0'){
        x = (x << 1) + (x << 3) + (c ^ 48);
        c = getchar();
    }
    if(f) return x;
    else return -x;
}
const int N = 1e5+100;
const int inf = 0x7ffffff;
struct splay{
	int fa;
	int ch[2];
	int val;
	int cnt;
	int sz;
}a[N];
int n,tot,root,ans;
inline void update(int rt){
	sz(rt) = sz(ls(rt)) + sz(rs(rt)) + cnt(rt);
}
inline bool get(int rt){
	return rt == rs(fa(rt));
}
inline void clear(int rt){
    fa(rt) = val(rt) = cnt(rt) = sz(rt) = ls(rt) = rs(rt) = 0;
}
inline void rotate(int rt){
	int f = fa(rt);
	int pa = fa(f);
    int cnt = get(rt);
    a[f].ch[cnt] = a[rt].ch[cnt^1];
    if(a[rt].ch[cnt^1]) fa(a[rt].ch[cnt^1]) = f;
    a[rt].ch[cnt^1] = f;
    fa(f) = rt;
    fa(rt) = pa;
    if(pa) a[pa].ch[f == rs(pa)] = rt;
    update(rt);
    update(f);
}
inline void splay(int rt){
	for(int f = fa(rt);f = fa(rt),f;rotate(rt))
        if(fa(f)) rotate(get(rt) == get(f) ? f : rt);
    root = rt;
}
inline void insert(int val){
	if(!root){
		val(++tot) = val;
		cnt(tot) = 1;
		root = tot;
		update(root);
		return;
	}
	int rt = root,f = 0;
	while(1){
		if(val(rt) == val){
			cnt(rt)++;
			update(rt);
			update(f);
			splay(rt);
			break;
		}
		f = rt;
        rt = a[rt].ch[val(rt) < val];
		if(!rt){
			val(++tot) = val;
			cnt(tot) = 1;
			fa(tot) = f;
            a[f].ch[val(f) < val] = tot;
			update(tot);
			update(f);
			splay(tot);
			break;
		}
	}
}
inline int rk(int val){//查询val的排名
    int res = 0,cnt = root;
    while(1){
        if(val < val(cnt)) cnt = ls(cnt);
        else{
            res += sz(ls(cnt));
            if(val == val(cnt)){
                splay(cnt);
                return res + 1;
            }
            res += cnt(cnt);
            cnt = rs(cnt);
        }
    }
}
inline int kth(int k){//查询排名为k的数
    int cnt = root;
    while(1){
        if(ls(cnt) && k <= sz(ls(cnt))) cnt = ls(cnt);
        else{
            k -= cnt(cnt) + sz(ls(cnt));
            if(k <= 0){
                splay(cnt);
                return val(cnt);
            }
            cnt = rs(cnt);
        }
    }
}
inline int pre(){
    int cnt = ls(root);
    if(!cnt) return cnt;
    while(rs(cnt)) cnt = rs(cnt);
    splay(cnt);
    return cnt;
}
inline int nxt(){
    int cnt = rs(root);
    if(!cnt) return cnt;
    while(ls(cnt)) cnt = ls(cnt);
    splay(cnt);
    return cnt;
}
inline void dlt(int val){
    rk(val);
    if(cnt(root) > 1){
        cnt(root)--;
        update(root);
        return;
    }
    if(!ls(root) && !rs(root)){
        clear(root);
        root = 0;
        return;
    }
    if(!ls(root)){
        int cnt = root;
        root = rs(root);
        fa(root) = 0;
        clear(cnt);
        return;
    }
    if(!rs(root)){
        int cnt = root;
        root = ls(root);
        fa(root) = 0;
        clear(cnt);
        return;
    }
    int cnt = root;
    int x = pre();
    fa(rs(cnt)) = x;
    rs(x) = rs(cnt);
    clear(cnt);
    update(root);
}
int main(){
	n = in();
	int x,opt;
    for(int i = 1;i <= n;++i){
        opt = in();
        x = in();
        if(opt == 1) insert(x);
        if(opt == 2) dlt(x);
        if(opt == 3) printf("%d\n",rk(x));
        if(opt == 4) printf("%d\n",kth(x));
        if(opt == 5){
            insert(x);
            printf("%d\n",val(pre()));
            dlt(x);
        }
        if(opt == 6){
            insert(x);
            printf("%d\n",val(nxt()));
            dlt(x);
        }
	}
    return 0;
}

裂开。。。。。TAT

标签:rt,万恶,cnt,return,val,int,平衡,root
来源: https://www.cnblogs.com/Kamisato-Ayaka/p/16388089.html