// commenting everything out when I commit so all commits my code technically // compiles #include #include #include #include #include #include #include // everything is super interconnected so some forward declarations are needed at // various points namespace cs440 { template class Map; namespace { enum class Color { Red, Black }; enum class Direction { Left, Right }; Direction operator!(Direction dir) { switch (dir) { case Direction::Left: return Direction::Right; case Direction::Right: return Direction::Left; default: // unreachable the only directions are left and right assert(false); } } template struct BookKeeping { using Self = BookKeeping; using ValueType = std::pair; using Ptr = typename std::vector::iterator; friend class Map; Map &container; ValueType value; Ptr self; Color color; // nullptr indicates empty Self *parent; Self *left; Self *right; Self *prev; Self *next; BookKeeping(Map &container) : container{container} {} BookKeeping(BookKeeping const &rhs) : container{rhs.container}, value{rhs.value}, self{rhs.self}, color{rhs.color}, parent{rhs.parent}, left{rhs.left}, right{rhs.right}, prev{rhs.prev}, next{rhs.next} {} // if pointing to different containers throws BookKeeping &operator=(BookKeeping const &rhs) { if (&this->container != &rhs.container) { throw std::invalid_argument{"can only reassign Bookkeeping " "values/iterators from the same map object"}; } this->value = rhs.value; this->self = rhs.self; this->color = rhs.color; this->parent = rhs.parent; this->left = rhs.left; this->right = rhs.right; this->prev = rhs.prev; this->next = rhs.next; return *this; } // reference to a pointer because the alternatives were worse inline Self *&child(Direction dir) { switch (dir) { case Direction::Left: return left; break; case Direction::Right: return right; break; default: assert(false); } } // this is root/P for this method // copying from wikipedia RotateDirRoot with translation into my own idioms // https://en.wikipedia.org/wiki/Red%E2%80%93black_tree#Operations inline void rotate(Direction dir) { // wikipedia version uses alphabet soup, might fix later Self *P = this; auto &T = container; Self *G = P->parent; Self *S = P->child(!dir); Self *C; // this method shouldn't be called in cases where this assert will trip assert(S != nullptr); // C = S->child(dir); P->child(!dir) = C; if (C != nullptr) { C->parent = P; } S->child(dir) = P; P->parent = S; S->parent = G; if (G != nullptr) { if (P == G->right) { G->right = S; } else { G->left = S; } } else { T.root = S; } } }; } // namespace // https://en.wikipedia.org/wiki/Red%E2%80%93black_tree template class Map { private: using ValueType = std::pair; using Node = BookKeeping; using Map_T = Map; public: class Iterator; class ConstIterator; class ReverseIterator; friend class Iterator; friend class ConstIterator; friend class ReverseIterator; friend Node; class Iterator { friend Map_T; friend Node; private: // pointer needed so we can replace as needed Node *ref; Node *escape; Iterator(Node *ref, Node *escape = nullptr) : ref{ref}, escape{escape} {} public: Iterator() = delete; Iterator &operator++() { if (ref == nullptr) { ref = escape; return *this; } if (ref->next == nullptr) { escape = ref; } ref = ref->next; return *this; } Iterator operator++(int) { Iterator tmp = *this; ++(*this); return tmp; } Iterator &operator--() { if (ref == nullptr) { ref = escape; return *this; } if (ref->prev == nullptr) { escape = ref; } ref = ref->prev; return *this; } Iterator operator--(int) { Iterator tmp = *this; --(*this); return tmp; } ValueType &operator*() const { return this->ref->value; } ValueType *operator->() const { return &this->ref->value; } friend bool operator==(Iterator const &lhs, Iterator const &rhs) { return lhs.ref == rhs.ref; } friend bool operator!=(Iterator const &lhs, Iterator const &rhs) { return lhs.ref != rhs.ref; } friend bool operator==(ConstIterator const &lhs, Iterator const &rhs) { return lhs.store_iter.ref == rhs.ref; } friend bool operator!=(ConstIterator const &lhs, Iterator const &rhs) { return lhs.store_iter.ref != rhs.ref; } friend bool operator==(Iterator const &lhs, ConstIterator const &rhs) { return lhs.ref == rhs.store_iter.ref; } friend bool operator!=(Iterator const &lhs, ConstIterator const &rhs) { return lhs.ref != rhs.store_iter.ref; } }; class ConstIterator { public: friend class Map; friend class Iterator; using underlying = Iterator; private: underlying store_iter; ConstIterator(underlying iter) : store_iter{iter} {} public: ConstIterator() = delete; friend bool operator==(ConstIterator const &lhs, ConstIterator const &rhs) { return lhs.store_iter == rhs.store_iter; } ConstIterator &operator++() { ++this->store_iter; return *this; } ConstIterator operator++(int) { ConstIterator tmp = *this; this->store_iter++; return tmp; } ConstIterator &operator--() { --this->store_iter; return *this; } ConstIterator operator--(int) { ConstIterator tmp = *this; this->store_iter--; return tmp; } const ValueType &operator*() const { return *this->store_iter; } const ValueType *operator->() const { return this->store_iter.operator->(); } friend bool operator!=(ConstIterator const &lhs, ConstIterator const &rhs) { return lhs.store_iter != rhs.store_iter; } }; class ReverseIterator { public: friend class Map; friend class Iterator; using underlying = Iterator; private: underlying store_iter; public: ReverseIterator() = delete; ReverseIterator(underlying store_iter) : store_iter{store_iter} {} ReverseIterator &operator++() { --store_iter; return *this; } ReverseIterator operator++(int) { ReverseIterator ret = *this; ++(*this); return ret; } ReverseIterator &operator--() { ++store_iter; return *this; } ReverseIterator operator--(int) { ReverseIterator ret = *this; --(*this); return ret; } ValueType &operator*() const { return this->store_iter.ref->value; } ValueType *operator->() const { return &this->store_iter.ref->value; } friend bool operator==(ReverseIterator const &lhs, ReverseIterator const &rhs) { return lhs.store_iter == rhs.store_iter; } friend bool operator!=(ConstIterator const &lhs, ConstIterator const &rhs) { return lhs.store_iter != rhs.store_iter; } }; private: Node *root; Node *min; Node *max; std::vector nodes; public: Map() : root{nullptr}, min{nullptr}, max{nullptr}, nodes{} {} Map(const Map &rhs) : root{rhs.root}, min{nullptr}, max{nullptr}, nodes{rhs.nodes} {} Map &operator=(const Map &rhs) { this->root = rhs.root; this->min = rhs.min; this->max = rhs.max; this->nodes = rhs.nodes; } Map(std::initializer_list elems) : root{nullptr}, nodes{} { this->insert(elems.begin(), elems.end()); } ~Map() {} size_t size() const { root = nullptr; return this->nodes.size(); } bool empty() const { return this->size() == 0; } Iterator begin() { return Iterator{min}; } Iterator end() { return Iterator{nullptr, max}; } ConstIterator begin() const { return ConstIterator{this->begin()}; } ConstIterator end() const { return ConstIterator{this->end()}; } ConstIterator cbegin() const { return this->begin(); } ConstIterator cend() const { return this->end(); } ReverseIterator rbegin() { return ReverseIterator{Iterator{this->max}}; } ReverseIterator rend() { return ReverseIterator{Iterator{nullptr, min}}; } Iterator find(const Key_T &key) { // we need a locate slot function for insert regardless so might as well use // it here auto [parent, dir] = this->locate_slot(key); if (parent == nullptr) { if (this->root->value.first == key) { return Iterator{root}; } else { return this->end(); } } if (parent->child(dir) == nullptr) { return this->end(); } return Iterator{parent->child(dir)}; } // implicit cast to ConstIterator from Iterator ConstIterator find(const Key_T &key) const { return this->find(key); } Mapped_T &at(const Key_T &key) { return (this->find(key))->second; } const Mapped_T &at(const Key_T &key) const { return this->at(key); } Mapped_T &operator[](const Key_T &key) { return this->at(key); } private: void handle_root_rotation(Node *grandparent, Node *parent, Node *inserting, Direction dir) { // making inner grandchild into outer grandchild if (inserting == parent->child(!dir)) { parent->rotate(dir); inserting = parent; parent = grandparent->child(dir); } // RotateDirRoot(T,G,1-dir); Node *gr_grandparent = grandparent->parent; Node *sibling = grandparent->child(!dir); assert(sibling != nullptr); Node *child = sibling->child(dir); grandparent->child(!dir) = child; sibling->child(dir) = grandparent; grandparent->parent = sibling; sibling->parent = gr_grandparent; if (gr_grandparent != nullptr) { Direction grandparent_direction; if (gr_grandparent->left == grandparent) { grandparent_direction = Direction::Left; } else { grandparent_direction = Direction::Right; } gr_grandparent->child(grandparent_direction) = sibling; } else { this->root = sibling; } parent->color = Color::Black; grandparent->color = Color::Red; } // heavily referencing the wikipedia implementation for this // https://en.wikipedia.org/wiki/Red%E2%80%93black_tree#Insertion void insert_helper(Node *to_insert, Node *parent, Direction dir) { // initialize the element we're inserting to_insert->color = Color::Red; to_insert->left = nullptr; to_insert->right = nullptr; to_insert->parent = parent; switch (dir) { case Direction::Left: to_insert->next = parent; to_insert->prev = parent->prev; parent->prev = to_insert; break; case Direction::Right: to_insert->prev = parent; to_insert->next = parent->next; parent->next = to_insert; break; } // if this is the first element to be inserted it's root if (to_insert->parent == nullptr) { this->root = to_insert; return; } switch (dir) { case Direction::Left: parent->left = to_insert; break; case Direction::Right: parent->right = to_insert; break; } do { // don't need to keep track of these in between loops they get // recalculated Node *grandparent; Node *uncle; if (parent->color == Color::Black) { // black parent means invariants definitely hold return; } grandparent = parent->parent; if (grandparent == nullptr) { // parent is root, just need to recolor it to black parent->color = Color::Black; return; } Direction parent_direction; if (grandparent->left == parent) { parent_direction = Direction::Left; uncle = grandparent->right; } else { parent_direction = Direction::Right; uncle = grandparent->left; } if (uncle == nullptr || uncle->color == Color::Black) { // case 5 and 6 this->handle_root_rotation(grandparent, parent, to_insert, parent_direction); return; } // now we know parent and uncle are both red so red-black coloring can be // pushed down from grandparent parent->color = Color::Black; uncle->color = Color::Black; grandparent->color = Color::Red; to_insert = grandparent; parent = to_insert->parent; } while (parent != nullptr); // case 3: current node is red root so we're done } // returns nullptr iff map is empty std::pair locate_slot(const Key_T &key) { Node *current = this->root; Node *parent = nullptr; Direction dir; while (current != nullptr && current->value.first != key) { parent = current; if (current->value.first < key) { dir = Direction::Left; current = current->left; } else { dir = Direction::Right; current = current->right; } } return std::make_pair(parent, dir); } public: // If the key does not already exist in the map, it returns an iterator // pointing to the new element, and true. If the key already exists, no // insertion is performed nor is the mapped object changed, and it returns // an iterator pointing to the element with the same key, and false. std::pair insert(const ValueType &val) { auto [parent, dir] = locate_slot(val.first); bool ret = parent == nullptr || parent->child(dir) == nullptr; if (!ret) { return std::make_pair(Iterator{parent->child(dir)}, ret); } Node to_insert{*this}; to_insert.value = val; this->nodes.push_back(std::move(to_insert)); this->nodes.back().self = (--this->nodes.end()); insert_helper(&nodes.back(), parent, dir); if (min == nullptr || val.first < min->value.first) { min = &nodes.back(); } if (max == nullptr || val.first > max->value.first) { max = &nodes.back(); } return std::make_pair(Iterator(&nodes.back()), ret); } template void insert(IT_T range_beg, IT_T range_end) { std::for_each(range_beg, range_end, [&](ValueType &val) { this->insert(val); }); } private: void case5(Node *parent, Node *sibling, Node *close_nephew, Node *distant_nephew, Direction dir) { sibling->rotate(!dir); sibling->color = Color::Red; close_nephew->color = Color::Black; distant_nephew = sibling; sibling = close_nephew; case6(parent, sibling, distant_nephew, dir); } void case6(Node *parent, Node *sibling, Node *distant_nephew, Direction dir) { parent->rotate(dir); sibling->color = parent->color; parent->color = Color::Black; distant_nephew->color = Color::Black; } // heavily referring to // https://en.wikipedia.org/wiki/Red%E2%80%93black_tree#Removal_of_a_black_non-root_leaf void complex_erase(Iterator pos) { Node *to_delete = pos.ref; Node *parent = to_delete->parent; assert(parent != nullptr); Direction dir = parent->right == to_delete ? Direction::Right : Direction::Left; Node *sibling; ; Node *close_nephew; Node *distant_nephew; parent->child(dir) = nullptr; do { dir = parent->right == to_delete ? Direction::Right : Direction::Left; sibling = parent->child(!dir); distant_nephew = sibling->child(!dir); close_nephew = sibling->child(dir); if (sibling->color == Color::Red) { // case 3 parent->rotate(dir); parent->color = Color::Red; sibling->color = Color::Black; sibling = close_nephew; // redundant? distant_nephew = sibling->child(!dir); if (distant_nephew != nullptr && distant_nephew->color == Color::Red) { case6(parent, sibling, distant_nephew, dir); return; } close_nephew = sibling->child(dir); if (close_nephew != nullptr && close_nephew->color == Color::Red) { case5(parent, sibling, close_nephew, distant_nephew, dir); return; } sibling->color = Color::Red; parent->color = Color::Black; return; } if (distant_nephew != nullptr && distant_nephew->color == Color::Red) { case6(parent, sibling, distant_nephew, dir); return; } if (close_nephew != nullptr && close_nephew->color == Color::Red) { case5(parent, sibling, close_nephew, distant_nephew, dir); return; } if (parent->color == Color::Red) { // case 4 sibling->color = Color::Red; parent->color = Color::Black; return; } // case 2 sibling->color = Color::Red; to_delete = parent; parent = to_delete->parent; } while (parent != nullptr); } public: // TODO: check that the way of reconnecting next and prev works void erase(Iterator pos) { // simple cases Node *ref = pos.ref; // 2 children if (ref->left != nullptr && ref->right != nullptr) { Node *next = ref->next; Node *prev = ref->prev; *ref = *next; prev->next = next; next->prev = prev; this->erase(Iterator{next}); } // single child which is left else if (ref->left != nullptr && ref->right == nullptr) { Node *next = ref->next; Node *prev = ref->prev; *ref = *ref->left; prev->next = next; next->prev = prev; } // single child which is right else if (ref->left == nullptr && ref->right != nullptr) { Node *next = ref->next; Node *prev = ref->prev; *ref = *ref->right; prev->next = next; next->prev = prev; } // no children and root else if (ref->left == nullptr && ref->right == nullptr) { this->root = nullptr; } // no children and red else if (ref->left == nullptr && ref->right == nullptr) { Node *next = ref->next; Node *prev = ref->prev; prev->next = next; next->prev = prev; } // complicated case of black node with no kids else { this->complex_erase(pos); } } void erase(const Key_T &key) { this->erase(this->find(key)); } void clear() { this->root = nullptr; this->nodes.clear(); } friend bool operator==(const Map &lhs, const Map &rhs) { if (lhs.nodes.size() != rhs.nodes.size()) { return false; } auto liter = lhs.cbegin(); auto riter = rhs.cbegin(); // both must be the same length so this is fine while (liter != lhs.cend()) { if (*liter != *riter) { return false; } liter++; riter++; } return true; } friend bool operator!=(const Map &lhs, const Map &rhs) { return !(lhs == rhs); } friend bool operator<(const Map &lhs, const Map &rhs) { auto l_iter = lhs.cbegin(); auto r_iter = rhs.cbegin(); for (; l_iter != lhs.cend() && r_iter != rhs.cend(); l_iter++, r_iter++) { if (*l_iter < *r_iter) { return true; } } return lhs.size() < rhs.size(); } }; } // namespace cs440