diff --git a/SharedPtr.hpp b/SharedPtr.hpp index 82bb20c..d119ae6 100644 --- a/SharedPtr.hpp +++ b/SharedPtr.hpp @@ -1,6 +1,8 @@ #include +#include #ifndef _POWELLCS440 #include +#include #include #define todo(msg) \ @@ -9,75 +11,65 @@ namespace cs440 { using counter = std::atomic; -template class SharedPtr; -class Key { - Key() {} +template struct Base { + template friend struct Base; + + T *ptr; + Base(T *ptr) : ptr{ptr} {} + template Base(Base const &ptr) : ptr{ptr.ptr} {} + template operator U() { return U(*this); } public: - template friend class SharedPtr; - template void increment(SharedPtr const &ptr) { - const_cast &>(*ptr.count)++; - } - template counter *get_counter(SharedPtr const &ptr) { - return const_cast *>(ptr.count); - } - template U *get_ptr(SharedPtr const &ptr) { - return const_cast(ptr.ptr); - } - template void nullify(SharedPtr &&ptr) { - ptr.ptr = nullptr; - ptr.count = nullptr; - } + virtual T &operator*() { return *ptr; } + virtual T *operator->() { return ptr; } + virtual ~Base() = default; +}; +template class Derived : public Base { + static_assert(std::is_base_of_v); + U *ptr; + +public: + Derived(Base b) : Derived{*dynamic_cast *>(&b)} {} + U &operator*() { return *ptr; } + U *operator->() { return ptr; } + virtual ~Derived() { delete this->ptr; } }; template class SharedPtr { - T *ptr; - std::atomic *count; - public: - friend class Key; - SharedPtr() : ptr{nullptr}, count{new std::atomic{1}} {} + Base *ptr; + counter *count; + + // problem? + SharedPtr() : ptr{nullptr}, count{new std::atomic{0}} {} template explicit SharedPtr(U *ptr) - : ptr{ptr}, count{new std::atomic{1}} {} - SharedPtr(const SharedPtr &p) : ptr{p.ptr} { - Key().increment(p); - this->count = Key().get_counter(p); - } + : ptr{static_cast *>(new Derived{ptr})}, + count{new std::atomic{ptr ? 1 : 0}} {} + SharedPtr(const SharedPtr &p) : ptr{p.ptr} {} template - SharedPtr(const SharedPtr &p) : ptr{Key().get_ptr(p)} { - Key().increment(p); - this->count = Key().get_counter(p); + SharedPtr(const SharedPtr &p) : ptr(Base(*p.ptr)), count{p.count} { + (*count)++; } - SharedPtr(SharedPtr &&p) - : count{Key().get_counter(p)}, ptr{Key().get_ptr(p)} { - Key().nullify(p); - } - template - SharedPtr(SharedPtr &&p) - : count{Key().get_counter(p)}, ptr{Key().get_ptr(p)} { - Key().nullify(p); + SharedPtr(SharedPtr &&p) : ptr{p.ptr} { p.ptr = nullptr; } + template SharedPtr(SharedPtr &&p) : ptr{p.ptr} { + p.ptr = nullptr; } SharedPtr &operator=(const SharedPtr &ptr) { - Key().increment(ptr); - this->count = Key().get_counter(ptr); - this->ptr = Key().get_ptr(ptr); + this->ptr = ptr.ptr; + (*count)++; return *this; } template SharedPtr &operator=(const SharedPtr &ptr) { - Key().increment(ptr); - this->count = Key().get_counter(ptr); - this->ptr = Key().get_ptr(ptr); + this->ptr = new Derived{Base(ptr.ptr)}; return *this; } SharedPtr &operator=(SharedPtr &&p) { - this->count = Key().get_counter(ptr); - this->ptr = Key().get_ptr(ptr); - Key().nullify(p); + ptr = p.ptr; + p.ptr = nullptr; } template SharedPtr &operator=(SharedPtr &&p) { - this->count = Key().get_counter(ptr); - this->ptr = Key().get_ptr(ptr); - Key().nullify(p); + ptr = p.ptr; + p.ptr = nullptr; } ~SharedPtr() { if (this->count == nullptr) { @@ -91,33 +83,18 @@ public: delete this->ptr; } } - void reset() { - if (this->count != nullptr) { - if (--(*this->count) == 0) { - delete this->count; - delete this->ptr; - return; - } - } - this->ptr = nullptr; - this->count = nullptr; - } - template void reset(U *p) { - this->reset(); - this->count = new counter{1}; - this->ptr = p; - } - T *get() const { return this->ptr; } - T &operator*() const { return *this->ptr; } - T *operator->() const { return this->ptr; } - explicit operator bool() const { return this->ptr != nullptr; } - template + void reset() { this->ptr = nullptr; } + template void reset(U *p) { this->ptr = new Derived{p}; } + T *get() const { return this->ptr->ptr; } + T &operator*() const { return this->ptr->operator*(); } + T *operator->() const { return this->ptr->operator->(); } + explicit operator bool() const { return this->ptr->ptr != nullptr; } template - friend SharedPtr static_pointer_cast(const SharedPtr &sp) { + friend SharedPtr static_pointer_cast(const SharedPtr &sp) { todo(""); } template - friend SharedPtr dynamic_pointer_cast(const SharedPtr &sp) { + friend SharedPtr dynamic_pointer_cast(const SharedPtr &sp) { todo(""); } };