cs440-assignment3/SharedPtr.hpp

189 lines
4.9 KiB
C++

#include <atomic>
#ifndef _POWELLCS440
#include <cassert>
#include <iostream>
#define todo(msg) \
std::cerr << msg << std::endl; \
assert(false);
namespace cs440 {
using counter = std::atomic<std::size_t>;
class Control {
counter count;
virtual void destroy() = 0;
public:
Control() : count{1} {};
void decrement() {
count--;
if (count == 0) {
this->destroy();
}
}
void increment() { count++; }
virtual ~Control() = default;
};
template <typename T> class ControlImpl : public Control {
T *val;
void destroy() override {
// std::cerr << this << '\t' << val << std::endl;
// delete val;
// delete this;
}
public:
ControlImpl(T *v) : Control{}, val{v} {}
~ControlImpl() = default;
};
template <typename T> class SharedPtr {
T *raw;
Control *underlying;
public:
template <typename U> friend class SharedPtr;
// typename U meaning a type which inherits from T
SharedPtr() : raw{nullptr}, underlying{nullptr} {}
template <typename U> explicit SharedPtr(U *p) {
this->raw = p;
this->underlying = new ControlImpl<U>(p);
}
SharedPtr(const SharedPtr &p) : raw{p.raw}, underlying{p.underlying} {
if (this->underlying != nullptr) {
this->underlying->increment();
}
}
template <typename U>
SharedPtr(const SharedPtr<U> &p) : raw{p.raw}, underlying{p.underlying} {
if (this->underlying != nullptr) {
this->underlying->increment();
}
}
SharedPtr(SharedPtr &&p) : raw{p.raw}, underlying{p.underlying} {
p.raw = nullptr;
p.underlying = nullptr;
};
template <typename U>
SharedPtr(SharedPtr<U> &&p) : raw{p.raw}, underlying{p.underlying} {
p.raw = nullptr;
p.underlying = nullptr;
}
// needs to handle self assignment
SharedPtr &operator=(const SharedPtr &rhs) {
if (this->underlying == rhs.underlying) {
return *this;
}
this->raw = rhs.raw;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = rhs.underlying;
if (this->underlying != nullptr) {
this->underlying->increment();
}
return *this;
}
template <typename U> SharedPtr<T> &operator=(const SharedPtr<U> &rhs) {
if (this->underlying == rhs.underlying) {
return *this;
}
this->raw = rhs.raw;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = rhs.underlying;
if (this->underlying != nullptr) {
this->underlying->increment();
}
return *this;
}
SharedPtr &operator=(SharedPtr &&rhs) {
if (this == &rhs) {
return *this;
}
this->raw = rhs.raw;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = rhs.underlying;
rhs.raw = nullptr;
rhs.underlying = nullptr;
}
template <typename U> SharedPtr &operator=(SharedPtr<U> &&rhs) {
if (this == &rhs) {
return *this;
}
this->raw = rhs.raw;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = rhs.underlying;
rhs.raw = nullptr;
rhs.underlying = nullptr;
}
~SharedPtr() {
if (this->underlying != nullptr) {
this->underlying->decrement();
}
}
void reset() {
this->raw = nullptr;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = nullptr;
}
template <typename U> void reset(U *p) {
this->raw = p;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = new ControlImpl<U>(p);
}
T *get() const { return this->raw; }
T &operator*() const { return *this->get(); }
T *operator->() const { return this->get(); }
explicit operator bool() const { return this->raw; }
template <typename U>
friend bool operator==(const SharedPtr<T> &lhs, const SharedPtr<U> &rhs) {
return lhs.raw == rhs.raw;
}
friend bool operator==(const SharedPtr<T> &lhs, std::nullptr_t) {
return lhs.raw == nullptr;
}
friend bool operator==(std::nullptr_t, const SharedPtr<T> &rhs) {
return nullptr == rhs.raw;
}
// template <typename T1, typename T2>
// friend bool operator!=(const SharedPtr<T1> &lhs, const SharedPtr<T2> &rhs)
// {
// return lhs.raw != rhs.raw;
// }
// friend bool operator!=(const SharedPtr<T> &lhs, std::nullptr_t) {
// return lhs.raw != nullptr;
// }
// friend bool operator!=(std::nullptr_t, const SharedPtr<T> &rhs) {
// return nullptr != rhs.raw;
// }
template <typename U>
friend SharedPtr<T> static_pointer_cast(const SharedPtr<U> &sp) {
SharedPtr<T> ret{};
ret.raw = static_cast<T *>(sp.raw);
ret.underlying = sp.underlying;
}
template <typename U>
friend SharedPtr<T> dynamic_pointer_cast(const SharedPtr<U> &sp) {
SharedPtr<T> ret{};
ret.raw = dynamic_cast<T *>(sp.raw);
ret.underlying = sp.underlying;
}
};
} // namespace cs440
#endif