ST Using Binary Search Tree
Last updated
Was this helpful?
Last updated
Was this helpful?
#include <iostream>
#include <queue>
using namespace std;
template <typename Key, typename Value>
struct Node
{
Node *left, *right;
Key key;
Value val;
int count; // In each node, we store the number of nodes in the subtree rooted at that node
public:
Node()
{
left = right = nullptr;
count = 1;
}
Node(Key key, Value val)
{
left = right = nullptr;
this->key = key;
this->val = val;
count = 1;
}
};
template <typename Key, typename Value>
class BST
{
Node<Key, Value> *root = nullptr;
Node<Key, Value> *put(Node<Key, Value> *x, Key key, Value val)
{
if (x == nullptr)
{
// Inserting first element, the count will not correctly updated
return new Node<Key, Value>(key, val);
}
if (x->key > key)
x->left = put(x->left, key, val);
else if (x->key < key)
x->right = put(x->right, key, val);
else
x->val = val;
x->count = 1 + size(x->left) + size(x->right);
return x;
}
Node<Key, Value> *floor(Node<Key, Value> *x, Key key)
{
if (x == nullptr)
return nullptr;
if (key == x->key)
return x;
if (key < x->key)
return floor(x->left, key);
Node<Key, Value> *temp = floor(x->right, key);
if (temp != nullptr)
return temp;
else
return x;
}
bool contain(Node<Key, Value> *x, Key key)
{
if (x == nullptr)
return false;
if (x->key < key)
return contain(x->right, key);
else if (x->key > key)
return contain(x->left, key);
else
return true;
}
int size(Node<Key, Value> *x)
{
return x == nullptr ? 0 : x->count;
}
int rank(Node<Key, Value> *x, Key key)
{
if (x == nullptr)
return 0;
if (key < x->key)
return rank(x->left, key);
else if (key > x->key)
return 1 + size(x->left) + rank(x->right, key);
else
return size(x->left);
}
Node<Key, Value> *delMax(Node<Key, Value> *x)
{
if (x->right == nullptr)
{
Node<Key, Value> *temp = x->left;
return temp;
}
x->right = delMax(x->right);
x->count = 1 + size(x->left) + size(x->right);
return x;
}
Node<Key, Value> *delMin(Node<Key, Value> *x)
{
if (x->left == nullptr)
{
// Node<Key, Value> *temp = x;
return x->right;
}
x->left = delMin(x->left);
x->count = 1 + size(x->left) + size(x->right);
return x;
}
Node<Key, Value> *delPair(Node<Key, Value> *x, Key key)
{
if (x == nullptr)
return x;
if (key < x->key)
x->left = delPair(x->left, key);
else if (key > x->key)
x->right = delPair(x->right, key);
else
{
if (x->left == nullptr)
return x->right;
if (x->right == nullptr)
return x->left;
Node<Key, Value> *t = x;
x = min(t->right);
x->right = delMin(t->right); // will return t->right
x->left = t->left;
}
x->count = 1 + size(x->left) + size(x->right);
return x;
}
Node<Key, Value> *min(Node<Key, Value> *x)
{
Node<Key, Value> *temp = x;
while (temp->left != nullptr)
{
temp = temp->left;
}
return temp;
}
Node<Key, Value> *max(Node<Key, Value> *x)
{
Node<Key, Value> *temp = x;
while (temp->right != nullptr)
{
temp = temp->right;
}
return temp;
}
void inorder(Node<Key, Value> *x)
{
if (x == nullptr)
return;
inorder(x->left);
cout << x->key << ": " << x->val << " ";
inorder(x->right);
}
void preorder(Node<Key, Value> *x)
{
if (x == nullptr)
return;
cout << x->key << ": " << x->val << " ";
preorder(x->left);
preorder(x->right);
}
void postorder(Node<Key, Value> *x)
{
if (x == nullptr)
return;
preorder(x->left);
preorder(x->right);
cout << x->key << ": " << x->val << " ";
}
void levelorder(Node<Key, Value> *x)
{
queue<Node<Key, Value> *> qt;
qt.push(root);
qt.push(nullptr);
cout << x->key << ": " << x->val << " " << endl;
while (!qt.empty())
{
Node<Key, Value> *cur = qt.front();
qt.pop();
if (cur != nullptr)
{
if (cur->left != nullptr)
{
qt.push(cur->left);
cout << cur->left->key << ": " << cur->left->val << " ";
}
else
{
cout << "NULL ";
}
if (cur->right != nullptr)
{
qt.push(cur->right);
cout << cur->right->key << ": " << cur->right->val << " ";
}
else
{
cout << "NULL ";
}
}
if (qt.front() == nullptr)
{
cout << endl;
qt.pop();
qt.push(nullptr);
}
}
}
public:
void put(Key key, Value val)
{
root = put(root, key, val);
}
// 1d Range Count
int size(Key hi, Key lo)
{
if (contain(hi))
return rank(hi) - rank(lo) + 1;
else
return rank(hi) - rank(lo); // no of keys < hi
}
Key floor(Key key)
{
Node<Key, Value> *x = floor(root, key);
if (x == nullptr)
throw "";
return x->key;
}
bool contain(Key key)
{
contain(root, key);
}
Value get(Key key)
{
Node<Key, Value> *cur = root;
while (cur != nullptr)
{
if (cur->key < key)
{
cur = cur->right;
}
else if (cur->key > key)
cur = cur->left;
else
return cur->val;
}
throw "Key Not found!";
}
int rank(Key key)
{
return rank(root, key);
}
Key min()
{
if (root == nullptr)
throw "Empty Symbol Table";
Node<Key, Value> *temp = min(root);
return temp->key;
}
Key max()
{
if (root == nullptr)
throw "Empty Symbol Table";
Node<Key, Value> *temp = max(root);
return temp->key;
}
void delMin()
{
root = delMin(root);
}
void delMax()
{
root = delMax(root);
}
void delPair(Key key)
{
root = delPair(root, key);
}
bool checkBST() // Handle some cases - will Modify :)
{
queue<Node<Key, Value> *> q;
q.push(root);
while (!q.empty())
{
Node<Key, Value> *tempRoot = q.front();
q.pop();
if (tempRoot != nullptr)
{
if (tempRoot->left != nullptr)
{
if (tempRoot->left->key > tempRoot->key || (tempRoot->right != nullptr && tempRoot->right->key < tempRoot->key))
return false;
q.push(tempRoot->left);
}
if (tempRoot->right != nullptr)
{
if (tempRoot->right->key < tempRoot->key || (tempRoot->left != nullptr && tempRoot->left->key > tempRoot->key))
return false;
q.push(tempRoot->right);
}
}
}
return true;
}
void print_inorder()
{
inorder(root);
}
void print_preorder()
{
preorder(root);
}
void print_postorder()
{
postorder(root);
}
void print_levelorder()
{
levelorder(root);
}
};
int main()
{
BST<char, string> *st = new BST<char, string>();
st->put('S', "CS");
st->put('E', "LLB");
st->put('X', "SE");
st->put('A', "BBA");
st->put('R', "DVM");
st->put('C', "DPT");
st->put('H', "D-Pharm");
st->put('M', "IT");
st->put('X', "BA"); // just update the value
// st->delPair('E');
// st->delMin();
// st->delMax();
st->print_levelorder();
cout << boolalpha << st->rank('M') << endl;
cout << boolalpha << st->size('T', 'F') << endl;
cout << boolalpha << st->checkBST() << endl;
cout << st->min() << endl;
cout << st->max() << endl;
}