inheritence stuff is annoying to fix

This commit is contained in:
Pagwin 2024-12-13 20:55:06 -05:00
parent 45a0b765a4
commit 4a9fe59bb2
No known key found for this signature in database
GPG key ID: 81137023740CA260

View file

@ -1,6 +1,8 @@
#include <atomic> #include <atomic>
#include <type_traits>
#ifndef _POWELLCS440 #ifndef _POWELLCS440
#include <cassert> #include <cassert>
#include <concepts>
#include <iostream> #include <iostream>
#define todo(msg) \ #define todo(msg) \
@ -9,75 +11,65 @@
namespace cs440 { namespace cs440 {
using counter = std::atomic<std::size_t>; using counter = std::atomic<std::size_t>;
template <typename T> class SharedPtr; template <typename T> struct Base {
class Key { template <typename U> friend struct Base;
Key() {}
T *ptr;
Base(T *ptr) : ptr{ptr} {}
template <typename U> Base(Base<U> const &ptr) : ptr{ptr.ptr} {}
template <typename U> operator U() { return U(*this); }
public: public:
template <typename U> friend class SharedPtr; virtual T &operator*() { return *ptr; }
template <typename U> void increment(SharedPtr<U> const &ptr) { virtual T *operator->() { return ptr; }
const_cast<std::atomic<size_t> &>(*ptr.count)++; virtual ~Base() = default;
} };
template <typename U> counter *get_counter(SharedPtr<U> const &ptr) { template <typename T, typename U> class Derived : public Base<T> {
return const_cast<std::atomic<size_t> *>(ptr.count); static_assert(std::is_base_of_v<T, U>);
} U *ptr;
template <typename U> U *get_ptr(SharedPtr<U> const &ptr) {
return const_cast<U *>(ptr.ptr); public:
} Derived(Base<T> b) : Derived{*dynamic_cast<Derived<T, U> *>(&b)} {}
template <typename U> void nullify(SharedPtr<U> &&ptr) { U &operator*() { return *ptr; }
ptr.ptr = nullptr; U *operator->() { return ptr; }
ptr.count = nullptr; virtual ~Derived() { delete this->ptr; }
}
}; };
template <typename T> class SharedPtr { template <typename T> class SharedPtr {
T *ptr;
std::atomic<std::size_t> *count;
public: public:
friend class Key; Base<T> *ptr;
SharedPtr() : ptr{nullptr}, count{new std::atomic<std::size_t>{1}} {} counter *count;
// problem?
SharedPtr() : ptr{nullptr}, count{new std::atomic<std::size_t>{0}} {}
template <typename U> template <typename U>
explicit SharedPtr(U *ptr) explicit SharedPtr(U *ptr)
: ptr{ptr}, count{new std::atomic<std::size_t>{1}} {} : ptr{static_cast<Base<T> *>(new Derived<T, U>{ptr})},
SharedPtr(const SharedPtr &p) : ptr{p.ptr} { count{new std::atomic<std::size_t>{ptr ? 1 : 0}} {}
Key().increment(p); SharedPtr(const SharedPtr &p) : ptr{p.ptr} {}
this->count = Key().get_counter(p);
}
template <typename U> template <typename U>
SharedPtr(const SharedPtr<U> &p) : ptr{Key().get_ptr(p)} { SharedPtr(const SharedPtr<U> &p) : ptr(Base<T>(*p.ptr)), count{p.count} {
Key().increment(p); (*count)++;
this->count = Key().get_counter(p);
} }
SharedPtr(SharedPtr &&p) SharedPtr(SharedPtr &&p) : ptr{p.ptr} { p.ptr = nullptr; }
: count{Key().get_counter(p)}, ptr{Key().get_ptr(p)} { template <typename U> SharedPtr(SharedPtr<U> &&p) : ptr{p.ptr} {
Key().nullify(p); p.ptr = nullptr;
}
template <typename U>
SharedPtr(SharedPtr<U> &&p)
: count{Key().get_counter(p)}, ptr{Key().get_ptr(p)} {
Key().nullify(p);
} }
SharedPtr &operator=(const SharedPtr &ptr) { SharedPtr &operator=(const SharedPtr &ptr) {
Key().increment(ptr); this->ptr = ptr.ptr;
this->count = Key().get_counter(ptr); (*count)++;
this->ptr = Key().get_ptr(ptr);
return *this; return *this;
} }
template <typename U> SharedPtr<T> &operator=(const SharedPtr<U> &ptr) { template <typename U> SharedPtr<T> &operator=(const SharedPtr<U> &ptr) {
Key().increment(ptr); this->ptr = new Derived<T, U>{Base<T>(ptr.ptr)};
this->count = Key().get_counter(ptr);
this->ptr = Key().get_ptr(ptr);
return *this; return *this;
} }
SharedPtr &operator=(SharedPtr &&p) { SharedPtr &operator=(SharedPtr &&p) {
this->count = Key().get_counter(ptr); ptr = p.ptr;
this->ptr = Key().get_ptr(ptr); p.ptr = nullptr;
Key().nullify(p);
} }
template <typename U> SharedPtr &operator=(SharedPtr<U> &&p) { template <typename U> SharedPtr &operator=(SharedPtr<U> &&p) {
this->count = Key().get_counter(ptr); ptr = p.ptr;
this->ptr = Key().get_ptr(ptr); p.ptr = nullptr;
Key().nullify(p);
} }
~SharedPtr() { ~SharedPtr() {
if (this->count == nullptr) { if (this->count == nullptr) {
@ -91,33 +83,18 @@ public:
delete this->ptr; delete this->ptr;
} }
} }
void reset() { void reset() { this->ptr = nullptr; }
if (this->count != nullptr) { template <typename U> void reset(U *p) { this->ptr = new Derived<U>{p}; }
if (--(*this->count) == 0) { T *get() const { return this->ptr->ptr; }
delete this->count; T &operator*() const { return this->ptr->operator*(); }
delete this->ptr; T *operator->() const { return this->ptr->operator->(); }
return; explicit operator bool() const { return this->ptr->ptr != nullptr; }
}
}
this->ptr = nullptr;
this->count = nullptr;
}
template <typename U> void reset(U *p) {
this->reset();
this->count = new counter{1};
this->ptr = p;
}
T *get() const { return this->ptr; }
T &operator*() const { return *this->ptr; }
T *operator->() const { return this->ptr; }
explicit operator bool() const { return this->ptr != nullptr; }
template <typename T1, typename T2>
template <typename U> template <typename U>
friend SharedPtr<T> static_pointer_cast(const SharedPtr<U> &sp) { friend SharedPtr<U> static_pointer_cast(const SharedPtr<T> &sp) {
todo(""); todo("");
} }
template <typename U> template <typename U>
friend SharedPtr<T> dynamic_pointer_cast(const SharedPtr<U> &sp) { friend SharedPtr<U> dynamic_pointer_cast(const SharedPtr<T> &sp) {
todo(""); todo("");
} }
}; };