BAN: Write RefPtr and WeakPtr to be thread safe

This commit is contained in:
Bananymous 2024-06-28 22:00:29 +03:00
parent 48a76426e7
commit 010c2c934b
2 changed files with 45 additions and 13 deletions

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <BAN/Atomic.h>
#include <BAN/Errors.h> #include <BAN/Errors.h>
#include <BAN/Move.h> #include <BAN/Move.h>
#include <BAN/NoCopyMove.h> #include <BAN/NoCopyMove.h>
@ -22,15 +23,27 @@ namespace BAN
void ref() const void ref() const
{ {
ASSERT(m_ref_count > 0); uint32_t old = m_ref_count.fetch_add(1, MemoryOrder::memory_order_relaxed);
m_ref_count++; ASSERT(old > 0);
}
bool try_ref() const
{
uint32_t expected = m_ref_count.load(MemoryOrder::memory_order_relaxed);
for (;;)
{
if (expected == 0)
return false;
if (m_ref_count.compare_exchange(expected, expected + 1, MemoryOrder::memory_order_acquire))
return true;
}
} }
void unref() const void unref() const
{ {
ASSERT(m_ref_count > 0); uint32_t old = m_ref_count.fetch_sub(1);
m_ref_count--; ASSERT(old > 0);
if (m_ref_count == 0) if (old == 1)
delete (const T*)this; delete (const T*)this;
} }
@ -39,7 +52,7 @@ namespace BAN
virtual ~RefCounted() { ASSERT(m_ref_count == 0); } virtual ~RefCounted() { ASSERT(m_ref_count == 0); }
private: private:
mutable uint32_t m_ref_count = 1; mutable Atomic<uint32_t> m_ref_count = 1;
}; };
template<typename T> template<typename T>

View File

@ -2,6 +2,10 @@
#include <BAN/RefPtr.h> #include <BAN/RefPtr.h>
#if __is_kernel
#include <kernel/Lock/SpinLock.h>
#endif
namespace BAN namespace BAN
{ {
@ -11,22 +15,37 @@ namespace BAN
template<typename T> template<typename T>
class WeakPtr; class WeakPtr;
// FIXME: Write this without using locks...
template<typename T> template<typename T>
class WeakLink : public RefCounted<WeakLink<T>> class WeakLink : public RefCounted<WeakLink<T>>
{ {
public: public:
RefPtr<T> lock() { ASSERT(m_ptr); return raw_ptr(); } RefPtr<T> try_lock()
T* raw_ptr() { return m_ptr; } {
#if __is_kernel
Kernel::SpinLockGuard _(m_weak_lock);
#endif
if (m_ptr && m_ptr->try_ref())
return RefPtr<T>::adopt(m_ptr);
return nullptr;
}
bool valid() const { return m_ptr; } bool valid() const { return m_ptr; }
void invalidate() { m_ptr = nullptr; } void invalidate()
{
#if __is_kernel
Kernel::SpinLockGuard _(m_weak_lock);
#endif
m_ptr = nullptr;
}
private: private:
WeakLink(T* ptr) : m_ptr(ptr) {} WeakLink(T* ptr) : m_ptr(ptr) {}
private: private:
T* m_ptr; T* m_ptr;
#if __is_kernel
Kernel::SpinLock m_weak_lock;
#endif
friend class RefPtr<WeakLink<T>>; friend class RefPtr<WeakLink<T>>;
}; };
@ -82,8 +101,8 @@ namespace BAN
RefPtr<T> lock() RefPtr<T> lock()
{ {
if (valid()) if (m_link)
return m_link->lock(); return m_link->try_lock();
return nullptr; return nullptr;
} }