#ifndef _POWELL_CS440 #define _POWELL_CS440 #include // uncomment on submission/performance test // #define NDEBUG #include #include #include #include #include namespace cs440 { // universal type defs here namespace { enum class Direction { Left, Right }; Direction operator!(Direction dir) { switch (dir) { case Direction::Left: return Direction::Right; case Direction::Right: return Direction::Left; default: assert(false); } } enum class Color { Red, Black }; } // namespace template class Map { // Type definitions here using ValueType = std::pair; using internal_ValueType = std::pair; struct Node { int valid = 0x13371337; Node *parent; internal_ValueType val; std::unique_ptr left; std::unique_ptr right; Color color; Node *prev; Node *next; Map *map; Node(internal_ValueType val, Map *map) : parent{nullptr}, val{val}, left{}, right{}, color{Color::Red}, prev{nullptr}, next{nullptr}, map{map} {} Node(const Node &rhs) : parent{nullptr}, val{rhs.val}, left{std::make_unique(*rhs.left)}, right{std::make_unique(*rhs.right)}, color{rhs.color}, prev{nullptr}, next{nullptr}, map{rhs.map} { if (this->left) { this->left->parent = this; } if (this->right) { this->right->parent = this; } } Node(Node &&rhs) : parent{nullptr}, val{std::move(rhs.val)}, left{std::move(rhs.left)}, right{std::move(rhs.right)}, color{rhs.color}, prev{nullptr}, next{nullptr}, map{rhs.map} { if (rhs.valid != 0x13371337) { std::cerr << "(" << rhs.val.first << ")" << std::endl; } rhs.valid = 0; if (this->left) { this->left->parent = this; } if (this->right) { this->right->parent = this; } } ~Node() {} Node &operator=(const Node &rhs) { // retain parent as is, common case is the copy or move is happening due // to a rotation where parent can get wonky // this->parent this->val = rhs.val; this->left = std::make_unique(*rhs.left); this->right = std::make_unique(*rhs.right); this->color = rhs.color; if (this->left) { this->left->parent = this; this->left->restore_ordering(); } if (this->right) { this->right->parent = this; this->right->restore_ordering(); } this->map = rhs.map; return *this; } Node &operator=(Node &&rhs) { // retain parent as is, common case is the copy or move is happening due // to a rotation where parent can get wonky // this->parent this->val = rhs.val; this->left = std::move(rhs.left); this->right = std::move(rhs.right); this->color = rhs.color; if (this->left) { this->left->parent = this; this->left->restore_ordering(); } if (this->right) { this->right->parent = this; this->right->restore_ordering(); } this->map = rhs.map; return *this; } Node *child(Direction dir) { switch (dir) { case Direction::Left: return this->left.get(); case Direction::Right: return this->right.get(); default: assert(false); } } std::unique_ptr uchild(Direction dir) { switch (dir) { case Direction::Left: return std::move(this->left); case Direction::Right: return std::move(this->right); default: assert(false); } } std::unique_ptr uchild(Node *child) { return this->uchild(this->which_child(child)); } std::unique_ptr &set_child(Direction dir, std::unique_ptr new_child) { if (new_child) { new_child->parent = this; } switch (dir) { case Direction::Left: this->left = std::move(new_child); if (this->left) { this->left->parent = this; } return this->left; case Direction::Right: this->right = std::move(new_child); if (this->right) { this->right->parent = this; } return this->right; default: assert(false); } } Direction which_child(Node *n) { if (this->left.get() == n) { return Direction::Left; } if (this->right.get() == n) { return Direction::Right; } assert(false); } void erase_child(Node *n) { this->erase_child(this->which_child(n)); } void erase_child(Direction dir) { // bringing ownership to this function scope so Deleter gets called at end // of function and we can do reordering things std::unique_ptr dropping; switch (dir) { case Direction::Right: dropping = std::move(this->right); break; case Direction::Left: dropping = std::move(this->left); break; default: assert(false); } // intuitively should be correct but might need to do restore ordering on // both instead if (dropping->prev != nullptr) { dropping->prev->next = dropping->next; } if (dropping->next != nullptr) { dropping->next->prev = dropping->prev; } } void restore_ordering() { this->prev = this->calc_pred(); this->next = this->calc_succ(); if (this->prev) { this->prev->next = this; } if (this->next) { this->next->prev = this; } } Node *calc_pred() { if (this->left) { Node *ret = this->left.get(); while (ret->right) { ret = ret->right.get(); } return ret; } else { Node *ret = this->parent; Node *prev_ret = this; while (ret && ret->which_child(prev_ret) != Direction::Right) { prev_ret = ret; ret = prev_ret->parent; } return ret; } } Node *calc_succ() { if (this->right) { Node *ret = this->right.get(); while (ret->left) { ret = ret->left.get(); } return ret; } else { Node *ret = this->parent; Node *prev_ret = this; while (ret && ret->which_child(prev_ret) != Direction::Left) { prev_ret = ret; ret = prev_ret->parent; } return ret; } } void rotate(Direction dir) { // cannot rotate nullptr assert(this != nullptr); // we can't be root for this rotate operation assert(this->parent != nullptr); // if we're missing the child on the opposite direction this is an invalid // rotation assert(this->child(!dir)); // gotta pull outselves out of parent to avoid accidentally overwriting // outselves std::unique_ptr self = this->parent->uchild(!dir); // make sure this is actually us assert(self.get() == this); // make our former position the position of the relevant child this->parent->set_child(!dir, this->uchild(!dir)); // steal our former child's child this->set_child(!dir, this->parent->child(!dir)->uchild(dir)); // make ourselves our former child's child this->parent->child(!dir)->set_child(dir, std::move(self)); } // Referencing // https://en.wikipedia.org/wiki/Red%E2%80%93black_tree#Notes_to_the_insert_diagrams void restore_red_black_insert(Direction dir) { Node *self = this; // infinite loop for case 2's sake, if tail recursion optimization was // guaranteed I'd use tail recursion while (true) { Node *parent = self->parent; // we're root, no-op (case 3) if (!parent) { return; } // if this is violated it's a bug assert(parent->child(dir) == self); // parent is black so no violation no-op (case 1) if (parent->color == Color::Black) { return; } Node *grandparent = parent->parent; // parent is root (case 4) if (!grandparent) { parent->color = Color::Black; return; } // table showing transforms on wikipedia doesn't have this so if it // happens it's probably a bug assert(grandparent->color == Color::Black); Node *uncle = grandparent->child(!grandparent->which_child(parent)); if (uncle == nullptr || uncle->color == Color::Black) { if (parent->which_child(self) != grandparent->which_child(parent)) { // we're an inner child // case 5 parent->rotate(dir); self = parent; parent = self->parent; } // case 6 // recolor first so we aren't recoloring a dropped reference or smth parent->color = Color::Black; grandparent->color = Color::Red; if (grandparent->parent == nullptr) { map->rotate_root(!dir); } else { grandparent->rotate(!dir); } return; } // case 2 (by process of elimination) parent->color = Color::Black; uncle->color = Color::Black; grandparent->color = Color::Red; self = grandparent; } } }; // data needed for implementation std::optional root; std::size_t _size; Node *min; Node *max; public: friend Node; // public type definitions class Iterator { Node *underlying; Node *store; Iterator(Node *ptr, Node *potential = nullptr) : underlying{ptr}, store{potential} {} public: friend Map; Iterator() = delete; void check() { assert(underlying->val.first < 200); if (underlying->prev != nullptr) { assert(underlying->prev->val.first < 200); } if (underlying->next != nullptr) { assert(underlying->next->val.first < 200); } } }; Map() : root{}, _size{0} {} Map(const Map &rhs) : root{rhs.root}, _size{rhs._size} {} Map(Map &&rhs) : root{std::move(rhs.root)}, _size{rhs._size} {} Map &operator=(const Map &rhs) { this->root = rhs.root; this->_size = rhs._size; return *this; } Map &operator=(Map &&rhs) { this->root = std::move(rhs.root); this->_size = rhs._size; return *this; } std::size_t size() { return this->_size; } private: // private helpers void rotate_root(Direction dir) { assert(root.has_value()); std::unique_ptr new_root = root.value().uchild(!dir); // can't make null the new root assert(new_root); std::unique_ptr old_root = std::make_unique(std::move(root.value())); root.value() = std::move(*new_root); old_root->set_child(!dir, root.value().uchild(dir)); root.value().set_child(dir, std::move(old_root)); } template std::pair locate(const Key_T &key) { Node *ret_parent; Direction ret_dir; // map is empty if (!this->root.has_value()) { if constexpr (trace) { std::cerr << "(map empty)" << std::endl; } return std::make_pair(nullptr, ret_dir); } if constexpr (trace) { std::cerr << "root"; } // value is in root if (this->root.value().val.first == key) { if constexpr (trace) { std::cerr << "->found" << std::endl; } return std::make_pair(nullptr, ret_dir); } ret_parent = &this->root.value(); if (key < ret_parent->val.first) { if constexpr (trace) { std::cerr << "->left"; } ret_dir = Direction::Left; } else { if constexpr (trace) { std::cerr << "->right"; } ret_dir = Direction::Right; } while (ret_parent->child(ret_dir) && ret_parent->child(ret_dir)->val.first != key) { ret_parent = ret_parent->child(ret_dir); if (key < ret_parent->val.first) { if constexpr (trace) { std::cerr << "->left"; } ret_dir = Direction::Left; } else { if constexpr (trace) { std::cerr << "->right"; } ret_dir = Direction::Right; } } if constexpr (trace) { std::cerr << "->found" << std::endl; } return std::make_pair(ret_parent, ret_dir); } void hard_erase(Node *n) { assert(n->parent); Node *parent = n->parent; Direction dir = parent->which_child(n); parent->erase_child(n); goto skip; while (true) { parent = n->parent; if (parent == nullptr) { // we're at root we're done (case 1) return; } dir = parent->which_child(n); skip: Color par_color = parent->color; Node *sibling = parent->child(!dir); Color sibling_color = sibling ? sibling->color : Color::Black; Node *close = sibling ? sibling->child(dir) : nullptr; Node *distant = sibling ? sibling->child(!dir) : nullptr; Color close_color = close ? close->color : Color::Black; Color distant_color = distant ? distant->color : Color::Black; #define redcheck(v) if ((v) == Color::Red) // it kinda sucks but I think that goto is genuinely the best solution // here, making methods for cases 4,5 and 6 is a lot of unneeded // bookkeeping redcheck(sibling_color) { // case 3 if (parent->parent != nullptr) { parent->rotate(dir); } else { parent->map->rotate_root(dir); } parent->color = Color::Red; sibling->color = Color::Black; sibling = close; distant = sibling->child(!dir); if (distant != nullptr && distant->color == Color::Red) { goto case_6; } close = sibling->child(dir); if (close != nullptr && close->color == Color::Red) { goto case_5; } goto case_4; } else redcheck(close_color) { // case 5 case_5: assert(sibling); assert(close); sibling->rotate(!dir); sibling->color = Color::Red; close->color = Color::Black; distant = sibling; sibling = close; goto case_6; } else redcheck(distant_color) { // case 6 case_6: assert(parent); assert(sibling); assert(distant); if (parent->parent != nullptr) { parent->rotate(dir); } else { parent->map->rotate_root(dir); } Color tmp = parent->color; sibling->color = tmp; parent->color = Color::Black; distant->color = Color::Black; return; } else redcheck(par_color) { // case 4 case_4: assert(sibling); sibling->color = Color::Red; parent->color = Color::Black; return; } else { // case 2 assert(sibling); sibling->color = Color::Red; n = parent; continue; } } } bool core_erase(Node *erasing) { // 2 children if (erasing->left && erasing->right) { Node *succ = erasing->next; erasing->val = succ->val; this->core_erase(succ); } // 1 child else if (erasing->left) { *erasing = std::move(*erasing->left); if (erasing->prev != nullptr) { erasing->prev->next = erasing; } if (erasing->next != nullptr) { erasing->next->prev = erasing; } return true; } else if (erasing->right) { *erasing = std::move(*erasing->right); if (erasing->prev != nullptr) { erasing->prev->next = erasing; } if (erasing->next != nullptr) { erasing->next->prev = erasing; } return true; } // no children and root else if (!erasing->parent) { erasing->map->root = std::nullopt; } // no children and red else if (erasing->color == Color::Red) { erasing->parent->erase_child(erasing); } // no children and black else { hard_erase(erasing); } return false; } public: // baseline find using locate Iterator find(const Key_T &key) { auto [parent, dir] = locate(key); if (parent == nullptr) { if (this->root.has_value()) { if (this->root.value().val.first == key) { return Iterator{&this->root.value()}; } } return this->end(); } if (parent->child(dir) != nullptr) { return Iterator{parent->child(dir), nullptr}; } return this->end(); } Iterator find_trace(const Key_T &key) { auto [parent, dir] = locate(key); if (parent == nullptr) { if (this->root.has_value()) { if (this->root.value().val.first == key) { return Iterator{&this->root.value()}; } } return this->end(); } if (parent->child(dir) != nullptr) { return Iterator{parent->child(dir), nullptr}; } return this->end(); } // baseline modification operations std::pair insert(const ValueType &val) { // for convenience auto &[key, map] = val; auto [parent, dir] = locate(key); // located root node if (parent == nullptr) { if (this->root.has_value()) { return std::make_pair(Iterator{&root.value()}, false); } else { this->root = Node{val, this}; return std::make_pair(Iterator{&root.value()}, true); } } // non-root node if (parent->child(dir)) { // node already present return std::make_pair(Iterator{parent->child(dir)}, false); } // need to insert non-root node Node *new_node = parent->set_child(dir, std::make_unique(Node{val, this})).get(); new_node->restore_red_black_insert(dir); new_node->restore_ordering(); return std::make_pair(Iterator{new_node}, true); } void erase(Iterator pos) { Node *before = pos.underlying->prev; Node *after = pos.underlying->next; if (core_erase(pos.underlying)) { pos.underlying->restore_ordering(); } else { if (before != nullptr) { before->next = before->calc_succ(); if (before->next != nullptr) { before->next->prev = before; } } if (after != nullptr) { after->prev = after->calc_pred(); if (after->prev != nullptr) { after->prev->next = after; } } } } // baseline iterator creation Iterator begin() { return Iterator{min, nullptr}; } Iterator end() { return Iterator{nullptr, max}; } }; } // namespace cs440 #endif