everything compiles, now need to figure out ref counting

This commit is contained in:
Pagwin 2026-01-15 14:21:46 -05:00
parent 2df44b29a2
commit 6ef58754d2
No known key found for this signature in database
GPG key ID: 81137023740CA260
3 changed files with 154 additions and 45 deletions

View file

@ -1,3 +1,3 @@
CXX=clang++ CXX=clang++
test: SharedPtr_test.cpp SharedPtr.hpp test: SharedPtr_test.cpp SharedPtr.hpp
$(CXX) SharedPtr_test.cpp -o test $(CXX) SharedPtr_test.cpp -Wall -fsanitize=address -g -o test

View file

@ -9,51 +9,160 @@
namespace cs440 { namespace cs440 {
using counter = std::atomic<std::size_t>; using counter = std::atomic<std::size_t>;
template <typename T> class SharedPtr { class Control {
counter count; counter count;
(???) underlying;
virtual void destroy() = 0;
public: public:
Control() : count{1} {};
void decrement() {
count--;
if (count == 0) {
this->destroy();
}
}
void increment() { count++; }
virtual ~Control() = default;
};
template <typename T> class ControlImpl : public Control {
T *val;
void destroy() override {
delete val;
delete this;
}
public:
ControlImpl(T *v) : Control{}, val{v} {}
~ControlImpl() = default;
};
template <typename T> class SharedPtr {
T *raw;
Control *underlying;
public:
template <typename U> friend class SharedPtr;
// typename U meaning a type which inherits from T // typename U meaning a type which inherits from T
SharedPtr(); SharedPtr() : raw{nullptr}, underlying{nullptr} {}
template <typename U> explicit SharedPtr(U *); template <typename U> explicit SharedPtr(U *p) {
this->raw = p;
this->underlying = new ControlImpl<U>(p);
}
SharedPtr(const SharedPtr &p); SharedPtr(const SharedPtr &p) : raw{p.raw}, underlying{p.underlying} {
template <typename U> SharedPtr(const SharedPtr<U> &p); this->underlying->increment();
}
template <typename U>
SharedPtr(const SharedPtr<U> &p) : raw{p.raw}, underlying{p.underlying} {
this->underlying->increment();
}
SharedPtr(SharedPtr &&p); SharedPtr(SharedPtr &&p) : raw{p.raw}, underlying{p.underlying} {
template <typename U> SharedPtr(SharedPtr<U> &&p); p.raw = nullptr;
p.underlying = nullptr;
};
template <typename U>
SharedPtr(SharedPtr<U> &&p) : raw{p.raw}, underlying{p.underlying} {
p.raw = nullptr;
p.underlying = nullptr;
}
// needs to handle self assignment // needs to handle self assignment
SharedPtr &operator=(const SharedPtr &); SharedPtr &operator=(const SharedPtr &rhs) {
template <typename U> SharedPtr<T> &operator=(const SharedPtr<U> &); if (this->underlying == rhs.underlying) {
return *this;
}
this->raw = rhs.raw;
this->underlying->decrement();
this->underlying = rhs.underlying;
this->underlying->increment();
return *this;
}
template <typename U> SharedPtr<T> &operator=(const SharedPtr<U> &rhs) {
if (this->underlying == rhs.underlying) {
return *this;
}
this->raw = rhs.raw;
this->underlying->decrement();
this->underlying = rhs.underlying;
this->underlying->increment();
return *this;
}
SharedPtr &operator=(SharedPtr &&p); SharedPtr &operator=(SharedPtr &&rhs) {
template <typename U> SharedPtr &operator=(SharedPtr<U> &&p); if (this == &rhs) {
~SharedPtr(); return *this;
}
this->raw = rhs.raw;
this->underlying->decrement();
this->underlying = rhs.underlying;
rhs.raw = nullptr;
rhs.underlying = nullptr;
}
template <typename U> SharedPtr &operator=(SharedPtr<U> &&rhs) {
if (this == &rhs) {
return *this;
}
this->raw = rhs.raw;
this->underlying->decrement();
this->underlying = rhs.underlying;
rhs.raw = nullptr;
rhs.underlying = nullptr;
}
~SharedPtr() { this->underlying->decrement(); }
void reset() { void reset() {
this->count = 0; this->raw = nullptr;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = nullptr; this->underlying = nullptr;
} }
template <typename U> void reset(U *p); template <typename U> void reset(U *p) {
T *get() const; this->raw = p;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = new ControlImpl<U>(p);
}
T *get() const { return this->raw; }
T &operator*() const { return *this->get(); } T &operator*() const { return *this->get(); }
T *operator->() const { return this->get(); } T *operator->() const { return this->get(); }
explicit operator bool() const; explicit operator bool() const { return this->raw; }
template <typename T1, typename T2>
friend bool operator==(const SharedPtr<T1> &, const SharedPtr<T2> &);
friend bool operator==(const SharedPtr<T> &, std::nullptr_t);
friend bool operator==(std::nullptr_t, const SharedPtr<T> &);
template <typename T1, typename T2>
friend bool operator!=(const SharedPtr<T1> &, const SharedPtr<T2> &);
friend bool operator!=(const SharedPtr<T> &, std::nullptr_t);
friend bool operator!=(std::nullptr_t, const SharedPtr<T> &);
template <typename U> template <typename U>
friend SharedPtr<T> static_pointer_cast(const SharedPtr<U> &sp); friend bool operator==(const SharedPtr<T> &lhs, const SharedPtr<U> &rhs) {
return lhs.raw == rhs.raw;
}
friend bool operator==(const SharedPtr<T> &lhs, std::nullptr_t) {
return lhs.raw == nullptr;
}
friend bool operator==(std::nullptr_t, const SharedPtr<T> &rhs) {
return nullptr == rhs.raw;
}
// template <typename T1, typename T2>
// friend bool operator!=(const SharedPtr<T1> &lhs, const SharedPtr<T2> &rhs)
// {
// return lhs.raw != rhs.raw;
// }
// friend bool operator!=(const SharedPtr<T> &lhs, std::nullptr_t) {
// return lhs.raw != nullptr;
// }
// friend bool operator!=(std::nullptr_t, const SharedPtr<T> &rhs) {
// return nullptr != rhs.raw;
// }
template <typename U> template <typename U>
friend SharedPtr<T> dynamic_pointer_cast(const SharedPtr<U> &sp); friend SharedPtr<T> static_pointer_cast(const SharedPtr<U> &sp) {
SharedPtr<T> ret{};
ret.raw = static_cast<T *>(sp.raw);
ret.underlying = sp.underlying;
}
template <typename U>
friend SharedPtr<T> dynamic_pointer_cast(const SharedPtr<U> &sp) {
SharedPtr<T> ret{};
ret.raw = dynamic_cast<T *>(sp.raw);
ret.underlying = sp.underlying;
}
}; };
} // namespace cs440 } // namespace cs440
#endif #endif

View file

@ -264,9 +264,9 @@ void basic_tests_1() {
{ {
SharedPtr<Derived> sp(new Derived); SharedPtr<Derived> sp(new Derived);
// Test template copy constructor. // Test template copy constructor.
// SharedPtr<Base1> sp3(sp); SharedPtr<Base1> sp3(sp);
// sp2 = sp; sp2 = sp;
// sp2 = sp2; sp2 = sp2;
} }
} }
} }
@ -279,7 +279,7 @@ void basic_tests_1() {
} }
{ {
SharedPtr<Derived> sp(new Derived); SharedPtr<Derived> sp(new Derived);
// SharedPtr<Base1> sp2(sp); SharedPtr<Base1> sp2(sp);
} }
} }
@ -294,9 +294,9 @@ void basic_tests_1() {
{ {
SharedPtr<Derived> sp(new Derived); SharedPtr<Derived> sp(new Derived);
SharedPtr<Base1> sp2; SharedPtr<Base1> sp2;
// sp2 = sp; sp2 = sp;
sp2 = sp2; // Self assignment. sp2 = sp2; // Self assignment.
// sp2 = sp; sp2 = sp;
sp = sp; sp = sp;
} }
} }
@ -306,7 +306,7 @@ void basic_tests_1() {
{ {
SharedPtr<Derived> sp(new Derived); SharedPtr<Derived> sp(new Derived);
SharedPtr<Base1> sp2; SharedPtr<Base1> sp2;
// sp2 = sp; sp2 = sp;
sp2 = sp2; sp2 = sp2;
sp.reset(); sp.reset();
sp.reset(new Derived); sp.reset(new Derived);
@ -339,10 +339,10 @@ void basic_tests_1() {
SharedPtr<Derived> sp(new Derived); SharedPtr<Derived> sp(new Derived);
(*sp).value = 1234; (*sp).value = 1234;
// SharedPtr<const Derived> sp2(sp); SharedPtr<const Derived> sp2(sp);
// int i = (*sp2).value; int i = (*sp2).value;
// assert(i == 1234); assert(i == 1234);
// (*sp2).value = 567; // Should give a syntax error if uncommented. //(*sp2).value = 567; // Should give a syntax error if uncommented.
} }
// Test ->. // Test ->.
@ -350,9 +350,9 @@ void basic_tests_1() {
SharedPtr<Derived> sp(new Derived); SharedPtr<Derived> sp(new Derived);
sp->value = 1234; sp->value = 1234;
// SharedPtr<const Derived> sp2(sp); SharedPtr<const Derived> sp2(sp);
// int i = sp2->value; int i = sp2->value;
// assert(i == 1234); assert(i == 1234);
// sp2->value = 567; // Should give a syntax error if uncommented. // sp2->value = 567; // Should give a syntax error if uncommented.
} }
@ -361,9 +361,9 @@ void basic_tests_1() {
SharedPtr<Derived> sp(new Derived); SharedPtr<Derived> sp(new Derived);
sp.get()->value = 1234; sp.get()->value = 1234;
// SharedPtr<const Derived> sp2(sp); SharedPtr<const Derived> sp2(sp);
// int i = sp2.get()->value; int i = sp2.get()->value;
// assert(i == 1234); assert(i == 1234);
} }
// Test bool. // Test bool.