diff --git a/Map.hpp b/Map.hpp index 37a1a53..ab642cd 100644 --- a/Map.hpp +++ b/Map.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace cs440 { // universal type defs here @@ -47,12 +48,17 @@ template class Map { left{std::make_unique(*rhs.left)}, right{std::make_unique(*rhs.right)}, color{rhs.color}, prev{nullptr}, next{nullptr}, map{rhs.map} { + this->valid = 0x13371337; + if (this->left) { this->left->parent = this; } if (this->right) { this->right->parent = this; } + + this->next = rhs.next; + this->prev = rhs.prev; } Node(Node &&rhs) : parent{nullptr}, val{std::move(rhs.val)}, left{std::move(rhs.left)}, @@ -72,17 +78,22 @@ template class Map { if (this->right) { this->right->parent = this; } + + this->next = rhs.next; + this->prev = rhs.prev; } ~Node() {} Node &operator=(const Node &rhs) { // retain parent as is, common case is the copy or move is happening due // to a rotation where parent can get wonky // this->parent + this->val = rhs.val; this->left = std::make_unique(*rhs.left); this->right = std::make_unique(*rhs.right); this->color = rhs.color; this->valid = 0x13371337; + if (this->left) { this->left->parent = this; this->left->restore_ordering(); @@ -91,7 +102,7 @@ template class Map { this->right->parent = this; this->right->restore_ordering(); } - + this->restore_ordering(); this->map = rhs.map; return *this; } @@ -113,6 +124,7 @@ template class Map { this->right->parent = this; this->right->restore_ordering(); } + this->restore_ordering(); this->map = rhs.map; return *this; } @@ -126,6 +138,16 @@ template class Map { assert(false); } } + Node const *child(Direction dir) const { + switch (dir) { + case Direction::Left: + return this->left.get(); + case Direction::Right: + return this->right.get(); + default: + assert(false); + } + } std::unique_ptr uchild(Direction dir) { switch (dir) { case Direction::Left: @@ -174,6 +196,8 @@ template class Map { } void erase_child(Node *n) { this->erase_child(this->which_child(n)); } void erase_child(Direction dir) { + bool minmax = this->child(dir) == this->map->min || + this->child(dir) == this->map->max; // bringing ownership to this function scope so Deleter gets called at end // of function and we can do reordering things std::unique_ptr dropping; @@ -197,6 +221,18 @@ template class Map { if (dropping->next != nullptr) { dropping->next->prev = dropping->prev; } + + if (minmax) { + switch (dir) { + case Direction::Left: + this->map->min = this; + break; + case Direction::Right: + this->map->max = this; + break; + assert(false); + } + } } void restore_ordering() { this->prev = this->calc_pred(); @@ -344,6 +380,8 @@ template class Map { public: friend Node; + class ConstIterator; + class ReverseIterator; // public type definitions class Iterator { Node *underlying; @@ -353,34 +391,146 @@ public: public: friend Map; + friend ConstIterator; + friend ReverseIterator; Iterator() = delete; - void check() { + Iterator(const Iterator &rhs) = default; + Iterator &operator=(const Iterator &) = default; + ~Iterator() = default; + // precrement + Iterator &operator++() { if (this->underlying == nullptr) { - return; + this->underlying = this->store; + this->store = nullptr; + return *this; } - if (this->underlying->parent) { - switch (this->underlying->parent->which_child(this->underlying)) { - case Direction::Left: - assert(this->underlying->val.first < - this->underlying->parent->val.first); - break; - case Direction::Right: - assert(this->underlying->val.first > - this->underlying->parent->val.first); - break; - default: - assert(false); - } + if (this->underlying->next == nullptr) { + this->store = this->underlying; } - if (this->underlying->right) { - assert(this->underlying->right->val.first > - this->underlying->val.first); + this->underlying = this->underlying->next; + return *this; + } + // postcrement + Iterator operator++(int) { + auto copy = *this; + this->operator++(); + return copy; + } + + // precrement + Iterator &operator--() { + if (this->underlying == nullptr) { + this->underlying = this->store; + this->store = nullptr; + return *this; } - if (this->underlying->left) { - assert(this->underlying->left->val.first < this->underlying->val.first); + if (this->underlying->prev == nullptr) { + this->store = this->underlying; } + this->underlying = this->underlying->prev; + return *this; + } + // postcrement + Iterator operator--(int) { + auto copy = *this; + this->operator--(); + return copy; + } + + ValueType &operator*() const { + ValueType *ret = (ValueType *)(&this->underlying->val); + return *ret; + } + ValueType *operator->() const { return &this->operator*(); } + friend bool operator==(const Iterator &lhs, const Iterator &rhs) { + return lhs.underlying == rhs.underlying; + } + friend bool operator!=(const Iterator &lhs, const Iterator &rhs) { + return !(lhs == rhs); } }; + class ConstIterator { + Iterator underlying; + + public: + friend Map; + ConstIterator() = delete; + ConstIterator(const Iterator &underlying) : underlying{underlying} {} + ConstIterator(const ConstIterator &rhs) = default; + ConstIterator &operator=(const ConstIterator &) = default; + ~ConstIterator() = default; + ConstIterator &operator++() { + ++underlying; + return *this; + } + ConstIterator operator++(int) { + auto copy = *this; + this->operator++(); + return copy; + } + ConstIterator &operator--() { + --underlying; + return *this; + } + ConstIterator operator--(int) { + auto copy = *this; + this->operator--(); + return copy; + } + const ValueType &operator*() const { return this->underlying.operator*(); } + const ValueType *operator->() const { return &this->operator*(); } + + friend bool operator==(const ConstIterator &lhs, const ConstIterator &rhs) { + return lhs.underlying == rhs.underlying; + } + friend bool operator!=(const ConstIterator &lhs, const ConstIterator &rhs) { + return !(lhs == rhs); + } + friend bool operator==(const Iterator &lhs, const ConstIterator &rhs) { + return lhs == rhs.underlying; + } + friend bool operator!=(const Iterator &lhs, const ConstIterator &rhs) { + return !(lhs == rhs); + } + friend bool operator==(const ConstIterator &lhs, const Iterator &rhs) { + return lhs.underlying == rhs; + } + friend bool operator!=(const ConstIterator &lhs, const Iterator &rhs) { + return !(lhs == rhs); + } + }; + class ReverseIterator { + Iterator underlying; + ReverseIterator(const Iterator &underlying) : underlying{underlying} {} + + public: + friend Map; + ReverseIterator() = delete; + ReverseIterator(const ReverseIterator &) = default; + ~ReverseIterator() = default; + ReverseIterator &operator=(const ReverseIterator &) = default; + ReverseIterator &operator++() { + --underlying; + return *this; + } + ReverseIterator operator++(int) { + auto copy = *this; + this->operator++(); + return copy; + } + ReverseIterator &operator--() { + ++underlying; + return *this; + } + ReverseIterator operator--(int) { + auto copy = *this; + this->operator++(); + return copy; + } + const ValueType &operator*() const { return this->underlying.operator*(); } + const ValueType *operator->() const { return &this->operator*(); } + }; + Map() : root{}, _size{0} {} Map(const Map &rhs) : root{rhs.root}, _size{rhs._size} {} Map(Map &&rhs) : root{std::move(rhs.root)}, _size{rhs._size} {} @@ -394,10 +544,14 @@ public: this->_size = rhs._size; return *this; } + Map(std::initializer_list> items) : Map{} { + this->insert(items.begin(), items.end()); + } void check() { assert(!this->root || this->root.value().color == Color::Black); } std::size_t size() { return this->_size; } + bool empty() { return this->size() == 0; } private: // private helpers @@ -421,18 +575,32 @@ private: if (old_root->right) { old_root->right->parent = old_root.get(); } + if (old_root->next) { + old_root->next->prev = old_root.get(); + } + if (old_root->prev) { + old_root->prev->next = old_root.get(); + } root.value().set_child(dir, std::move(old_root)); + root.value().child(dir)->restore_ordering(); + if (root.value().left) { root.value().left->parent = &root.value(); + if (min == &root.value()) { + min = root.value().left.get(); + } } if (root.value().right) { root.value().right->parent = &root.value(); + if (max == &root.value()) { + max = root.value().right.get(); + } } } template - std::pair locate(const Key_T &key) { - Node *ret_parent; + std::pair locate(const Key_T &key) const { + Node const *ret_parent; Direction ret_dir; // map is empty if (!this->root.has_value()) { @@ -639,22 +807,22 @@ public: return this->end(); } if (parent->child(dir) != nullptr) { - return Iterator{parent->child(dir), nullptr}; + return Iterator{const_cast(parent->child(dir)), nullptr}; } return this->end(); } - Iterator find_trace(const Key_T &key) { - auto [parent, dir] = locate(key); + ConstIterator find(const Key_T &key) const { + auto [parent, dir] = locate(key); if (parent == nullptr) { if (this->root.has_value()) { if (this->root.value().val.first == key) { - return Iterator{&this->root.value()}; + return Iterator{const_cast(&this->root.value())}; } } return this->end(); } if (parent->child(dir) != nullptr) { - return Iterator{parent->child(dir), nullptr}; + return Iterator{const_cast(parent->child(dir)), nullptr}; } return this->end(); } @@ -671,6 +839,9 @@ public: } else { this->root = Node{val, this}; this->root.value().color = Color::Black; + this->min = &this->root.value(); + this->max = &this->root.value(); + this->_size++; return std::make_pair(Iterator{&root.value()}, true); } } @@ -678,21 +849,35 @@ public: // non-root node if (parent->child(dir)) { // node already present - return std::make_pair(Iterator{parent->child(dir)}, false); + return std::make_pair(Iterator{const_cast(parent->child(dir))}, + false); } // need to insert non-root node - Node *new_node = - parent->set_child(dir, std::make_unique(Node{val, this})).get(); + Map *m = const_cast(this); + Node *new_node = const_cast(parent) + ->set_child(dir, std::make_unique(Node{val, m})) + .get(); new_node->restore_red_black_insert(dir); new_node->restore_ordering(); + if (this->min == parent && dir == Direction::Left) { + this->min = new_node; + } + if (this->max == parent && dir == Direction::Right) { + this->max = new_node; + } + this->_size++; return std::make_pair(Iterator{new_node}, true); } void erase(Iterator pos) { + if (pos.underlying == nullptr) { + return; + } + this->_size--; + Node *before = pos.underlying->prev; Node *after = pos.underlying->next; - if (core_erase(pos.underlying)) { pos.underlying->restore_ordering(); } else { @@ -701,12 +886,16 @@ public: if (before->next != nullptr) { before->next->prev = before; } + } else { + this->min = after; } if (after != nullptr) { after->prev = after->calc_pred(); if (after->prev != nullptr) { after->prev->next = after; } + } else { + this->max = before; } } } @@ -714,6 +903,46 @@ public: // baseline iterator creation Iterator begin() { return Iterator{min, nullptr}; } Iterator end() { return Iterator{nullptr, max}; } + ConstIterator begin() const { return Iterator{min, nullptr}; } + ConstIterator end() const { return Iterator{nullptr, max}; } + ReverseIterator rbegin() { return Iterator{max, nullptr}; } + ReverseIterator rend() { return Iterator{nullptr, min}; } + + // misc that can be implemented with the above or trivially + void clear() { + this->root = std::move(std::nullopt); + this->_size = 0; + } + Mapped_T &at(const Key_T &key) { + auto ret = this->find(key); + if (ret == this->end()) { + throw std::out_of_range{"key not in map"}; + } + return (*ret).second; + } + const Mapped_T &at(const Key_T &key) const { + auto ret = this->find(key); + if (ret == this->end()) { + throw std::out_of_range{"key not in map"}; + } + return (*ret).first; + } + Mapped_T &operator[](const Key_T &key) { + this->insert({key, {}}); + return this->at(key); + } + template void insert(IT_T range_beg, IT_T range_end) { + while (range_beg != range_end) { + auto [first, second] = *range_beg; + this->insert(std::make_pair(first, second)); + ++range_beg; + } + } + void erase(const Key_T &key) { this->erase(this->find(key)); } + // TODO: + friend bool operator==(const Map &lhs, const Map &rhs) { assert(false); } + friend bool operator!=(const Map &lhs, const Map &rhs) { assert(false); } + friend bool operator<(const Map &lhs, const Map &rhs) { assert(false); } }; } // namespace cs440 #endif diff --git a/t.cpp b/t.cpp index 8e5eb47..74f7470 100644 --- a/t.cpp +++ b/t.cpp @@ -1,22 +1,20 @@ #include "Map.hpp" +#include template class cs440::Map; int main(void) { - cs440::Map a; - for (std::size_t i = 1; i <= 10; i++) { - a.insert({i, i}); - a.check(); - for (std::size_t j = 1; j <= i; j++) { - a.find(j).check(); - } - } - for (std::size_t i = 1; i <= 10; i++) { - a.find(i).check(); - } - for (std::size_t i = 1; i <= 5; i++) { - std::cout << i << std::endl; - auto b = a.find(i); - a.erase(b); - } + // cs440::Map a; + // for (std::size_t i = 10; i >= 1; i--) { + // a.insert({i, i}); + // for (std::size_t j = 10; j >= i; j--) { + // } + // } + // for (std::size_t i = 10; i >= 5; i--) { + // std::cout << i << std::endl; + // auto b = a.find(i); + // a.erase(b); + // for (std::size_t j = 1; j <= i; j++) { + // } + // } return 0; }