cs440-assignment2/Map.hpp

969 lines
27 KiB
C++

#ifndef _POWELL_CS440
#define _POWELL_CS440
#include <algorithm>
// uncomment on submission/performance test
// #define NDEBUG
#include <cassert>
#include <memory>
#include <optional>
#include <stdexcept>
#include <utility>
namespace cs440 {
// universal type defs here
namespace {
enum class Direction { Left, Right };
Direction operator!(Direction dir) {
switch (dir) {
case Direction::Left:
return Direction::Right;
case Direction::Right:
return Direction::Left;
default:
assert(false);
}
}
enum class Color { Red, Black };
} // namespace
template <typename Key_T, typename Mapped_T> class Map {
// Type definitions here
using ValueType = std::pair<const Key_T, Mapped_T>;
using internal_ValueType = std::pair<Key_T, Mapped_T>;
struct Node {
int valid = 0x13371337;
Node *parent = nullptr;
std::unique_ptr<internal_ValueType> val;
std::unique_ptr<Node> left;
std::unique_ptr<Node> right;
Color color;
Node *prev;
Node *next;
Map *map;
Node(internal_ValueType val, Map *map)
: parent{nullptr}, val{new internal_ValueType{val}}, left{}, right{},
color{Color::Red}, prev{nullptr}, next{nullptr}, map{map} {}
Node(const Node &rhs)
: parent{nullptr},
val{rhs.val ? new internal_ValueType{*rhs.val} : nullptr},
left{rhs.left ? std::make_unique<Node>(*rhs.left) : nullptr},
right{rhs.right ? std::make_unique<Node>(*rhs.right) : nullptr},
color{rhs.color}, prev{nullptr}, next{nullptr}, map{rhs.map} {
this->valid = 0x13371337;
if (this->left) {
this->left->parent = this;
}
if (this->right) {
this->right->parent = this;
}
this->next = rhs.next;
this->prev = rhs.prev;
}
Node(Node &&rhs)
: parent{nullptr}, val{std::move(rhs.val)}, left{std::move(rhs.left)},
right{std::move(rhs.right)}, color{rhs.color}, prev{nullptr},
next{nullptr}, map{rhs.map} {
rhs.valid = 0;
this->valid = 0x13371337;
if (this->left) {
this->left->parent = this;
}
if (this->right) {
this->right->parent = this;
}
this->next = rhs.next;
this->prev = rhs.prev;
}
~Node() {}
Node &operator=(const Node &rhs) {
// retain parent as is, common case is the copy or move is happening due
// to a rotation where parent can get wonky
// this->parent
this->val =
rhs.val ? std::unique_ptr<internal_ValueType>{new internal_ValueType{
*rhs.val}}
: nullptr;
this->left = rhs.left ? std::make_unique<Node>(*rhs.left) : nullptr;
this->right = rhs.right ? std::make_unique<Node>(*rhs.right) : nullptr;
this->color = rhs.color;
this->valid = 0x13371337;
if (this->left) {
this->left->parent = this;
this->left->restore_ordering();
}
if (this->right) {
this->right->parent = this;
this->right->restore_ordering();
}
this->restore_ordering();
this->map = rhs.map;
return *this;
}
Node &operator=(Node &&rhs) {
// retain parent as is, common case is the copy or move is happening due
// to a rotation where parent can get wonky
// this->parent
this->val = std::move(rhs.val);
this->left = std::move(rhs.left);
this->right = std::move(rhs.right);
this->color = rhs.color;
this->valid = 0x13371337;
rhs.valid = 0;
if (this->left) {
this->left->parent = this;
this->left->restore_ordering();
}
if (this->right) {
this->right->parent = this;
this->right->restore_ordering();
}
this->restore_ordering();
this->map = rhs.map;
return *this;
}
Node *child(Direction dir) {
switch (dir) {
case Direction::Left:
return this->left.get();
case Direction::Right:
return this->right.get();
default:
assert(false);
}
}
Node const *child(Direction dir) const {
switch (dir) {
case Direction::Left:
return this->left.get();
case Direction::Right:
return this->right.get();
default:
assert(false);
}
}
std::unique_ptr<Node> uchild(Direction dir) {
switch (dir) {
case Direction::Left:
return std::move(this->left);
case Direction::Right:
return std::move(this->right);
default:
assert(false);
}
}
std::unique_ptr<Node> uchild(Node *child) {
return this->uchild(this->which_child(child));
}
std::unique_ptr<Node> &set_child(Direction dir,
std::unique_ptr<Node> new_child) {
if (new_child) {
new_child->parent = this;
}
switch (dir) {
case Direction::Left:
this->left = std::move(new_child);
if (this->left) {
this->left->parent = this;
}
return this->left;
case Direction::Right:
this->right = std::move(new_child);
if (this->right) {
this->right->parent = this;
}
return this->right;
default:
assert(false);
}
}
Direction which_child(Node *n) {
if (this->left.get() == n) {
return Direction::Left;
}
if (this->right.get() == n) {
return Direction::Right;
}
assert(false);
}
void erase_child(Node *n) { this->erase_child(this->which_child(n)); }
void erase_child(Direction dir) {
bool minmax = this->child(dir) == this->map->min ||
this->child(dir) == this->map->max;
// bringing ownership to this function scope so Deleter gets called at end
// of function and we can do reordering things
std::unique_ptr<Node> dropping;
switch (dir) {
case Direction::Right:
dropping = std::move(this->right);
break;
case Direction::Left:
dropping = std::move(this->left);
break;
default:
assert(false);
}
// intuitively should be correct but might need to do restore ordering on
// both instead
if (dropping->prev != nullptr) {
dropping->prev->next = dropping->next;
}
if (dropping->next != nullptr) {
dropping->next->prev = dropping->prev;
}
if (minmax) {
switch (dir) {
case Direction::Left:
this->map->min = this;
break;
case Direction::Right:
this->map->max = this;
break;
assert(false);
}
}
}
void restore_ordering() {
this->prev = this->calc_pred();
this->next = this->calc_succ();
if (this->prev) {
this->prev->next = this;
}
if (this->next) {
this->next->prev = this;
}
}
Node *calc_pred() {
if (this->left) {
Node *ret = this->left.get();
while (ret->right) {
ret = ret->right.get();
}
return ret;
} else {
Node *ret = this->parent;
Node *prev_ret = this;
while (ret && ret->which_child(prev_ret) != Direction::Right) {
prev_ret = ret;
ret = prev_ret->parent;
}
return ret;
}
}
Node *calc_succ() {
if (this->right) {
Node *ret = this->right.get();
while (ret->left) {
ret = ret->left.get();
}
return ret;
} else {
Node *ret = this->parent;
Node *prev_ret = this;
while (ret && ret->which_child(prev_ret) != Direction::Left) {
prev_ret = ret;
ret = prev_ret->parent;
}
return ret;
}
}
void rotate(Direction dir) {
// cannot rotate nullptr
assert(this != nullptr);
// we can't be root for this rotate operation
assert(this->parent != nullptr);
// if we're missing the child on the opposite direction this is an invalid
// rotation
assert(this->child(!dir));
// gotta pull outselves out of parent to avoid accidentally overwriting
// outselves
std::unique_ptr<Node> self = this->parent->uchild(!dir);
// make sure this is actually us
assert(self.get() == this);
// make our former position the position of the relevant child
this->parent->set_child(!dir, this->uchild(!dir));
// steal our former child's child
this->set_child(!dir, this->parent->child(!dir)->uchild(dir));
// make ourselves our former child's child
this->parent->child(!dir)->set_child(dir, std::move(self));
}
// Referencing
// https://en.wikipedia.org/wiki/Red%E2%80%93black_tree#Notes_to_the_insert_diagrams
void restore_red_black_insert(Direction dir) {
Node *self = this;
// infinite loop for case 2's sake, if tail recursion optimization was
// guaranteed I'd use tail recursion
while (true) {
Node *parent = self->parent;
// we're root, no-op (case 3)
if (!parent) {
self->color = Color::Black;
return;
}
// if this is violated it's a bug
assert(parent->child(dir) == self);
// parent is black so no violation no-op (case 1)
if (parent->color == Color::Black) {
return;
}
Node *grandparent = parent->parent;
// parent is root (case 4)
if (!grandparent) {
parent->color = Color::Black;
return;
}
// table showing transforms on wikipedia doesn't have this so if it
// happens it's probably a bug
assert(grandparent->color == Color::Black);
Node *uncle = grandparent->child(!grandparent->which_child(parent));
if (uncle == nullptr || uncle->color == Color::Black) {
if (parent->which_child(self) != grandparent->which_child(parent)) {
// we're an inner child
// case 5
parent->rotate(dir);
self = parent;
parent = self->parent;
}
// case 6
// recolor first so we aren't recoloring a dropped reference or smth
parent->color = Color::Black;
grandparent->color = Color::Red;
if (grandparent->parent == nullptr) {
map->rotate_root(!dir);
} else {
grandparent->rotate(!dir);
}
return;
}
// case 2 (by process of elimination)
parent->color = Color::Black;
uncle->color = Color::Black;
grandparent->color = Color::Red;
self = grandparent;
}
}
};
// data needed for implementation
std::optional<Node> root;
std::size_t _size;
Node *min;
Node *max;
public:
friend Node;
class ConstIterator;
class ReverseIterator;
// public type definitions
class Iterator {
Node *underlying;
Node *store;
Iterator(Node *ptr, Node *potential = nullptr)
: underlying{ptr}, store{potential} {}
public:
friend Map;
friend ConstIterator;
friend ReverseIterator;
Iterator() = delete;
Iterator(const Iterator &rhs) = default;
Iterator &operator=(const Iterator &) = default;
~Iterator() = default;
// precrement
Iterator &operator++() {
if (this->underlying == nullptr) {
this->underlying = this->store;
this->store = nullptr;
return *this;
}
if (this->underlying->next == nullptr) {
this->store = this->underlying;
}
this->underlying = this->underlying->next;
return *this;
}
// postcrement
Iterator operator++(int) {
auto copy = *this;
this->operator++();
return copy;
}
// precrement
Iterator &operator--() {
if (this->underlying == nullptr) {
this->underlying = this->store;
this->store = nullptr;
return *this;
}
if (this->underlying->prev == nullptr) {
this->store = this->underlying;
}
this->underlying = this->underlying->prev;
return *this;
}
// postcrement
Iterator operator--(int) {
auto copy = *this;
this->operator--();
return copy;
}
ValueType &operator*() const {
ValueType *ret = (ValueType *)(&this->underlying->val);
return *ret;
}
ValueType *operator->() const { return &this->operator*(); }
friend bool operator==(const Iterator &lhs, const Iterator &rhs) {
return lhs.underlying == rhs.underlying;
}
friend bool operator!=(const Iterator &lhs, const Iterator &rhs) {
return !(lhs == rhs);
}
};
class ConstIterator {
Iterator underlying;
public:
friend Map;
ConstIterator() = delete;
ConstIterator(const Iterator &underlying) : underlying{underlying} {}
ConstIterator(const ConstIterator &rhs) = default;
ConstIterator &operator=(const ConstIterator &) = default;
~ConstIterator() = default;
ConstIterator &operator++() {
++underlying;
return *this;
}
ConstIterator operator++(int) {
auto copy = *this;
this->operator++();
return copy;
}
ConstIterator &operator--() {
--underlying;
return *this;
}
ConstIterator operator--(int) {
auto copy = *this;
this->operator--();
return copy;
}
const ValueType &operator*() const { return this->underlying.operator*(); }
const ValueType *operator->() const { return &this->operator*(); }
friend bool operator==(const ConstIterator &lhs, const ConstIterator &rhs) {
return lhs.underlying == rhs.underlying;
}
friend bool operator!=(const ConstIterator &lhs, const ConstIterator &rhs) {
return !(lhs == rhs);
}
friend bool operator==(const Iterator &lhs, const ConstIterator &rhs) {
return lhs == rhs.underlying;
}
friend bool operator!=(const Iterator &lhs, const ConstIterator &rhs) {
return !(lhs == rhs);
}
friend bool operator==(const ConstIterator &lhs, const Iterator &rhs) {
return lhs.underlying == rhs;
}
friend bool operator!=(const ConstIterator &lhs, const Iterator &rhs) {
return !(lhs == rhs);
}
};
class ReverseIterator {
Iterator underlying;
ReverseIterator(const Iterator &underlying) : underlying{underlying} {}
public:
friend Map;
ReverseIterator() = delete;
ReverseIterator(const ReverseIterator &) = default;
~ReverseIterator() = default;
ReverseIterator &operator=(const ReverseIterator &) = default;
ReverseIterator &operator++() {
--underlying;
return *this;
}
ReverseIterator operator++(int) {
auto copy = *this;
this->operator++();
return copy;
}
ReverseIterator &operator--() {
++underlying;
return *this;
}
ReverseIterator operator--(int) {
auto copy = *this;
this->operator++();
return copy;
}
const ValueType &operator*() const { return this->underlying.operator*(); }
const ValueType *operator->() const { return &this->operator*(); }
};
Map() : root{}, _size{0} {}
Map(const Map &rhs) : root{rhs.root}, _size{rhs._size} {
this->min = &this->root.value();
this->max = &this->root.value();
while (min->left) {
min = min->left.get();
}
while (min->right) {
min = min->left.get();
}
}
Map(Map &&rhs) : root{std::move(rhs.root)}, _size{rhs._size} {
this->min = &this->root.value();
this->max = &this->root.value();
while (min->left) {
min = min->left.get();
}
while (min->right) {
min = min->left.get();
}
}
Map &operator=(const Map &rhs) {
this->root = rhs.root;
this->_size = rhs._size;
return *this;
}
Map &operator=(Map &&rhs) {
this->root = std::move(rhs.root);
this->_size = rhs._size;
return *this;
}
Map(std::initializer_list<std::pair<const Key_T, Mapped_T>> items) : Map{} {
this->insert(items.begin(), items.end());
}
void check() {
assert(!this->root || this->root.value().color == Color::Black);
}
std::size_t size() const { return this->_size; }
bool empty() const { return this->size() == 0; }
private:
// private helpers
void rotate_root(Direction dir) {
assert(root.has_value());
std::unique_ptr<Node> new_root = root.value().uchild(!dir);
// can't make null the new root
assert(new_root);
std::unique_ptr<Node> old_root =
std::make_unique<Node>(std::move(root.value()));
root.value() = std::move(*new_root.release());
old_root->set_child(!dir, root.value().uchild(dir));
if (old_root->left) {
old_root->left->parent = old_root.get();
}
if (old_root->right) {
old_root->right->parent = old_root.get();
}
if (old_root->next) {
old_root->next->prev = old_root.get();
}
if (old_root->prev) {
old_root->prev->next = old_root.get();
}
root.value().set_child(dir, std::move(old_root));
root.value().child(dir)->restore_ordering();
if (root.value().left) {
root.value().left->parent = &root.value();
if (min == &root.value()) {
min = root.value().left.get();
}
}
if (root.value().right) {
root.value().right->parent = &root.value();
if (max == &root.value()) {
max = root.value().right.get();
}
}
}
template <bool trace = false>
std::pair<Node const *, Direction> locate(const Key_T &key) const {
Node const *ret_parent;
Direction ret_dir;
// map is empty
if (!this->root.has_value()) {
return std::make_pair(nullptr, ret_dir);
}
// value is in root
if (this->root.value().val->first == key) {
return std::make_pair(nullptr, ret_dir);
}
ret_parent = &this->root.value();
if (key < ret_parent->val->first) {
ret_dir = Direction::Left;
} else {
ret_dir = Direction::Right;
}
while (ret_parent->child(ret_dir) &&
!(ret_parent->child(ret_dir)->val->first == key)) {
ret_parent = ret_parent->child(ret_dir);
if (key < ret_parent->val->first) {
ret_dir = Direction::Left;
} else {
ret_dir = Direction::Right;
}
}
return std::make_pair(ret_parent, ret_dir);
}
void hard_erase(Node *n) {
assert(n->parent);
Node *parent = n->parent;
Direction dir = parent->which_child(n);
parent->erase_child(n);
goto skip;
while (true) {
parent = n->parent;
if (parent == nullptr) {
// we're at root we're done (case 1)
return;
}
dir = parent->which_child(n);
skip:
Color par_color = parent->color;
Node *sibling = parent->child(!dir);
Color sibling_color = sibling ? sibling->color : Color::Black;
Node *close = sibling ? sibling->child(dir) : nullptr;
Node *distant = sibling ? sibling->child(!dir) : nullptr;
Color close_color = close ? close->color : Color::Black;
Color distant_color = distant ? distant->color : Color::Black;
#define redcheck(v) if ((v) == Color::Red)
// it kinda sucks but I think that goto is genuinely the best solution
// here, making methods for cases 4,5 and 6 is a lot of unneeded
// bookkeeping
redcheck(sibling_color) {
// case 3
if (parent->parent != nullptr) {
parent->rotate(dir);
} else {
parent->map->rotate_root(dir);
}
parent->color = Color::Red;
sibling->color = Color::Black;
sibling = close;
distant = sibling->child(!dir);
if (distant != nullptr && distant->color == Color::Red) {
goto case_6;
}
close = sibling->child(dir);
if (close != nullptr && close->color == Color::Red) {
goto case_5;
}
goto case_4;
}
else redcheck(close_color) {
// case 5
case_5:
assert(sibling);
assert(close);
sibling->rotate(!dir);
sibling->color = Color::Red;
close->color = Color::Black;
distant = sibling;
sibling = close;
goto case_6;
}
else redcheck(distant_color) {
// case 6
case_6:
assert(parent);
assert(sibling);
assert(distant);
if (parent->parent != nullptr) {
parent->rotate(dir);
} else {
parent->map->rotate_root(dir);
}
Color tmp = parent->color;
sibling->color = tmp;
parent->color = Color::Black;
distant->color = Color::Black;
return;
}
else redcheck(par_color) {
// case 4
case_4:
assert(sibling);
sibling->color = Color::Red;
parent->color = Color::Black;
return;
}
else {
// case 2
assert(sibling);
sibling->color = Color::Red;
n = parent;
continue;
}
}
}
bool core_erase(Node *erasing) {
Color c = erasing->color;
// 2 children
if (erasing->left && erasing->right) {
Node *succ = erasing->next;
erasing->val = std::move(succ->val);
this->core_erase(succ);
}
// 1 child
else if (erasing->left) {
*erasing = std::move(*erasing->left.release());
if (erasing->prev != nullptr) {
erasing->prev->next = erasing;
}
if (erasing->next != nullptr) {
erasing->next->prev = erasing;
}
erasing->color = c;
return true;
} else if (erasing->right) {
*erasing = std::move(*erasing->right.release());
if (erasing->prev != nullptr) {
erasing->prev->next = erasing;
}
if (erasing->next != nullptr) {
erasing->next->prev = erasing;
}
erasing->color = c;
return true;
}
// no children and root
else if (erasing->parent == nullptr) {
erasing->map->root = std::nullopt;
}
// no children and red
else if (erasing->color == Color::Red) {
erasing->parent->erase_child(erasing);
}
// no children and black
else {
hard_erase(erasing);
}
return false;
}
public:
// baseline find using locate
Iterator find(const Key_T &key) {
auto [parent, dir] = locate(key);
if (parent == nullptr) {
if (this->root.has_value()) {
if (this->root.value().val->first == key) {
return Iterator{&this->root.value()};
}
}
return this->end();
}
if (parent->child(dir) != nullptr) {
return Iterator{const_cast<Node *>(parent->child(dir)), nullptr};
}
return this->end();
}
ConstIterator find(const Key_T &key) const {
auto [parent, dir] = locate(key);
if (parent == nullptr) {
if (this->root.has_value()) {
if (this->root.value().val->first == key) {
return Iterator{const_cast<Node *>(&this->root.value())};
}
}
return this->end();
}
if (parent->child(dir) != nullptr) {
return Iterator{const_cast<Node *>(parent->child(dir)), nullptr};
}
return this->end();
}
// baseline modification operations
std::pair<Iterator, bool> insert(const ValueType &val) {
// for convenience
auto &[key, map] = val;
auto [parent, dir] = locate(key);
// located root node
if (parent == nullptr) {
if (this->root.has_value()) {
return std::make_pair(Iterator{&root.value()}, false);
} else {
this->root = Node{val, this};
this->root.value().color = Color::Black;
this->min = &this->root.value();
this->max = &this->root.value();
this->_size++;
return std::make_pair(Iterator{&root.value()}, true);
}
}
// non-root node
if (parent->child(dir)) {
// node already present
return std::make_pair(Iterator{const_cast<Node *>(parent->child(dir))},
false);
}
// need to insert non-root node
Map *m = const_cast<Map *>(this);
Node *new_node = const_cast<Node *>(parent)
->set_child(dir, std::make_unique<Node>(Node{val, m}))
.get();
new_node->restore_red_black_insert(dir);
new_node->restore_ordering();
if (this->min == parent && dir == Direction::Left) {
this->min = new_node;
}
if (this->max == parent && dir == Direction::Right) {
this->max = new_node;
}
this->_size++;
return std::make_pair(Iterator{new_node}, true);
}
void erase(Iterator pos) {
if (pos.underlying == nullptr) {
return;
}
this->_size--;
Node *before = pos.underlying->prev;
Node *after = pos.underlying->next;
if (core_erase(pos.underlying)) {
pos.underlying->restore_ordering();
} else {
if (before != nullptr && before->valid == 0x13371337) {
before->next = before->calc_succ();
if (before->next != nullptr) {
before->next->prev = before;
}
} else {
this->min = after;
}
if (after != nullptr && after->valid == 0x13371337) {
after->prev = after->calc_pred();
if (after->prev != nullptr) {
after->prev->next = after;
}
} else {
this->max = before;
}
}
}
// baseline iterator creation
Iterator begin() { return Iterator{min, nullptr}; }
Iterator end() { return Iterator{nullptr, max}; }
ConstIterator begin() const { return Iterator{min, nullptr}; }
ConstIterator end() const { return Iterator{nullptr, max}; }
ReverseIterator rbegin() { return Iterator{max, nullptr}; }
ReverseIterator rend() { return Iterator{nullptr, min}; }
// misc that can be implemented with the above or trivially
void clear() {
this->root = std::move(std::nullopt);
this->_size = 0;
}
Mapped_T &at(const Key_T &key) {
auto ret = this->find(key);
if (ret == this->end()) {
throw std::out_of_range{"key not in map"};
}
return (*ret).second;
}
const Mapped_T &at(const Key_T &key) const {
auto ret = this->find(key);
if (ret == this->end()) {
throw std::out_of_range{"key not in map"};
}
return (*ret).second;
}
Mapped_T &operator[](const Key_T &key) {
this->insert({key, {}});
return this->at(key);
}
template <typename IT_T> void insert(IT_T range_beg, IT_T range_end) {
while (range_beg != range_end) {
auto [first, second] = *range_beg;
this->insert(std::make_pair(first, second));
++range_beg;
}
}
void erase(const Key_T &key) { this->erase(this->find(key)); }
friend bool operator==(const Map &lhs, const Map &rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
auto lhs_iter = lhs.begin();
auto rhs_iter = rhs.begin();
while (lhs_iter != lhs.end()) {
if (*lhs_iter != *rhs_iter) {
return false;
}
++lhs_iter;
++rhs_iter;
}
return true;
}
friend bool operator!=(const Map &lhs, const Map &rhs) {
return !(lhs == rhs);
}
friend bool operator<(const Map &lhs, const Map &rhs) {
auto lhs_iter = lhs.begin();
auto rhs_iter = rhs.begin();
while (lhs_iter != lhs.end() && rhs_iter != rhs.end()) {
if (*lhs_iter < *rhs_iter) {
return true;
}
if (*lhs_iter != *rhs_iter) {
return false;
}
++lhs_iter;
++rhs_iter;
}
return lhs.size() < rhs.size();
}
};
} // namespace cs440
#endif