C++을 컴파일할 때 G++ 컴파일러를 쓰는 사람은 PBDS(Policy-Based Data Structure)라고 해서 유용한 연산들을 많이 제공해 주는 자료 구조를 쓸 수 있다. 하지만 난 Xcode를 쓰기 때문에 PBDS를 못 쓴다! 그래서 그냥 내가 만들기로 했다. 밑의 코드는 AVL Tree를 통해 Ordered-Set을 구현한 것이다. Ordered-Set이란? 기존의 Balanced Binary Search Tree 기반의 Set에 몇몇 기능을 추가한 것이다. 주요 기능으로 다음 두 가지 기능이 있다.
- find-by-order : Set의 n번째 원소가 무엇인지를 O(log n) 안에 찾아준다.
- order-of-key : Set의 원소 k가 몇 번째로 작은 원소인지를 O(log n) 안에 찾아준다.
BBST의 각 노드들이 subtree의 원소의 개수를 갖고 있다면, 이 일은 어려운 일이 아닐 것이다. (...설명 작성 중)
template<class T>
class AVLTree;
template<class T>
class NODE {
T key;
int h, size, n_of_keys;
NODE *left, *right;
friend AVLTree<T>;
NODE(const T &_key, int _h = 0) : key(_key), h(_h), size(1), n_of_keys(1), left(NULL), right(NULL) {}
~NODE() {if (left) {delete left;} if (right) {delete right;}}
static NODE* insert(const T &_key, NODE *node) {
++node->size;
if (_key == node->key) {
++node->n_of_keys;
return node;
}
if (_key < node->key) {
if (node->left) {
node->left = insert(_key, node->left);
} else {
node->left = new NODE(_key);
}
} else {
if (node->right) {
node->right = insert(_key, node->right);
} else {
node->right = new NODE(_key);
}
}
node->recompute_height();
if (node->unbal()) {
return restruct(node);
} return node;
}
static NODE* remove(const T &_key, NODE *node) {
if (_key < node->key && node->left) {
node->left = remove(_key, node->left);
node->size = node->lsize() + node->rsize() + node->n_of_keys;
} else if (_key > node->key && node->right) {
node->right = remove(_key, node->right);
node->size = node->lsize() + node->rsize() + node->n_of_keys;
} else if (_key == node->key) {
--node->size;
--node->n_of_keys;
if (node->n_of_keys > 0) {
return node;
}
if (node->left && node->right) {
const NODE *temp = find_first(node->right);
node->key = temp->key;
node->n_of_keys = temp->n_of_keys;
node->right = remove_first(node->right, node->n_of_keys);
} else {
if (node->left) {
NODE *temp = node->left;
node->left = nullptr;
delete node;
return temp;
} else if (node->right) {
NODE *temp = node->right;
node->right = nullptr;
delete node;
return temp;
} else {
delete node;
return nullptr;
}
}
}
node->recompute_height();
if (node->unbal()) {
return restruct(node);
} return node;
}
static NODE* remove_nth(int idx, NODE *node) {
--node->size;
int lsz = node->lsize();
if (idx >= lsz + node->n_of_keys && node->right) {
node->right = remove_nth(idx - lsz - node->n_of_keys, node->right);
} else if (idx < lsz && node->left) {
node->left = remove_nth(idx, node->left);
} else {
cout << node->key << '\n';
--node->n_of_keys;
if (node->n_of_keys > 0) {
return node;
}
if (node->left && node->right) {
NODE *temp = find_first(node->right);
node->key = temp->key;
node->n_of_keys = temp->n_of_keys;
node->right = remove_first(node->right, node->n_of_keys);
} else {
if (node->left) {
NODE *temp = node->left;
node->left = nullptr;
delete node;
return temp;
} else if (node->right) {
NODE *temp = node->right;
node->right = nullptr;
delete node;
return temp;
} else {
delete node;
return nullptr;
}
}
}
node->recompute_height();
if (node->unbal()) {
return restruct(node);
} return node;
}
static const NODE* find_first(const NODE *node) {
while (node->left) {
node = node->left;
} return node;
}
static const NODE* find_last(const NODE *node) {
while (node->left) {
node = node->left;
} return node;
}
static NODE* remove_first(NODE *node, int sz) {
node->size -= sz;
if (node->left) {
node->left = remove_first(node->left, sz);
node->recompute_height();
if (node->unbal()) {
return restruct(node);
} return node;
} else if (node->right) {
NODE *temp = node->right;
node->right = nullptr;
delete node;
return temp;
} else {
delete node;
return nullptr;
}
}
bool __contains__(const T &_key) const {
if (_key == key) {
return true;
}
if (_key < key) {
if (left) {
return left->__contains__(_key);
} else {
return false;
}
} else {
if (right) {
return right->__contains__(_key);
} else {
return false;
}
}
}
int get_idx(const T &_key) const {
if (_key < key) {
if (left) {
int temp = left->get_idx(_key);
return temp;
} else {
return 0;
}
} else {
if (right) {
int temp = right->get_idx(_key);
return temp + lsize() + n_of_keys;
} else {
return lsize() + n_of_keys;
}
}
}
T get_nth(int idx) const {
int lsz = lsize();
if (idx > lsz) {
return right->get_nth(idx - lsz - n_of_keys);
} else if (idx < lsz) {
return left->get_nth(idx);
} else {
return key;
}
}
void recompute_height() {
int l = lh(), r = rh();
h = l < r ? r + 1 : l + 1;
}
bool unbal() const {
return lh() - rh() > 1 || rh() - lh() > 1;
}
int lh() const {
return left ? left->h : -1;
}
int rh() const {
return right ? right->h : -1;
}
int lsize() const {
return left ? left->size : 0;
}
int rsize() const {
return right ? right->size : 0;
}
static NODE* lrot(NODE *node) {
NODE *root = node->right;
node->right = root->left;
node->recompute_height();
root->left = node;
root->recompute_height();
node->size = node->lsize() + node->rsize() + node->n_of_keys;
root->size = root->lsize() + root->rsize() + root->n_of_keys;
return root;
}
static NODE* rrot(NODE *node) {
NODE *root = node->left;
node->left = root->right;
node->recompute_height();
root->right = node;
root->recompute_height();
node->size = node->lsize() + node->rsize() + node->n_of_keys;
root->size = root->lsize() + root->rsize() + root->n_of_keys;
return root;
}
static NODE* restruct(NODE *node) {
if (node->rh() > node->lh()) {
if (node->right->lh() > node->right->rh()) {
node->right = rrot(node->right);
} return lrot(node);
} else {
if (node->left->rh() > node->left->lh()) {
node->left = lrot(node->left);
} return rrot(node);
}
}
void print(int indent = 0) const {
if (left) {
left->print(indent + 2);
}
for (int i = 0; i < indent; ++i) {
cout << '_';
}
cout << "key " << key << " size " << size << " n " << n_of_keys << endl;
if (right) {
right->print(indent + 2);
}
}
};
template<class T>
class AVLTree {
NODE<T> *p;
public:
AVLTree() : p(nullptr) {}
~AVLTree() {if (p) {delete p;}}
AVLTree& insert(const T &k) {
if (p == nullptr) {
p = new NODE<T>(k);
} else {
p = NODE<T>::insert(k, p);
}
return *this;
}
AVLTree& remove(const T &k) {
if (p) {
p = NODE<T>::remove(k, p);
}
return *this;
}
AVLTree& remove_nth(int idx) {
if (p) {
p = NODE<T>::remove_nth(idx, p);
}
return *this;
}
AVLTree& display() {
if (p == nullptr) {
cout << "NO ELEMENTS IN THE SET\n";
} else {
p->print();
}
return *this;
}
int get_idx(const T &k) {
return p ? p->get_idx(k) : 0;
}
};
'Computer Science' 카테고리의 다른 글
[C++] 다항식 클래스와 라그랑주 보간법(Lagrange Interpolation) 구현 (0) | 2022.08.07 |
---|---|
[C++] SCC Maker class (0) | 2022.08.06 |
[C++] 선분 교차 판정 (0) | 2022.08.04 |
[C++] 유리수 클래스(class Fraction) 구현 (0) | 2022.08.04 |
[C++] class Modulo: 나머지환에서의 사칙연산을 구현해 보자 (0) | 2022.08.04 |
댓글