From 010c2c934b9e5bf3e4475543cdff5bb0be6e59c8 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Fri, 28 Jun 2024 22:00:29 +0300 Subject: [PATCH] BAN: Write RefPtr and WeakPtr to be thread safe --- BAN/include/BAN/RefPtr.h | 25 +++++++++++++++++++------ BAN/include/BAN/WeakPtr.h | 33 ++++++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/BAN/include/BAN/RefPtr.h b/BAN/include/BAN/RefPtr.h index 8595dac1db..6f0e5108aa 100644 --- a/BAN/include/BAN/RefPtr.h +++ b/BAN/include/BAN/RefPtr.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -22,15 +23,27 @@ namespace BAN void ref() const { - ASSERT(m_ref_count > 0); - m_ref_count++; + uint32_t old = m_ref_count.fetch_add(1, MemoryOrder::memory_order_relaxed); + ASSERT(old > 0); + } + + bool try_ref() const + { + uint32_t expected = m_ref_count.load(MemoryOrder::memory_order_relaxed); + for (;;) + { + if (expected == 0) + return false; + if (m_ref_count.compare_exchange(expected, expected + 1, MemoryOrder::memory_order_acquire)) + return true; + } } void unref() const { - ASSERT(m_ref_count > 0); - m_ref_count--; - if (m_ref_count == 0) + uint32_t old = m_ref_count.fetch_sub(1); + ASSERT(old > 0); + if (old == 1) delete (const T*)this; } @@ -39,7 +52,7 @@ namespace BAN virtual ~RefCounted() { ASSERT(m_ref_count == 0); } private: - mutable uint32_t m_ref_count = 1; + mutable Atomic m_ref_count = 1; }; template diff --git a/BAN/include/BAN/WeakPtr.h b/BAN/include/BAN/WeakPtr.h index b8626875a0..d68bfd05a5 100644 --- a/BAN/include/BAN/WeakPtr.h +++ b/BAN/include/BAN/WeakPtr.h @@ -2,6 +2,10 @@ #include +#if __is_kernel +#include +#endif + namespace BAN { @@ -11,22 +15,37 @@ namespace BAN template class WeakPtr; + // FIXME: Write this without using locks... template class WeakLink : public RefCounted> { public: - RefPtr lock() { ASSERT(m_ptr); return raw_ptr(); } - T* raw_ptr() { return m_ptr; } - + RefPtr try_lock() + { +#if __is_kernel + Kernel::SpinLockGuard _(m_weak_lock); +#endif + if (m_ptr && m_ptr->try_ref()) + return RefPtr::adopt(m_ptr); + return nullptr; + } bool valid() const { return m_ptr; } - void invalidate() { m_ptr = nullptr; } + void invalidate() + { +#if __is_kernel + Kernel::SpinLockGuard _(m_weak_lock); +#endif + m_ptr = nullptr; + } private: WeakLink(T* ptr) : m_ptr(ptr) {} private: T* m_ptr; - +#if __is_kernel + Kernel::SpinLock m_weak_lock; +#endif friend class RefPtr>; }; @@ -82,8 +101,8 @@ namespace BAN RefPtr lock() { - if (valid()) - return m_link->lock(); + if (m_link) + return m_link->try_lock(); return nullptr; }