본문 바로가기
Computer Science

[C++] AVL Tree와 이를 이용한 Ordered-Set 구현

by invrtd.h 2022. 8. 4.

C++을 컴파일할 때 G++ 컴파일러를 쓰는 사람은 PBDS(Policy-Based Data Structure)라고 해서 유용한 연산들을 많이 제공해 주는 자료 구조를 쓸 수 있다. 하지만 난 Xcode를 쓰기 때문에 PBDS를 못 쓴다! 그래서 그냥 내가 만들기로 했다. 밑의 코드는 AVL Tree를 통해 Ordered-Set을 구현한 것이다. Ordered-Set이란? 기존의 Balanced Binary Search Tree 기반의 Set에 몇몇 기능을 추가한 것이다. 주요 기능으로 다음 두 가지 기능이 있다.

  • find-by-order : Set의 n번째 원소가 무엇인지를 O(log n) 안에 찾아준다.
  • order-of-key : Set의 원소 k가 몇 번째로 작은 원소인지를 O(log n) 안에 찾아준다.

BBST의 각 노드들이 subtree의 원소의 개수를 갖고 있다면, 이 일은 어려운 일이 아닐 것이다. (...설명 작성 중)

template<class T>
class AVLTree;

template<class T>
class NODE {
	T key;
	int h, size, n_of_keys;
	NODE *left, *right;
	
	friend AVLTree<T>;

	NODE(const T &_key, int _h = 0) : key(_key), h(_h), size(1), n_of_keys(1), left(NULL), right(NULL) {}
	~NODE() {if (left) {delete left;} if (right) {delete right;}}
	
	static NODE* insert(const T &_key, NODE *node) {
		++node->size;
		if (_key == node->key) {
			++node->n_of_keys;
			return node;
		}
		if (_key < node->key) {
			if (node->left) {
				node->left = insert(_key, node->left);
			} else {
				node->left = new NODE(_key);
			}
		} else {
			if (node->right) {
				node->right = insert(_key, node->right);
			} else {
				node->right = new NODE(_key);
			}
		}
		
		node->recompute_height();
		if (node->unbal()) {
			return restruct(node);
		} return node;
	}
	static NODE* remove(const T &_key, NODE *node) {
		if (_key < node->key && node->left) {
			node->left = remove(_key, node->left);
			node->size = node->lsize() + node->rsize() + node->n_of_keys;
		} else if (_key > node->key && node->right) {
			node->right = remove(_key, node->right);
			node->size = node->lsize() + node->rsize() + node->n_of_keys;
		} else if (_key == node->key) {
			--node->size;
			--node->n_of_keys;
			if (node->n_of_keys > 0) {
				return node;
			}
			if (node->left && node->right) {
				const NODE *temp = find_first(node->right);
				node->key = temp->key;
				node->n_of_keys = temp->n_of_keys;
				node->right = remove_first(node->right, node->n_of_keys);
			} else {
				if (node->left) {
					NODE *temp = node->left;
					node->left = nullptr;
					delete node;
					return temp;
				} else if (node->right) {
					NODE *temp = node->right;
					node->right = nullptr;
					delete node;
					return temp;
				} else {
					delete node;
					return nullptr;
				}
			}
		}
		
		node->recompute_height();
		if (node->unbal()) {
			return restruct(node);
		} return node;
	}
	static NODE* remove_nth(int idx, NODE *node) {
		--node->size;
		int lsz = node->lsize();
		if (idx >= lsz + node->n_of_keys && node->right) {
			node->right = remove_nth(idx - lsz - node->n_of_keys, node->right);
		} else if (idx < lsz && node->left) {
			node->left = remove_nth(idx, node->left);
		} else {
			cout << node->key << '\n';
			--node->n_of_keys;
			if (node->n_of_keys > 0) {
				return node;
			}
			if (node->left && node->right) {
				NODE *temp = find_first(node->right);
				node->key = temp->key;
				node->n_of_keys = temp->n_of_keys;
				node->right = remove_first(node->right, node->n_of_keys);
			} else {
				if (node->left) {
					NODE *temp = node->left;
					node->left = nullptr;
					delete node;
					return temp;
				} else if (node->right) {
					NODE *temp = node->right;
					node->right = nullptr;
					delete node;
					return temp;
				} else {
					delete node;
					return nullptr;
				}
			}
		}
		
		node->recompute_height();
		if (node->unbal()) {
			return restruct(node);
		} return node;
	}

