#ifndef _POWELL_CS440 #define _POWELL_CS440 #include // uncomment on submission/performance test // #define NDEBUG #include #include #include #include #include namespace cs440 { // universal type defs here namespace { enum class Direction { Left, Right }; Direction operator!(Direction dir) { switch (dir) { case Direction::Left: return Direction::Right; case Direction::Right: return Direction::Left; default: assert(false); } } enum class Color { Red, Black }; } // namespace template class Map { // Type definitions here using ValueType = std::pair; using internal_ValueType = std::pair; struct Node { int valid = 0x13371337; Node *parent = nullptr; std::unique_ptr val; std::unique_ptr left; std::unique_ptr right; Color color; Node *prev; Node *next; Map *map; Node(internal_ValueType val, Map *map) : parent{nullptr}, val{new internal_ValueType{val}}, left{}, right{}, color{Color::Red}, prev{nullptr}, next{nullptr}, map{map} {} Node(const Node &rhs) : parent{nullptr}, val{rhs.val ? new internal_ValueType{*rhs.val} : nullptr}, left{rhs.left ? std::make_unique(*rhs.left) : nullptr}, right{rhs.right ? std::make_unique(*rhs.right) : nullptr}, color{rhs.color}, prev{nullptr}, next{nullptr}, map{rhs.map} { this->valid = 0x13371337; if (this->left) { this->left->parent = this; } if (this->right) { this->right->parent = this; } this->next = rhs.next; this->prev = rhs.prev; } Node(Node &&rhs) : parent{nullptr}, val{std::move(rhs.val)}, left{std::move(rhs.left)}, right{std::move(rhs.right)}, color{rhs.color}, prev{nullptr}, next{nullptr}, map{rhs.map} { rhs.valid = 0; this->valid = 0x13371337; if (this->left) { this->left->parent = this; } if (this->right) { this->right->parent = this; } this->next = rhs.next; this->prev = rhs.prev; } ~Node() {} Node &operator=(const Node &rhs) { // retain parent as is, common case is the copy or move is happening due // to a rotation where parent can get wonky // this->parent this->val = rhs.val ? std::unique_ptr{new internal_ValueType{ *rhs.val}} : nullptr; this->left = rhs.left ? std::make_unique(*rhs.left) : nullptr; this->right = rhs.right ? std::make_unique(*rhs.right) : nullptr; this->color = rhs.color; this->valid = 0x13371337; if (this->left) { this->left->parent = this; this->left->restore_ordering(); } if (this->right) { this->right->parent = this; this->right->restore_ordering(); } this->restore_ordering(); this->map = rhs.map; return *this; } Node &operator=(Node &&rhs) { // retain parent as is, common case is the copy or move is happening due // to a rotation where parent can get wonky // this->parent this->val = std::move(rhs.val); this->left = std::move(rhs.left); this->right = std::move(rhs.right); this->color = rhs.color; this->valid = 0x13371337; rhs.valid = 0; if (this->left) { this->left->parent = this; this->left->restore_ordering(); } if (this->right) { this->right->parent = this; this->right->restore_ordering(); } this->restore_ordering(); this->map = rhs.map; return *this; } Node *child(Direction dir) { switch (dir) { case Direction::Left: return this->left.get(); case Direction::Right: return this->right.get(); default: assert(false); } } Node const *child(Direction dir) const { switch (dir) { case Direction::Left: return this->left.get(); case Direction::Right: return this->right.get(); default: assert(false); } } std::unique_ptr uchild(Direction dir) { switch (dir) { case Direction::Left: return std::move(this->left); case Direction::Right: return std::move(this->right); default: assert(false); } } std::unique_ptr uchild(Node *child) { return this->uchild(this->which_child(child)); } std::unique_ptr &set_child(Direction dir, std::unique_ptr new_child) { if (new_child) { new_child->parent = this; } switch (dir) { case Direction::Left: this->left = std::move(new_child); if (this->left) { this->left->parent = this; } return this->left; case Direction::Right: this->right = std::move(new_child); if (this->right) { this->right->parent = this; } return this->right; default: assert(false); } } Direction which_child(Node *n) { if (this->left.get() == n) { return Direction::Left; } if (this->right.get() == n) { return Direction::Right; } assert(false); } void erase_child(Node *n) { this->erase_child(this->which_child(n)); } void erase_child(Direction dir) { bool minmax = this->child(dir) == this->map->min || this->child(dir) == this->map->max; // bringing ownership to this function scope so Deleter gets called at end // of function and we can do reordering things std::unique_ptr dropping; switch (dir) { case Direction::Right: dropping = std::move(this->right); break; case Direction::Left: dropping = std::move(this->left); break; default: assert(false); } // intuitively should be correct but might need to do restore ordering on // both instead if (dropping->prev != nullptr) { dropping->prev->next = dropping->next; } if (dropping->next != nullptr) { dropping->next->prev = dropping->prev; } if (minmax) { switch (dir) { case Direction::Left: this->map->min = this; break; case Direction::Right: this->map->max = this; break; assert(false); } } } void restore_ordering() { this->prev = this->calc_pred(); this->next = this->calc_succ(); if (this->prev) { this->prev->next = this; } if (this->next) { this->next->prev = this; } } Node *calc_pred() { if (this->left) { Node *ret = this->left.get(); while (ret->right) { ret = ret->right.get(); } return ret; } else { Node *ret = this->parent; Node *prev_ret = this; while (ret && ret->which_child(prev_ret) != Direction::Right) { prev_ret = ret; ret = prev_ret->parent; } return ret; } } Node *calc_succ() { if (this->right) { Node *ret = this->right.get(); while (ret->left) { ret = ret->left.get(); } return ret; } else { Node *ret = this->parent; Node *prev_ret = this; while (ret && ret->which_child(prev_ret) != Direction::Left) { prev_ret = ret; ret = prev_ret->parent; } return ret; } } void rotate(Direction dir) { // cannot rotate nullptr assert(this != nullptr); // we can't be root for this rotate operation assert(this->parent != nullptr); // if we're missing the child on the opposite direction this is an invalid // rotation assert(this->child(!dir)); // gotta pull outselves out of parent to avoid accidentally overwriting // outselves std::unique_ptr self = this->parent->uchild(!dir); // make sure this is actually us assert(self.get() == this); // make our former position the position of the relevant child this->parent->set_child(!dir, this->uchild(!dir)); // steal our former child's child this->set_child(!dir, this->parent->child(!dir)->uchild(dir)); // make ourselves our former child's child this->parent->child(!dir)->set_child(dir, std::move(self)); } // Referencing // https://en.wikipedia.org/wiki/Red%E2%80%93black_tree#Notes_to_the_insert_diagrams void restore_red_black_insert(Direction dir) { Node *self = this; // infinite loop for case 2's sake, if tail recursion optimization was // guaranteed I'd use tail recursion while (true) { Node *parent = self->parent; // we're root, no-op (case 3) if (!parent) { self->color = Color::Black; return; } // if this is violated it's a bug assert(parent->child(dir) == self); // parent is black so no violation no-op (case 1) if (parent->color == Color::Black) { return; } Node *grandparent = parent->parent; // parent is root (case 4) if (!grandparent) { parent->color = Color::Black; return; } // table showing transforms on wikipedia doesn't have this so if it // happens it's probably a bug assert(grandparent->color == Color::Black); Node *uncle = grandparent->child(!grandparent->which_child(parent)); if (uncle == nullptr || uncle->color == Color::Black) { if (parent->which_child(self) != grandparent->which_child(parent)) { // we're an inner child // case 5 parent->rotate(dir); self = parent; parent = self->parent; } // case 6 // recolor first so we aren't recoloring a dropped reference or smth parent->color = Color::Black; grandparent->color = Color::Red; if (grandparent->parent == nullptr) { map->rotate_root(!dir); } else { grandparent->rotate(!dir); } return; } // case 2 (by process of elimination) parent->color = Color::Black; uncle->color = Color::Black; grandparent->color = Color::Red; self = grandparent; } } }; // data needed for implementation std::optional root; std::size_t _size; Node *min; Node *max; public: friend Node; class ConstIterator; class ReverseIterator; // public type definitions class Iterator { Node *underlying; Node *store; Iterator(Node *ptr, Node *potential = nullptr) : underlying{ptr}, store{potential} {} public: friend Map; friend ConstIterator; friend ReverseIterator; Iterator() = delete; Iterator(const Iterator &rhs) = default; Iterator &operator=(const Iterator &) = default; ~Iterator() = default; // precrement Iterator &operator++() { if (this->underlying == nullptr) { this->underlying = this->store; this->store = nullptr; return *this; } if (this->underlying->next == nullptr) { this->store = this->underlying; } this->underlying = this->underlying->next; return *this; } // postcrement Iterator operator++(int) { auto copy = *this; this->operator++(); return copy; } // precrement Iterator &operator--() { if (this->underlying == nullptr) { this->underlying = this->store; this->store = nullptr; return *this; } if (this->underlying->prev == nullptr) { this->store = this->underlying; } this->underlying = this->underlying->prev; return *this; } // postcrement Iterator operator--(int) { auto copy = *this; this->operator--(); return copy; } ValueType &operator*() const { ValueType *ret = (ValueType *)(&this->underlying->val); return *ret; } ValueType *operator->() const { return &this->operator*(); } friend bool operator==(const Iterator &lhs, const Iterator &rhs) { return lhs.underlying == rhs.underlying; } friend bool operator!=(const Iterator &lhs, const Iterator &rhs) { return !(lhs == rhs); } }; class ConstIterator { Iterator underlying; public: friend Map; ConstIterator() = delete; ConstIterator(const Iterator &underlying) : underlying{underlying} {} ConstIterator(const ConstIterator &rhs) = default; ConstIterator &operator=(const ConstIterator &) = default; ~ConstIterator() = default; ConstIterator &operator++() { ++underlying; return *this; } ConstIterator operator++(int) { auto copy = *this; this->operator++(); return copy; } ConstIterator &operator--() { --underlying; return *this; } ConstIterator operator--(int) { auto copy = *this; this->operator--(); return copy; } const ValueType &operator*() const { return this->underlying.operator*(); } const ValueType *operator->() const { return &this->operator*(); } friend bool operator==(const ConstIterator &lhs, const ConstIterator &rhs) { return lhs.underlying == rhs.underlying; } friend bool operator!=(const ConstIterator &lhs, const ConstIterator &rhs) { return !(lhs == rhs); } friend bool operator==(const Iterator &lhs, const ConstIterator &rhs) { return lhs == rhs.underlying; } friend bool operator!=(const Iterator &lhs, const ConstIterator &rhs) { return !(lhs == rhs); } friend bool operator==(const ConstIterator &lhs, const Iterator &rhs) { return lhs.underlying == rhs; } friend bool operator!=(const ConstIterator &lhs, const Iterator &rhs) { return !(lhs == rhs); } }; class ReverseIterator { Iterator underlying; ReverseIterator(const Iterator &underlying) : underlying{underlying} {} public: friend Map; ReverseIterator() = delete; ReverseIterator(const ReverseIterator &) = default; ~ReverseIterator() = default; ReverseIterator &operator=(const ReverseIterator &) = default; ReverseIterator &operator++() { --underlying; return *this; } ReverseIterator operator++(int) { auto copy = *this; this->operator++(); return copy; } ReverseIterator &operator--() { ++underlying; return *this; } ReverseIterator operator--(int) { auto copy = *this; this->operator++(); return copy; } const ValueType &operator*() const { return this->underlying.operator*(); } const ValueType *operator->() const { return &this->operator*(); } }; Map() : root{}, _size{0} {} Map(const Map &rhs) : root{rhs.root}, _size{rhs._size} { this->min = &this->root.value(); this->max = &this->root.value(); while (min->left) { min = min->left.get(); } while (min->right) { min = min->left.get(); } } Map(Map &&rhs) : root{std::move(rhs.root)}, _size{rhs._size} { this->min = &this->root.value(); this->max = &this->root.value(); while (min->left) { min = min->left.get(); } while (min->right) { min = min->left.get(); } } Map &operator=(const Map &rhs) { this->root = rhs.root; this->_size = rhs._size; return *this; } Map &operator=(Map &&rhs) { this->root = std::move(rhs.root); this->_size = rhs._size; return *this; } Map(std::initializer_list> items) : Map{} { this->insert(items.begin(), items.end()); } void check() { assert(!this->root || this->root.value().color == Color::Black); } std::size_t size() const { return this->_size; } bool empty() const { return this->size() == 0; } private: // private helpers void rotate_root(Direction dir) { assert(root.has_value()); std::unique_ptr new_root = root.value().uchild(!dir); // can't make null the new root assert(new_root); std::unique_ptr old_root = std::make_unique(std::move(root.value())); root.value() = std::move(*new_root.release()); old_root->set_child(!dir, root.value().uchild(dir)); if (old_root->left) { old_root->left->parent = old_root.get(); } if (old_root->right) { old_root->right->parent = old_root.get(); } if (old_root->next) { old_root->next->prev = old_root.get(); } if (old_root->prev) { old_root->prev->next = old_root.get(); } root.value().set_child(dir, std::move(old_root)); root.value().child(dir)->restore_ordering(); if (root.value().left) { root.value().left->parent = &root.value(); if (min == &root.value()) { min = root.value().left.get(); } } if (root.value().right) { root.value().right->parent = &root.value(); if (max == &root.value()) { max = root.value().right.get(); } } } template std::pair locate(const Key_T &key) const { Node const *ret_parent; Direction ret_dir; // map is empty if (!this->root.has_value()) { return std::make_pair(nullptr, ret_dir); } // value is in root if (this->root.value().val->first == key) { return std::make_pair(nullptr, ret_dir); } ret_parent = &this->root.value(); if (key < ret_parent->val->first) { ret_dir = Direction::Left; } else { ret_dir = Direction::Right; } while (ret_parent->child(ret_dir) && !(ret_parent->child(ret_dir)->val->first == key)) { ret_parent = ret_parent->child(ret_dir); if (key < ret_parent->val->first) { ret_dir = Direction::Left; } else { ret_dir = Direction::Right; } } return std::make_pair(ret_parent, ret_dir); } void hard_erase(Node *n) { assert(n->parent); Node *parent = n->parent; Direction dir = parent->which_child(n); parent->erase_child(n); goto skip; while (true) { parent = n->parent; if (parent == nullptr) { // we're at root we're done (case 1) return; } dir = parent->which_child(n); skip: Color par_color = parent->color; Node *sibling = parent->child(!dir); Color sibling_color = sibling ? sibling->color : Color::Black; Node *close = sibling ? sibling->child(dir) : nullptr; Node *distant = sibling ? sibling->child(!dir) : nullptr; Color close_color = close ? close->color : Color::Black; Color distant_color = distant ? distant->color : Color::Black; #define redcheck(v) if ((v) == Color::Red) // it kinda sucks but I think that goto is genuinely the best solution // here, making methods for cases 4,5 and 6 is a lot of unneeded // bookkeeping redcheck(sibling_color) { // case 3 if (parent->parent != nullptr) { parent->rotate(dir); } else { parent->map->rotate_root(dir); } parent->color = Color::Red; sibling->color = Color::Black; sibling = close; distant = sibling->child(!dir); if (distant != nullptr && distant->color == Color::Red) { goto case_6; } close = sibling->child(dir); if (close != nullptr && close->color == Color::Red) { goto case_5; } goto case_4; } else redcheck(close_color) { // case 5 case_5: assert(sibling); assert(close); sibling->rotate(!dir); sibling->color = Color::Red; close->color = Color::Black; distant = sibling; sibling = close; goto case_6; } else redcheck(distant_color) { // case 6 case_6: assert(parent); assert(sibling); assert(distant); if (parent->parent != nullptr) { parent->rotate(dir); } else { parent->map->rotate_root(dir); } Color tmp = parent->color; sibling->color = tmp; parent->color = Color::Black; distant->color = Color::Black; return; } else redcheck(par_color) { // case 4 case_4: assert(sibling); sibling->color = Color::Red; parent->color = Color::Black; return; } else { // case 2 assert(sibling); sibling->color = Color::Red; n = parent; continue; } } } bool core_erase(Node *erasing) { Color c = erasing->color; // 2 children if (erasing->left && erasing->right) { Node *succ = erasing->next; erasing->val = std::move(succ->val); this->core_erase(succ); } // 1 child else if (erasing->left) { *erasing = std::move(*erasing->left.release()); if (erasing->prev != nullptr) { erasing->prev->next = erasing; } if (erasing->next != nullptr) { erasing->next->prev = erasing; } erasing->color = c; return true; } else if (erasing->right) { *erasing = std::move(*erasing->right.release()); if (erasing->prev != nullptr) { erasing->prev->next = erasing; } if (erasing->next != nullptr) { erasing->next->prev = erasing; } erasing->color = c; return true; } // no children and root else if (erasing->parent == nullptr) { erasing->map->root = std::nullopt; } // no children and red else if (erasing->color == Color::Red) { erasing->parent->erase_child(erasing); } // no children and black else { hard_erase(erasing); } return false; } public: // baseline find using locate Iterator find(const Key_T &key) { auto [parent, dir] = locate(key); if (parent == nullptr) { if (this->root.has_value()) { if (this->root.value().val->first == key) { return Iterator{&this->root.value()}; } } return this->end(); } if (parent->child(dir) != nullptr) { return Iterator{const_cast(parent->child(dir)), nullptr}; } return this->end(); } ConstIterator find(const Key_T &key) const { auto [parent, dir] = locate(key); if (parent == nullptr) { if (this->root.has_value()) { if (this->root.value().val->first == key) { return Iterator{const_cast(&this->root.value())}; } } return this->end(); } if (parent->child(dir) != nullptr) { return Iterator{const_cast(parent->child(dir)), nullptr}; } return this->end(); } // baseline modification operations std::pair insert(const ValueType &val) { // for convenience auto &[key, map] = val; auto [parent, dir] = locate(key); // located root node if (parent == nullptr) { if (this->root.has_value()) { return std::make_pair(Iterator{&root.value()}, false); } else { this->root = Node{val, this}; this->root.value().color = Color::Black; this->min = &this->root.value(); this->max = &this->root.value(); this->_size++; return std::make_pair(Iterator{&root.value()}, true); } } // non-root node if (parent->child(dir)) { // node already present return std::make_pair(Iterator{const_cast(parent->child(dir))}, false); } // need to insert non-root node Map *m = const_cast(this); Node *new_node = const_cast(parent) ->set_child(dir, std::make_unique(Node{val, m})) .get(); new_node->restore_red_black_insert(dir); new_node->restore_ordering(); if (this->min == parent && dir == Direction::Left) { this->min = new_node; } if (this->max == parent && dir == Direction::Right) { this->max = new_node; } this->_size++; return std::make_pair(Iterator{new_node}, true); } void erase(Iterator pos) { if (pos.underlying == nullptr) { return; } this->_size--; Node *before = pos.underlying->prev; Node *after = pos.underlying->next; if (core_erase(pos.underlying)) { pos.underlying->restore_ordering(); } else { if (before != nullptr && before->valid == 0x13371337) { before->next = before->calc_succ(); if (before->next != nullptr) { before->next->prev = before; } } else { this->min = after; } if (after != nullptr && after->valid == 0x13371337) { after->prev = after->calc_pred(); if (after->prev != nullptr) { after->prev->next = after; } } else { this->max = before; } } } // baseline iterator creation Iterator begin() { return Iterator{min, nullptr}; } Iterator end() { return Iterator{nullptr, max}; } ConstIterator begin() const { return Iterator{min, nullptr}; } ConstIterator end() const { return Iterator{nullptr, max}; } ReverseIterator rbegin() { return Iterator{max, nullptr}; } ReverseIterator rend() { return Iterator{nullptr, min}; } // misc that can be implemented with the above or trivially void clear() { this->root = std::move(std::nullopt); this->_size = 0; } Mapped_T &at(const Key_T &key) { auto ret = this->find(key); if (ret == this->end()) { throw std::out_of_range{"key not in map"}; } return (*ret).second; } const Mapped_T &at(const Key_T &key) const { auto ret = this->find(key); if (ret == this->end()) { throw std::out_of_range{"key not in map"}; } return (*ret).second; } Mapped_T &operator[](const Key_T &key) { this->insert({key, {}}); return this->at(key); } template void insert(IT_T range_beg, IT_T range_end) { while (range_beg != range_end) { auto [first, second] = *range_beg; this->insert(std::make_pair(first, second)); ++range_beg; } } void erase(const Key_T &key) { this->erase(this->find(key)); } friend bool operator==(const Map &lhs, const Map &rhs) { if (lhs.size() != rhs.size()) { return false; } auto lhs_iter = lhs.begin(); auto rhs_iter = rhs.begin(); while (lhs_iter != lhs.end()) { if (*lhs_iter != *rhs_iter) { return false; } ++lhs_iter; ++rhs_iter; } return true; } friend bool operator!=(const Map &lhs, const Map &rhs) { return !(lhs == rhs); } friend bool operator<(const Map &lhs, const Map &rhs) { auto lhs_iter = lhs.begin(); auto rhs_iter = rhs.begin(); while (lhs_iter != lhs.end() && rhs_iter != rhs.end()) { if (*lhs_iter < *rhs_iter) { return true; } if (*lhs_iter != *rhs_iter) { return false; } ++lhs_iter; ++rhs_iter; } return lhs.size() < rhs.size(); } }; } // namespace cs440 #endif