everything compiles, now need to figure out ref counting

This commit is contained in:
Pagwin 2026-01-15 14:21:46 -05:00
parent 2df44b29a2
commit 6ef58754d2
No known key found for this signature in database
GPG key ID: 81137023740CA260
3 changed files with 154 additions and 45 deletions

View file

@ -1,3 +1,3 @@
CXX=clang++
test: SharedPtr_test.cpp SharedPtr.hpp
$(CXX) SharedPtr_test.cpp -o test
$(CXX) SharedPtr_test.cpp -Wall -fsanitize=address -g -o test

View file

@ -9,51 +9,160 @@
namespace cs440 {
using counter = std::atomic<std::size_t>;
template <typename T> class SharedPtr {
class Control {
counter count;
(???) underlying;
virtual void destroy() = 0;
public:
Control() : count{1} {};
void decrement() {
count--;
if (count == 0) {
this->destroy();
}
}
void increment() { count++; }
virtual ~Control() = default;
};
template <typename T> class ControlImpl : public Control {
T *val;
void destroy() override {
delete val;
delete this;
}
public:
ControlImpl(T *v) : Control{}, val{v} {}
~ControlImpl() = default;
};
template <typename T> class SharedPtr {
T *raw;
Control *underlying;
public:
template <typename U> friend class SharedPtr;
// typename U meaning a type which inherits from T
SharedPtr();
SharedPtr() : raw{nullptr}, underlying{nullptr} {}
template <typename U> explicit SharedPtr(U *);
template <typename U> explicit SharedPtr(U *p) {
this->raw = p;
this->underlying = new ControlImpl<U>(p);
}
SharedPtr(const SharedPtr &p);
template <typename U> SharedPtr(const SharedPtr<U> &p);
SharedPtr(const SharedPtr &p) : raw{p.raw}, underlying{p.underlying} {
this->underlying->increment();
}
template <typename U>
SharedPtr(const SharedPtr<U> &p) : raw{p.raw}, underlying{p.underlying} {
this->underlying->increment();
}
SharedPtr(SharedPtr &&p);
template <typename U> SharedPtr(SharedPtr<U> &&p);
SharedPtr(SharedPtr &&p) : raw{p.raw}, underlying{p.underlying} {
p.raw = nullptr;
p.underlying = nullptr;
};
template <typename U>
SharedPtr(SharedPtr<U> &&p) : raw{p.raw}, underlying{p.underlying} {
p.raw = nullptr;
p.underlying = nullptr;
}
// needs to handle self assignment
SharedPtr &operator=(const SharedPtr &);
template <typename U> SharedPtr<T> &operator=(const SharedPtr<U> &);
SharedPtr &operator=(const SharedPtr &rhs) {
if (this->underlying == rhs.underlying) {
return *this;
}
this->raw = rhs.raw;
this->underlying->decrement();
this->underlying = rhs.underlying;
this->underlying->increment();
return *this;
}
template <typename U> SharedPtr<T> &operator=(const SharedPtr<U> &rhs) {
if (this->underlying == rhs.underlying) {
return *this;
}
this->raw = rhs.raw;
this->underlying->decrement();
this->underlying = rhs.underlying;
this->underlying->increment();
return *this;
}
SharedPtr &operator=(SharedPtr &&p);
template <typename U> SharedPtr &operator=(SharedPtr<U> &&p);
~SharedPtr();
SharedPtr &operator=(SharedPtr &&rhs) {
if (this == &rhs) {
return *this;
}
this->raw = rhs.raw;
this->underlying->decrement();
this->underlying = rhs.underlying;
rhs.raw = nullptr;
rhs.underlying = nullptr;
}
template <typename U> SharedPtr &operator=(SharedPtr<U> &&rhs) {
if (this == &rhs) {
return *this;
}
this->raw = rhs.raw;
this->underlying->decrement();
this->underlying = rhs.underlying;
rhs.raw = nullptr;
rhs.underlying = nullptr;
}
~SharedPtr() { this->underlying->decrement(); }
void reset() {
this->count = 0;
this->raw = nullptr;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = nullptr;
}
template <typename U> void reset(U *p);
T *get() const;
template <typename U> void reset(U *p) {
this->raw = p;
if (this->underlying != nullptr) {
this->underlying->decrement();
}
this->underlying = new ControlImpl<U>(p);
}
T *get() const { return this->raw; }
T &operator*() const { return *this->get(); }
T *operator->() const { return this->get(); }
explicit operator bool() const;
template <typename T1, typename T2>
friend bool operator==(const SharedPtr<T1> &, const SharedPtr<T2> &);
friend bool operator==(const SharedPtr<T> &, std::nullptr_t);
friend bool operator==(std::nullptr_t, const SharedPtr<T> &);
template <typename T1, typename T2>
friend bool operator!=(const SharedPtr<T1> &, const SharedPtr<T2> &);
friend bool operator!=(const SharedPtr<T> &, std::nullptr_t);
friend bool operator!=(std::nullptr_t, const SharedPtr<T> &);
explicit operator bool() const { return this->raw; }
template <typename U>
friend SharedPtr<T> static_pointer_cast(const SharedPtr<U> &sp);
friend bool operator==(const SharedPtr<T> &lhs, const SharedPtr<U> &rhs) {
return lhs.raw == rhs.raw;
}
friend bool operator==(const SharedPtr<T> &lhs, std::nullptr_t) {
return lhs.raw == nullptr;
}
friend bool operator==(std::nullptr_t, const SharedPtr<T> &rhs) {
return nullptr == rhs.raw;
}
// template <typename T1, typename T2>
// friend bool operator!=(const SharedPtr<T1> &lhs, const SharedPtr<T2> &rhs)
// {
// return lhs.raw != rhs.raw;
// }
// friend bool operator!=(const SharedPtr<T> &lhs, std::nullptr_t) {
// return lhs.raw != nullptr;
// }
// friend bool operator!=(std::nullptr_t, const SharedPtr<T> &rhs) {
// return nullptr != rhs.raw;
// }
template <typename U>
friend SharedPtr<T> dynamic_pointer_cast(const SharedPtr<U> &sp);
friend SharedPtr<T> static_pointer_cast(const SharedPtr<U> &sp) {
SharedPtr<T> ret{};
ret.raw = static_cast<T *>(sp.raw);
ret.underlying = sp.underlying;
}
template <typename U>
friend SharedPtr<T> dynamic_pointer_cast(const SharedPtr<U> &sp) {
SharedPtr<T> ret{};
ret.raw = dynamic_cast<T *>(sp.raw);
ret.underlying = sp.underlying;
}
};
} // namespace cs440
#endif

View file

@ -264,9 +264,9 @@ void basic_tests_1() {
{
SharedPtr<Derived> sp(new Derived);
// Test template copy constructor.
// SharedPtr<Base1> sp3(sp);
// sp2 = sp;
// sp2 = sp2;
SharedPtr<Base1> sp3(sp);
sp2 = sp;
sp2 = sp2;
}
}
}
@ -279,7 +279,7 @@ void basic_tests_1() {
}
{
SharedPtr<Derived> sp(new Derived);
// SharedPtr<Base1> sp2(sp);
SharedPtr<Base1> sp2(sp);
}
}
@ -294,9 +294,9 @@ void basic_tests_1() {
{
SharedPtr<Derived> sp(new Derived);
SharedPtr<Base1> sp2;
// sp2 = sp;
sp2 = sp;
sp2 = sp2; // Self assignment.
// sp2 = sp;
sp2 = sp;
sp = sp;
}
}
@ -306,7 +306,7 @@ void basic_tests_1() {
{
SharedPtr<Derived> sp(new Derived);
SharedPtr<Base1> sp2;
// sp2 = sp;
sp2 = sp;
sp2 = sp2;
sp.reset();
sp.reset(new Derived);
@ -339,9 +339,9 @@ void basic_tests_1() {
SharedPtr<Derived> sp(new Derived);
(*sp).value = 1234;
// SharedPtr<const Derived> sp2(sp);
// int i = (*sp2).value;
// assert(i == 1234);
SharedPtr<const Derived> sp2(sp);
int i = (*sp2).value;
assert(i == 1234);
//(*sp2).value = 567; // Should give a syntax error if uncommented.
}
@ -350,9 +350,9 @@ void basic_tests_1() {
SharedPtr<Derived> sp(new Derived);
sp->value = 1234;
// SharedPtr<const Derived> sp2(sp);
// int i = sp2->value;
// assert(i == 1234);
SharedPtr<const Derived> sp2(sp);
int i = sp2->value;
assert(i == 1234);
// sp2->value = 567; // Should give a syntax error if uncommented.
}
@ -361,9 +361,9 @@ void basic_tests_1() {
SharedPtr<Derived> sp(new Derived);
sp.get()->value = 1234;
// SharedPtr<const Derived> sp2(sp);
// int i = sp2.get()->value;
// assert(i == 1234);
SharedPtr<const Derived> sp2(sp);
int i = sp2.get()->value;
assert(i == 1234);
}
// Test bool.