	static const NODE* find_first(const NODE *node) {
		while (node->left) {
			node = node->left;
		} return node;
	}
	static const NODE* find_last(const NODE *node) {
		while (node->left) {
			node = node->left;
		} return node;
	}
	static NODE* remove_first(NODE *node, int sz) {
		node->size -= sz;
		if (node->left) {
			node->left = remove_first(node->left, sz);
			node->recompute_height();
			if (node->unbal()) {
				return restruct(node);
			} return node;
		} else if (node->right) {
			NODE *temp = node->right;
			node->right = nullptr;
			delete node;
			return temp;
		} else {
			delete node;
			return nullptr;
		}
	}
	
	bool __contains__(const T &_key) const {
		if (_key == key) {
			return true;
		}
		if (_key < key) {
			if (left) {
				return left->__contains__(_key);
			} else {
				return false;
			}
		} else {
			if (right) {
				return right->__contains__(_key);
			} else {
				return false;
			}
		}
	}
	
	int get_idx(const T &_key) const {
		if (_key < key) {
			if (left) {
				int temp = left->get_idx(_key);
				return temp;
			} else {
				return 0;
			}
		} else {
			if (right) {
				int temp = right->get_idx(_key);
				return temp + lsize() + n_of_keys;
			} else {
				return lsize() + n_of_keys;
			}
		}
	}
	
	T get_nth(int idx) const {
		int lsz = lsize();
		if (idx > lsz) {
			return right->get_nth(idx - lsz - n_of_keys);
		} else if (idx < lsz) {
			return left->get_nth(idx);
		} else {
			return key;
		}
	}
	
	void recompute_height() {
		int l = lh(), r = rh();
		h = l < r ? r + 1 : l + 1;
	}
	
	bool unbal() const {
		return lh() - rh() > 1 || rh() - lh() > 1;
	}
	
	int lh() const {
		return left ? left->h : -1;
	}
	int rh() const {
		return right ? right->h : -1;
	}
	
	int lsize() const {
		return left ? left->size : 0;
	}
	
	int rsize() const {
		return right ? right->size : 0;
	}
	
	static NODE* lrot(NODE *node) {
		NODE *root = node->right;
		node->right = root->left;
		node->recompute_height();
		root->left = node;
		root->recompute_height();
		
		node->size = node->lsize() + node->rsize() + node->n_of_keys;
		root->size = root->lsize() + root->rsize() + root->n_of_keys;
		
		return root;
	}
	
	static NODE* rrot(NODE *node) {
		NODE *root = node->left;
		node->left = root->right;
		node->recompute_height();
		root->right = node;
		root->recompute_height();
		
		node->size = node->lsize() + node->rsize() + node->n_of_keys;
		root->size = root->lsize() + root->rsize() + root->n_of_keys;
		
		return root;
	}
	
	static NODE* restruct(NODE *node) {
		if (node->rh() > node->lh()) {
			if (node->right->lh() > node->right->rh()) {
				node->right = rrot(node->right);
			} return lrot(node);
		} else {
			if (node->left->rh() > node->left->lh()) {
				node->left = lrot(node->left);
			} return rrot(node);
		}
	}
	
	void print(int indent = 0) const {
		if (left) {
			left->print(indent + 2);
		}
		for (int i = 0; i < indent; ++i) {
			cout << '_';
		}
		cout << "key " << key << " size " << size << " n " << n_of_keys << endl;
		if (right) {
			right->print(indent + 2);
		}
	}
};

template<class T>
class AVLTree {
	NODE<T> *p;
	
public:
	AVLTree() : p(nullptr) {}
	~AVLTree() {if (p) {delete p;}}
	
	AVLTree& insert(const T &k) {
		if (p == nullptr) {
			p = new NODE<T>(k);
		} else {
			p = NODE<T>::insert(k, p);
		}
		
		return *this;
	}
	
	AVLTree& remove(const T &k) {
		if (p) {
			p = NODE<T>::remove(k, p);
		}
		
		return *this;
	}
	
	AVLTree& remove_nth(int idx) {
		if (p) {
			p = NODE<T>::remove_nth(idx, p);
		}
		
		return *this;
	}
	
	AVLTree& display() {
		if (p == nullptr) {
			cout << "NO ELEMENTS IN THE SET\n";
		} else {
			p->print();
		}
		
		return *this;
	}
	
	int get_idx(const T &k) {
		return p ? p->get_idx(k) : 0;
	}
};

댓글