cs440-assignment3/SharedPtr.hpp

130 lines
3.7 KiB
C++

#include <atomic>
#ifndef _POWELLCS440
#include <cassert>
#include <iostream>
#define todo(msg) \
std::cerr << msg << std::endl; \
assert(false);
namespace cs440 {
using counter = std::atomic<std::size_t>;
template <typename T> struct Base {
template <typename U> friend struct Base;
T *ptr;
Base(T *ptr) : ptr{ptr} {}
template <typename U> Base(Base<U> &ptr) : ptr{ptr.ptr} {}
template <typename U> Base(Base<U> &&ptr) : ptr{ptr.ptr} {
ptr.ptr = nullptr;
}
public:
virtual T &operator*() { return *ptr; }
virtual T *operator->() { return ptr; }
virtual ~Base() = default;
};
template <typename T, typename U> class MyDerived : public Base<T> {
U *ptr;
public:
MyDerived(U *ptr) : Base<T>{ptr} { this->ptr = ptr; }
MyDerived(Base<U> &ptr) : Base<T>{ptr.ptr} { this->ptr = ptr.ptr; }
U &operator*() { return *ptr; }
U *operator->() { return ptr; }
virtual ~MyDerived() { delete this->ptr; }
};
template <typename T> class SharedPtr {
public:
Base<T> *ptr;
counter *count;
SharedPtr() : ptr{nullptr}, count{new std::atomic<std::size_t>{0}} {}
template <typename U>
explicit SharedPtr(U *ptr)
: ptr{static_cast<Base<T> *>(new MyDerived<T, U>{ptr})},
count{new std::atomic<std::size_t>{
static_cast<std::size_t>(ptr ? 1 : 0)}} {}
SharedPtr(const SharedPtr &p) : ptr{p.ptr}, count{p.count} {}
template <typename U>
SharedPtr(const SharedPtr<U> &p)
: ptr(MyDerived<U, T>(const_cast<cs440::Base<U> &>(*p.ptr))),
count{p.count} {
(*count)++;
}
SharedPtr(SharedPtr &&p) : ptr{p.ptr} { p.ptr = nullptr; }
template <typename U> SharedPtr(SharedPtr<U> &&p) : ptr{p.ptr} {
p.ptr = nullptr;
}
SharedPtr &operator=(const SharedPtr &ptr) {
this->ptr = ptr.ptr;
(*count)++;
return *this;
}
template <typename U> SharedPtr<T> &operator=(const SharedPtr<U> &ptr) {
this->ptr = ptr.ptr;
return *this;
}
SharedPtr &operator=(SharedPtr &&p) {
ptr = p.ptr;
p.ptr = nullptr;
return *this;
}
template <typename U> SharedPtr &operator=(SharedPtr<U> &&p) {
ptr = p.ptr;
p.ptr = nullptr;
}
~SharedPtr() {
if (this->count == nullptr) {
if (this->ptr != nullptr) {
delete this->ptr;
}
return;
}
if (--(*this->count) == 0) {
delete this->count;
delete this->ptr;
}
}
void reset() { this->ptr = nullptr; }
template <typename U> void reset(U *p) { this->ptr = new MyDerived<T, U>{p}; }
T *get() const { return this->ptr->ptr; }
T &operator*() const { return this->ptr->operator*(); }
T *operator->() const { return this->ptr->operator->(); }
explicit operator bool() const { return this->ptr->ptr != nullptr; }
};
template <typename T, typename U>
SharedPtr<T> static_pointer_cast(const SharedPtr<U> &sp) {
todo("");
}
template <typename T, typename U>
SharedPtr<T> dynamic_pointer_cast(const SharedPtr<U> &sp) {
todo("");
}
template <typename T1, typename T2>
bool operator==(const SharedPtr<T1> &lhs, const SharedPtr<T2> &rhs) {
return lhs.get() == rhs.get();
}
template <typename T>
bool operator==(const SharedPtr<T> &lhs, std::nullptr_t rhs) {
return lhs.get() == rhs;
}
template <typename T>
bool operator==(std::nullptr_t lhs, const SharedPtr<T> &rhs) {
return lhs == rhs.get();
}
template <typename T1, typename T2>
bool operator!=(const SharedPtr<T1> &lhs, const SharedPtr<T2> &rhs) {
return !(lhs == rhs);
}
template <typename T>
bool operator!=(const SharedPtr<T> &lhs, std::nullptr_t rhs) {
return !(lhs == rhs);
}
template <typename T>
bool operator!=(std::nullptr_t lhs, const SharedPtr<T> &rhs) {
return !(lhs == rhs);
}
} // namespace cs440
#endif