Page cover

ST Using Binary Search Tree

#include <iostream>
#include <queue>
using namespace std;

template <typename Key, typename Value>
struct Node
{
     Node *left, *right;
     Key key;
     Value val;
     int count; // In each node, we store the number of nodes in the subtree rooted at that node

public:
     Node()
     {
          left = right = nullptr;
          count = 1;
     }
     Node(Key key, Value val)
     {
          left = right = nullptr;
          this->key = key;
          this->val = val;
          count = 1;
     }
};
template <typename Key, typename Value>
class BST
{
     Node<Key, Value> *root = nullptr;

     Node<Key, Value> *put(Node<Key, Value> *x, Key key, Value val)
     {
          if (x == nullptr)
          {
               // Inserting first element, the count will not correctly updated
               return new Node<Key, Value>(key, val);
          }
          if (x->key > key)
               x->left = put(x->left, key, val);
          else if (x->key < key)
               x->right = put(x->right, key, val);
          else
               x->val = val;
          x->count = 1 + size(x->left) + size(x->right);
          return x;
     }

     Node<Key, Value> *floor(Node<Key, Value> *x, Key key)
     {
          if (x == nullptr)
               return nullptr;
          if (key == x->key)
               return x;
          if (key < x->key)
               return floor(x->left, key);

          Node<Key, Value> *temp = floor(x->right, key);
          if (temp != nullptr)
               return temp;
          else
               return x;
     }

     bool contain(Node<Key, Value> *x, Key key)
     {
          if (x == nullptr)
               return false;
          if (x->key < key)
               return contain(x->right, key);
          else if (x->key > key)
               return contain(x->left, key);
          else
               return true;
     }

     int size(Node<Key, Value> *x)
     {
          return x == nullptr ? 0 : x->count;
     }

     int rank(Node<Key, Value> *x, Key key)
     {
          if (x == nullptr)
               return 0;
          if (key < x->key)
               return rank(x->left, key);
          else if (key > x->key)
               return 1 + size(x->left) + rank(x->right, key);
          else
               return size(x->left);
     }

     Node<Key, Value> *delMax(Node<Key, Value> *x)
     {
          if (x->right == nullptr)
          {
               Node<Key, Value> *temp = x->left;

               return temp;
          }
          x->right = delMax(x->right);
          x->count = 1 + size(x->left) + size(x->right);
          return x;
     }

     Node<Key, Value> *delMin(Node<Key, Value> *x)
     {
          if (x->left == nullptr)
          {
               // Node<Key, Value> *temp = x;
               return x->right;
          }
          x->left = delMin(x->left);
          x->count = 1 + size(x->left) + size(x->right);
          return x;
     }

     Node<Key, Value> *delPair(Node<Key, Value> *x, Key key)
     {
          if (x == nullptr)
               return x;
          if (key < x->key)
               x->left = delPair(x->left, key);
          else if (key > x->key)
               x->right = delPair(x->right, key);
          else
          {
               if (x->left == nullptr)
                    return x->right;
               if (x->right == nullptr)
                    return x->left;
               Node<Key, Value> *t = x;
               x = min(t->right);
               x->right = delMin(t->right); // will return t->right
               x->left = t->left;
          }
          x->count = 1 + size(x->left) + size(x->right);
          return x;
     }

     Node<Key, Value> *min(Node<Key, Value> *x)
     {
          Node<Key, Value> *temp = x;
          while (temp->left != nullptr)
          {
               temp = temp->left;
          }
          return temp;
     }

     Node<Key, Value> *max(Node<Key, Value> *x)
     {
          Node<Key, Value> *temp = x;
          while (temp->right != nullptr)
          {
               temp = temp->right;
          }
          return temp;
     }

     void inorder(Node<Key, Value> *x)
     {
          if (x == nullptr)
               return;
          inorder(x->left);
          cout << x->key << ": " << x->val << "  ";
          inorder(x->right);
     }

     void preorder(Node<Key, Value> *x)
     {
          if (x == nullptr)
               return;
          cout << x->key << ": " << x->val << "  ";
          preorder(x->left);
          preorder(x->right);
     }

     void postorder(Node<Key, Value> *x)
     {
          if (x == nullptr)
               return;
          preorder(x->left);
          preorder(x->right);
          cout << x->key << ": " << x->val << "  ";
     }

