제가 안까먹으려고 하나의 글에 모두 정리해둡니다.
아래 justicehui님의 splay tree글을 많이 참고하여 작성합니다.(사실상 정리)
https://justicehui.github.io/hard-algorithm/2018/11/12/SplayTree1/
https://justicehui.github.io/hard-algorithm/2018/11/13/SplayTree2/
https://justicehui.github.io/hard-algorithm/2019/10/22/SplayTree3/
https://justicehui.github.io/hard-algorithm/2019/10/23/SplayTree4/
Splay Tree?
스플레이 트리는 BBST의 한 종류로 삽입/삭제/검색 모두 amortized O(logN)에 해주는 자료구조이며 다른 BBST에 비해 구현이 쉽다는 특징이 있습니다. 또한 구간 뒤집기 연산도 가능합니다.
스플레이 트리 노드 구조체
struct node {
node* l;
node* r;
node* p;
int key;
}
왼쪽 자식, 오른쪽 자식, 부모 노드, key값을 갖도록 합니다.
Rotate
rotate 함수를 통해 특정 노드를 자신의 부모 노드와 자리를 바꾸도록 합니다.
위 그림처럼 X노드를 P노드와 자리를 바꾸고 트리의 형태에 따라 옮긴 후 노드간의 관계를 재설정 하도록하는 함수를 만들것입니다.
void rotate(node* x){ // 위로 올릴 노드
node* p = x -> p; // x의 부모 노드
node* b = NULL; // b노드
if(!p) return; // x가 루트라면 리턴
if (x == p->l){ // x가 p의 왼쪽 자식
p -> l = b = x -> r; // p의 왼쪽을 b노드로
x -> r = p; // x의 오른쪽 자식을 p노드로
}else{ // x가 p의 오른쪽 자식
p -> r = b = x -> l; // p의 오른쪽을 b노드로
x -> l = p; // x의 왼쪽 자식을 p노드로
}
// 부모노드 재설정
x -> p = p -> p; // x의 부모는 p의 부모
p -> p = x; // p의 부모는 x
if (b) b -> p = p; // b노드가 존재한다면 b의 부모는 p
// x가 루트노드가 아닐경우
// p노드가 p노드의 부모노드의 왼쪽이었다면 x를 왼쪽으로, 오른쪽이었다면 오른쪽으로
(x -> p ? p == x -> p -> l ? x -> p -> l : x -> p -> r : tree) = x;
}
node* b 는 위 그림에서의 B노드입니다.
Splay
이제 Splay 함수를 구현할것입니다. Splay 함수는 특정 노드를 루트노드까지 올리는 작업을 하게 될것입니다.
함수는 다음과 같은 순서로 진행됩니다.
1. x가 루트라면 종료합니다.
2. x의 부모가 루트라면 rotate(x)를 하고 종료합니다. (Zig step)
3. x의 조부모(부모의 부모)를 g라고 할 때 상황에 따라 나누어집니다.
3- 1. g-p 방향과 p-x 방향이 같다면, rotate(p) -> rotate(x)를 수행합니다. (Zig-Zig step)
3 - 2. 방향이 다르다면, rotate(x)를 두 번 수행합니다. (Zig-Zag step)
이걸 x가 루트가 될 때까지 반복합니다.
zig, zig-zig, zig-zeg 과정은 아래 사진과 같이 진행됩니다.
1. zig step
2. zig-zig step
x와 p가 둘 다 왼쪽 자식이거나 둘 다 오른쪽 자식일 경우
rotate(p) 후에 rotate(x)를 해줍니다.
3. zig-zag step
함수로 구현하면 다음과 같습니다.
void splay(node* x){
while(x->p){ // x가 루트가 될때까지 반복
node* p = x -> p;
node* g = p -> p;
if (g){
if ((x == p -> l) == (p == g -> l)) rotate(p); // zig-zig
else rotate(x); // zig-zag;
}
rotate(x);
}
}
이제 insert, find, delete 함수를 차례차례 보겠습니다.
Insert
삽입 방식은 일반적인 BST와 같고 마지막에 splay를 해서 루트로 올려줍니다.
void insert(int key){
node* p =tree;
node** pp; // 넣을 위치
if(!p){ // 빈 트리인 경우
node* x = new node;
tree = x;
x -> l = x -> r = x -> p = NULL;
x -> key = key;
return;
}
while(true){ // 삽입할 위치를 찾을때까지 반복
if (key == p -> key) return; // 중복값이 있을 경우 종료
if (key < p -> key){ // 현재 노드의 값보다 작을경우
if (!p -> l){ // 왼쪽이 비어있으면 왼쪽에 삽입
pp = &p->l;
break;
}
p = p->l; // 비어있지 않으면 왼쪽 자식으로 이동
}else{ // 현재 노드의 값보다 크다
if (!p->r){
pp = &p->r;
break;
}
p = p->r;
}
}
// 위에서 삽입할 위치를 찾음
node* x = new node;
*pp = x;
x -> l = x -> r = NULL;
x-> p = p;
x-> key = key;
// splay
splay(x);
}
find
find도 또한 비슷하게 진행되며
마지막에 splay를 해주어 다른 연산을 쉽게 해주도록 합니다.
bool find(int key){
node* p = tree;
if(!p) return false; // 비어있는 트리
while(p){
if (key == p -> key) break; // 찾았다
if (key < p -> key){ // 찾으려는게 현재 값보다 작음
if (!p -> l) break; // 왼쪽이 없다 -> 탐색 실패
p = p -> l; // 왼쪽으로 이동
}else{ // 찾으려는게 현재 값보다 큼
if (!p -> r) break; // 오른쪽이 없다 -> 탐색 실패
p = p -> r; // 오른쪽으로 이동
}
}
splay(p); // 마지막에 탐색한 노드를 루트로 올림
return key == p->key; // 탐색한 노드의 값이 찾는값이랑 같은지 return
}
Delete
삭제연산입니다.
void del(int key){
if (!find(key)) return; // 삭제할 값이 없으면 루트로 이동 후 종료
node* p = tree;
if (p -> l && p -> r){ // 자식이 두개
tree = p->l; // 왼쪽 자식이 새로운 루트
tree->p = NULL;
// 오른쪽 서브트리를 왼쪽 서브트리 아래에 삽입
node* x = tree;
while(x -> r) x = x -> r;
x -> r = p -> r;
p -> r -> p = x;
delete p;
return;
}
if(p->l){ //자식이 왼쪽만 있는 경우
tree = p->l;
tree->p = NULL; //왼쪽 자식이 새로운 루트
delete p; //노드 삭제
return;
}
if(p->r){ //자식이 오른쪽만 있는 경우
tree = p->r;
tree->p = NULL; //오른쪽 자식이 새로운 루트
delete p; //노드 삭제
return;
}
// 노드가 자신 하나인경우
delete p;
tree = NULL;
}
K-th element
k번째 원소를 찾는 함수도 구현해봅시다
k번째 원소를 찾으려면 각 노드를 루트로하는 서브트리의 크기를 알아야 합니다.
구조체에 변수 하나를 추가해줍니다.
struct node{
node* l;
node* r;
node* p;
int key, cnt;
}
cnt변수의 값은 rotate가 될 때마다 계속 바뀔 것입니다.
cnt변수의 갱신을 담당할 update 함수를 만들어 줍니다.
void update(node* x){
x -> cnt = 1;
if (x -> l) x->cnt += x->l->cnt;
if (x -> r) x->cnt += x->r->cnt;
}
rotate 할 때마다 cnt값이 바뀐다고 했으니 rotate의 마지막 부분에 update함수를 추가해주면 됩니다.
void rotate(node* x){ // 위로 올릴 노드
node* p = x -> p; // x의 부모 노드
node* b = NULL; // b노드
if(!p) return; // x가 루트라면 리턴
if (x == p->l){ // x가 p의 왼쪽 자식
p -> l = b = x -> r; // p의 왼쪽을 b노드로
x -> r = p; // x의 오른쪽 자식을 p노드로
}else{ // x가 p의 오른쪽 자식
p -> r = b = x -> l; // p의 오른쪽을 b노드로
x -> l = p; // x의 왼쪽 자식을 p노드로
}
// 부모노드 재설정
x -> p = p -> p; // x의 부모는 p의 부모
p -> p = x; // p의 부모는 x
if (b) b -> p = p; // b노드가 존재한다면 b의 부모는 p
// x가 루트노드가 아닐경우
// p노드가 p노드의 부모노드의 왼쪽이었다면 x를 왼쪽으로, 오른쪽이었다면 오른쪽으로
(x -> p ? p == x -> p -> l ? x -> p -> l : x -> p -> r : tree) = x;
update(p), update(x);
}
roatate가 끝난 시점에서, p노드는 x노드의 자식노드이기 때문에
p노드에 대해 먼저 update후에 x노드를 update 해줘야 합니다.
이제 모든 노드의 서브트리의 크기를 cnt에 저장해 뒀으니 k-th element를 구할 수 있습니다.
void kth(int k){
node* x = tree;
while(1){
while(x->l && x->l->cnt > k) x = x->l;
if (x -> l) x -= x->l->cnt;
if(!k--)break;
x = x->r;
}
splay(x);
}
여기까지가 기본연산들에 대해 알아봤습니다.
이제 문제에 어떻게 적용할 수 있는지 알아보겠습니다.
대부분 배열에서의 쿼리 문제를 다룰 것입니다.
BBST이기 때문에 중위순회를 하면 항상 같은 순서로 탐색한다는 점을 이용합니다.
따라서 배열의 인덱스를 노드의 key로 생각해주면 배열의 원소를 스플레이트리로 관리할 수 있습니다.
간단한 구간합과 RMQ, 그리고 구간을 뒤집는 flip연산을 살펴보겠습니다.
splay연산을 통해 어떤 구간 [s, e]를 하나의 노드로 모으는 것에 집중해봅니다.
구간 [s, e]를 하나의 노드에 모으는 방법을 알아봅시다
s - 1번째 노드를 splay 해주면
s-1번째 노드가 루트가 되고 왼쪽 자식은 s - 2까지, 오른쪽 자식은 s부터의 구간에 대한 정보를 담고 있을것입니다.
그런 다음에 e + 1번 노드를 s - 1번 노드 오른쪽에 붙여주면
이렇게 될 것이고
루트의 오른쪽 자식의 왼쪽 자식에 [s, e] 구간에 대한 정보가 있을것입니다.
이것을 위해 splay 함수를 수정해서 루트로 올리는 함수에서 특정 노드를 다른 노드의 자식으로 가도록 수정합니다.
void splay(node* x, node* g = nullptr){
node* y;
while(x -> p != g){
node* p = x -> p;
if (p -> p == g){ // zig step
rotate(x);
break;
}
auto pp = p -> p;
if ((p -> l == x) == (pp -> l == p)){ // zig-zig step
rotate(p);
rotate(x);
}else{ // zig-zag step
rotate(x);
rotate(x);
}
}
if(!g) tree = x;
}
구간 [s, e] 를 모으는 gather 함수는 다음과 같습니다.
node* gather(int s, int e){ // [s, e]구간을 관리하는 노드를 리턴
kth(e + 1);
auto tmp = tree;
kth(s - 1);
splay(tmp, tree);
return tree->r->l;
}
이제 구간을 뒤집는 flip 함수를 보겠습니다.
구간을 뒤집는 연산은 왼쪽 서브트리와 오른쪽 서브트리를 swap하는 것을 재귀적으로 해주면 됩니다.
재귀적으로 처리하면 느리기 때문에 flip됐는지 여부를 lazy하게 처리해줍니다.
구조체에 변수를 추가합니다.
struct node{
// 다른 변수들
bool flip;
}
그리고 flip됐는지 여부를 전파할 함수를 만듭니다
void push(node *x){
if (!x -> flip) return;
swap(x->l, x->r);
if (x -> l) x->l->flip = !x->l->flip;
if (x->r) x->r->flip = !x->r->flip;
x->flip = false;
}
flip함수는 다음과 같이 간단하게 구현이 가능합니다.
void flip(int s, int e){
node* x = gather(s, e);
x->flip = !x->flip;
}
실제 문제에 적용하면 또 코드가 구현하기 편하도록 수정이 됩니다.
이것은 문제 풀이 카테고리에 문제를 직접 풀며 올려보겠습니다.
전체 소스
struct node{
node* l;
node* r;
node* p;
int key, cnt;
bool flip;
}*tree;
void update(node* x);
void rotate(node* x){ // 위로 올릴 노드
node* p = x -> p; // x의 부모 노드
node* b = NULL; // b노드
if(!p) return; // x가 루트라면 리턴
if (x == p->l){ // x가 p의 왼쪽 자식
p -> l = b = x -> r; // p의 왼쪽을 b노드로
x -> r = p; // x의 오른쪽 자식을 p노드로
}else{ // x가 p의 오른쪽 자식
p -> r = b = x -> l; // p의 오른쪽을 b노드로
x -> l = p; // x의 왼쪽 자식을 p노드로
}
// 부모노드 재설정
x -> p = p -> p; // x의 부모는 p의 부모
p -> p = x; // p의 부모는 x
if (b) b -> p = p; // b노드가 존재한다면 b의 부모는 p
// x가 루트노드가 아닐경우
// p노드가 p노드의 부모노드의 왼쪽이었다면 x를 왼쪽으로, 오른쪽이었다면 오른쪽으로
(x -> p ? p == x -> p -> l ? x -> p -> l : x -> p -> r : tree) = x;
update(p), update(x);
}
void splay(node* x, node* g = nullptr){
node* y;
while(x -> p != g){
node* p = x -> p;
if (p -> p == g){
rotate(x);
break;
}
auto pp = p -> p;
if ((p -> l == x) == (pp -> l == p)){
rotate(p);
rotate(x);
}else{
rotate(x);
rotate(x);
}
}
if(!g) tree = x;
}
void insert(int key){
node* p =tree;
node** pp; // 넣을 위치
if(!p){ // 빈 트리인 경우
node* x = new node;
tree = x;
x -> l = x -> r = x -> p = NULL;
x -> key = key;
return;
}
while(true){ // 삽입할 위치를 찾을때까지 반복
if (key == p -> key) return; // 중복값이 있을 경우 종료
if (key < p -> key){ // 현재 노드의 값보다 작을경우
if (!p -> l){ // 왼쪽이 비어있으면 왼쪽에 삽입
pp = &p->l;
break;
}
p = p->l; // 비어있지 않으면 왼쪽 자식으로 이동
}else{ // 현재 노드의 값보다 크다
if (!p->r){
pp = &p->r;
break;
}
p = p->r;
}
}
// 위에서 삽입할 위치를 찾음
node* x = new node;
*pp = x;
x -> l = x -> r = NULL;
x-> p = p;
x-> key = key;
// splay
splay(x);
}
bool find(int key){
node* p = tree;
if(!p) return false; // 비어있는 트리
while(p){
if (key == p -> key) break; // 찾았다
if (key < p -> key){
if (!p -> l) break;
p = p -> l;
}else{
if (!p -> r) break;
p = p -> r;
}
}
splay(p);
return key == p->key;
}
void del(int key){
if (!find(key)) return; // 삭제할 값이 없으면 루트로 이동 후 종료
node* p = tree;
if (p -> l && p -> r){ // 자식이 두개
tree = p->l; // 왼쪽 자식이 새로운 루트
tree->p = NULL;
// 오른쪽 서브트리를 왼쪽 서브트리 아래에 삽입
node* x = tree;
while(x -> r) x = x -> r;
x -> r = p -> r;
p -> r -> p = x;
delete p;
return;
}
if(p->l){ //자식이 왼쪽만 있는 경우
tree = p->l;
tree->p = NULL; //왼쪽 자식이 새로운 루트
delete p; //노드 삭제
return;
}
if(p->r){ //자식이 오른쪽만 있는 경우
tree = p->r;
tree->p = NULL; //오른쪽 자식이 새로운 루트
delete p; //노드 삭제
return;
}
// 노드가 자신 하나인경우
delete p;
tree = NULL;
}
void update(node* x){
x -> cnt = 1;
if (x -> l) x->cnt += x->l->cnt;
if (x -> r) x->cnt += x->r->cnt;
}
void kth(int k){
node* x = tree;
while(1){
while(x->l && x->l->cnt > k) x = x->l;
if (x -> l) x -= x->l->cnt;
if(!k--)break;
x = x->r;
}
splay(x);
}
node* gather(int s, int e){ // [s, e]구간을 관리하는 노드를 리턴
kth(e + 1);
auto tmp = tree;
kth(s - 1);
splay(tmp, tree);
return tree->r->l;
}
void push(node *x){
if (!x -> flip) return;
swap(x->l, x->r);
if (x -> l) x->l->flip = !x->l->flip;
if (x->r) x->r->flip = !x->r->flip;
x->flip = false;
}
void flip(int s, int e){
node* x = gather(s, e);
x->flip = !x->flip;
}
사진 출처 :
https://justicehui.github.io/hard-algorithm/2019/10/23/SplayTree4/