万能的Splay-处理区间问题
为什么会用到Splay呢,Splay虽然很强大但是不)同问题上替代品也有不少。其实主要是最近看到动态树、Link-cut Tree,里面需要维护一条路径的信息,而且路劲还是可分割,可合并的,这里Splay就是最佳的组件了。
Splay如果仅仅作为一棵二叉平衡查找树(BST)来使用的话的确效率是不尽人意,不过正如Splay的中文名—伸展树,它的发明不仅仅是作为一棵平庸的BST来使用的,较高的常数和看起来不怎么爽的平摊O(logN)的效率带来的是其非常丰富的功能。比如现在所讲的—如何用Splay处理区间问题。
给个例题,POJ3468,题目意思是:给一个连续的区间[1..N],区间里每个元素上面都有一个数值。每次有两种询问,第一种是给定区间[A,B],把里面所有元素的数值加上某个数字C,第二种询问是给定区间[A,B],求区间里所有元素数值的和。
其实这题用线段树还是能非常轻松的过掉的。不过今天就要自虐一下写个Splay。
首先呢,我们以元素的序号为关键字建一个二叉查找树(平衡不平衡这里其实无所谓,反正Splay几下就均摊平衡了。),树上的每个节点代表一个元素,为了方便起见我们可以再插入两个元素0和N+1(等下你就知道为什么需要这俩东西了。)。
接下来的问题呢,是如何把一段连续的区间在树上展现出来,一般人都会比较混沌,因为树上的结构看起来或许很乱。但是Splay在这里就发挥了他十分大的优势,比如,询问区间为[A,B],我们可以先把元素A-1旋到树根,这样询问区间里的所有元素都在树根的右子树里,接下来的手法类似,就是单独取下那棵右子树([0..A-1]区间里的元素就这样被我们忽略了),然后把B+1选到树根上,这样树根的左子树就是代表着我们需要处理的区间[A,B]了。
接下来的事情就比较简单了,对于第一种询问,我们可以对于每个节点多记录下一个新信息size,表示以该节点为根的子树下有多少个节点,然后再用类似线段树的标记传递的方法覆盖下去。对于第二种询问,我们可以对每个节点维护sum,表示以该节点为根的子树下所有元素的和是多少。然后问题就此解决了哈哈~
(Splay的基础知识参考NOI WinterCamp2004杨思雨的论文)
#include <cstdio> #include <cstring> #include <cstdlib> #include <iostream> using namespace std; typedef struct TSplayNode { int key, cover, size; __int64 data, sum; TSplayNode *lch, *rch, *pnt; }*PSplayNode; int n, q; PSplayNode root; void CoverExpand(PSplayNode cur) { if (cur->lch != NULL) { cur->lch->cover += cur->cover; cur->lch->data += cur->cover; cur->lch->sum += ((__int64)cur->cover) * cur->lch->size; } if (cur->rch != NULL) { cur->rch->cover += cur->cover; cur->rch->data += cur->cover; cur->rch->sum += ((__int64)cur->cover) * cur->rch->size; } cur->cover = 0; } void Update(PSplayNode cur) { cur->size = 1; if (cur->lch != NULL) cur->size += cur->lch->size; if (cur->rch != NULL) cur->size += cur->rch->size; if (cur->cover != 0) CoverExpand(cur); cur->sum = cur->data; if (cur->lch != NULL) cur->sum += cur->lch->sum; if (cur->rch != NULL) cur->sum += cur->rch->sum; } void LeftRotate(PSplayNode cur) { PSplayNode pnt = cur->pnt, anc = pnt->pnt; if (anc != NULL && anc->cover != 0) CoverExpand(anc); if (pnt->cover != 0) CoverExpand(pnt); if (cur->cover != 0) CoverExpand(cur); pnt->lch = cur->rch; if (cur->rch != NULL) cur->rch->pnt = pnt; cur->rch = pnt; pnt->pnt = cur; cur->pnt = anc; if (anc != NULL) { if (anc->lch == pnt) anc->lch = cur; else anc->rch = cur; } Update(pnt); Update(cur); } void RightRotate(PSplayNode cur) { PSplayNode pnt = cur->pnt, anc = pnt->pnt; if (anc != NULL && anc->cover != 0) CoverExpand(anc); if (pnt->cover != 0) CoverExpand(pnt); if (cur->cover != 0) CoverExpand(cur); pnt->rch = cur->lch; if (cur->lch != NULL) cur->lch->pnt = pnt; cur->lch = pnt; pnt->pnt = cur; cur->pnt = anc; if (anc != NULL) { if (anc->lch == pnt) anc->lch = cur; else anc->rch = cur; } Update(pnt); Update(cur); } void splay(PSplayNode cur) { PSplayNode pnt, anc; while (cur->pnt != NULL) { pnt = cur->pnt; anc = pnt->pnt; if (anc != NULL && anc->cover != 0) CoverExpand(anc); if (pnt->cover != 0) CoverExpand(pnt); if (cur->cover != 0) CoverExpand(cur); if (anc == NULL) { if (pnt->lch == cur) LeftRotate(cur); else RightRotate(cur); } else { if (pnt->lch == cur) LeftRotate(cur); else RightRotate(cur); if (anc->lch == cur) LeftRotate(cur); else RightRotate(cur); } } root = cur; } PSplayNode find(PSplayNode cur, int key) { while (cur->key != key) { if (cur->cover != 0) CoverExpand(cur); if (key < cur->key) { if (cur->lch == NULL) break; cur = cur->lch; } else if (key > cur->key) { if (cur->rch == NULL) break; cur = cur->rch; } } return cur; } void Insert(int key, int data) { PSplayNode cur = new(TSplayNode); memset(cur, 0, sizeof(TSplayNode)); cur->key = key; cur->data = cur->sum = data; cur->size = 1; if (root == NULL) root = cur; else { splay(find(root, key)); if (root->key < key) { cur->rch = root->rch; if (root->rch != NULL) root->rch->pnt = cur; root->rch = NULL; cur->lch = root; } else if (root->key > key) { cur->lch = root->lch; if (root->lch != NULL) root->lch->pnt = cur; root->lch = NULL; cur->rch = root; } root->pnt = cur; Update(root); Update(cur); root = cur; } } int main() { scanf("%d%d", &n, &q); root = NULL; Insert(0, 0); for (int i = 0; i < n; ++ i) { int a; scanf("%d", &a); Insert(i + 1, a); } Insert(n + 1, 0); scanf("\n"); for (int i = 0; i < q; ++ i) { char c; int a, b, p; PSplayNode tmp; scanf("%c", &c); if (c == 'Q') { scanf("%d%d\n", &a, &b); splay(find(root, a - 1)); tmp = root; root->rch->pnt = NULL; splay(find(root->rch, b + 1)); cout << root->lch->sum << endl; root->pnt = tmp; tmp->rch = root; root = tmp; } else { scanf("%d%d%d\n", &a, &b, &p); splay(find(root, a - 1)); tmp = root; root->rch->pnt = NULL; splay(find(root->rch, b + 1)); root->lch->cover += p; root->lch->data += p; root->lch->sum += ((__int64)p) * root->lch->size; root->pnt = tmp; tmp->rch = root; root = tmp; } } }