     void levelorder(Node<Key, Value> *x)
     {
          queue<Node<Key, Value> *> qt;
          qt.push(root);
          qt.push(nullptr);
          cout << x->key << ": " << x->val << " " << endl;

          while (!qt.empty())
          {
               Node<Key, Value> *cur = qt.front();
               qt.pop();
               if (cur != nullptr)
               {
                    if (cur->left != nullptr)
                    {
                         qt.push(cur->left);
                         cout << cur->left->key << ": " << cur->left->val << " ";
                    }
                    else
                    {
                         cout << "NULL ";
                    }
                    if (cur->right != nullptr)
                    {
                         qt.push(cur->right);
                         cout << cur->right->key << ": " << cur->right->val << " ";
                    }
                    else
                    {
                         cout << "NULL ";
                    }
               }
               if (qt.front() == nullptr)
               {
                    cout << endl;
                    qt.pop();
                    qt.push(nullptr);
               }
          }
     }

public:
     void put(Key key, Value val)
     {
          root = put(root, key, val);
     }

     // 1d Range Count
     int size(Key hi, Key lo)
     {
          if (contain(hi))
               return rank(hi) - rank(lo) + 1;
          else
               return rank(hi) - rank(lo); // no of keys < hi
     }

     Key floor(Key key)
     {
          Node<Key, Value> *x = floor(root, key);
          if (x == nullptr)
               throw "";
          return x->key;
     }

     bool contain(Key key)
     {
          contain(root, key);
     }

     Value get(Key key)
     {
          Node<Key, Value> *cur = root;
          while (cur != nullptr)
          {
               if (cur->key < key)
               {
                    cur = cur->right;
               }
               else if (cur->key > key)
                    cur = cur->left;
               else
                    return cur->val;
          }
          throw "Key Not found!";
     }

     int rank(Key key)
     {
          return rank(root, key);
     }

     Key min()
     {
          if (root == nullptr)
               throw "Empty Symbol Table";
          Node<Key, Value> *temp = min(root);
          return temp->key;
     }

     Key max()
     {
          if (root == nullptr)
               throw "Empty Symbol Table";
          Node<Key, Value> *temp = max(root);
          return temp->key;
     }

     void delMin()
     {
          root = delMin(root);
     }

     void delMax()
     {
          root = delMax(root);
     }

     void delPair(Key key)
     {
          root = delPair(root, key);
     }

     bool checkBST() // Handle some cases - will Modify :)
     {
          queue<Node<Key, Value> *> q;
          q.push(root);
          while (!q.empty())
          {
               Node<Key, Value> *tempRoot = q.front();
               q.pop();
               if (tempRoot != nullptr)
               {
                    if (tempRoot->left != nullptr)
                    {
                         if (tempRoot->left->key > tempRoot->key || (tempRoot->right != nullptr && tempRoot->right->key < tempRoot->key))
                              return false;
                         q.push(tempRoot->left);
                    }
                    if (tempRoot->right != nullptr)
                    {
                         if (tempRoot->right->key < tempRoot->key || (tempRoot->left != nullptr && tempRoot->left->key > tempRoot->key))
                              return false;
                         q.push(tempRoot->right);
                    }
               }
          }
          return true;
     }

     void print_inorder()
     {
          inorder(root);
     }

     void print_preorder()
     {
          preorder(root);
     }

     void print_postorder()
     {
          postorder(root);
     }

     void print_levelorder()
     {
          levelorder(root);
     }
};

int main()
{
     BST<char, string> *st = new BST<char, string>();
     st->put('S', "CS");
     st->put('E', "LLB");
     st->put('X', "SE");
     st->put('A', "BBA");
     st->put('R', "DVM");
     st->put('C', "DPT");
     st->put('H', "D-Pharm");
     st->put('M', "IT");
     st->put('X', "BA"); // just update the value
     // st->delPair('E');
     // st->delMin();
     // st->delMax();
     st->print_levelorder();

     cout << boolalpha << st->rank('M') << endl;
     cout << boolalpha << st->size('T', 'F') << endl;
     cout << boolalpha << st->checkBST() << endl;
     cout << st->min() << endl;
     cout << st->max() << endl;
}

Last updated

Was this helpful?