Kernel: Rename RefCounted -> RefPtr and implement RefCounted

This commit is contained in:
Bananymous 2023-03-08 03:21:30 +02:00
parent f7ebda3bf1
commit 23b3028e15
12 changed files with 98 additions and 133 deletions

View File

@ -24,143 +24,107 @@ namespace BAN
#endif #endif
template<typename T> template<typename T>
class Unique class RefCounted
{ {
BAN_NON_COPYABLE(Unique); BAN_NON_COPYABLE(RefCounted);
BAN_NON_MOVABLE(RefCounted);
public: public:
template<typename... Args> uint32_t ref_count() const
Unique(const Args&... args)
{ {
m_pointer = new T(args...); return m_ref_count;
} }
~Unique() void ref() const
{ {
delete m_pointer; ASSERT(m_ref_count > 0);
m_ref_count++;
} }
operator bool() const void unref() const
{ {
return m_pointer; ASSERT(m_ref_count > 0);
m_ref_count--;
if (m_ref_count == 0)
delete (const T*)this;
} }
protected:
RefCounted() = default;
~RefCounted() { ASSERT(m_ref_count == 0); }
private: private:
T* m_pointer = nullptr; mutable uint32_t m_ref_count = 1;
}; };
template<typename T> template<typename T>
class RefCounted class RefPtr
{ {
public: public:
RefCounted() = default; RefPtr() = default;
RefCounted(const RefCounted<T>& other) ~RefPtr() { clear(); }
{
*this = other;
}
RefCounted(RefCounted<T>&& other)
{
*this = move(other);
}
~RefCounted()
{
clear();
}
template<typename U> template<typename U>
static ErrorOr<RefCounted<T>> adopt(U* data) static RefPtr adopt(U* pointer)
{ {
uint32_t* count = new uint32_t(1); return RefPtr(pointer);
if (!count)
return Error::from_errno(ENOMEM);
return RefCounted<T>((T*)data, count);
} }
template<typename... Args> template<typename... Args>
static ErrorOr<RefCounted<T>> create(Args... args) static ErrorOr<RefPtr> create(Args&&... args)
{ {
uint32_t* count = new uint32_t(1); T* pointer = new T(forward<Args>(args)...);
if (!count) if (pointer == nullptr)
return Error::from_errno(ENOMEM); return Error::from_errno(ENOMEM);
T* data = new T(forward<Args>(args)...); return RefPtr(pointer);
if (!data)
{
delete count;
return Error::from_errno(ENOMEM);
}
return RefCounted<T>(data, count);
} }
RefCounted<T>& operator=(const RefCounted<T>& other) RefPtr(const RefPtr& other) { *this = other; }
RefPtr(RefPtr&& other) { *this = move(other); }
RefPtr& operator=(const RefPtr& other)
{ {
clear(); clear();
if (other) m_pointer = other.m_pointer;
{ if (m_pointer)
m_pointer = other.m_pointer; m_pointer->ref();
m_count = other.m_count;
(*m_count)++;
}
return *this; return *this;
} }
RefCounted<T>& operator=(RefCounted<T>&& other) RefPtr& operator=(RefPtr&& other)
{ {
clear(); clear();
if (other) m_pointer = other.m_pointer;
{ other.m_pointer = nullptr;
m_pointer = other.m_pointer;
m_count = other.m_count;
other.m_pointer = nullptr;
other.m_count = nullptr;
}
return *this; return *this;
} }
T* ptr() { return m_pointer; } T* ptr() { ASSERT(!empty()); return m_pointer; }
const T* ptr() const { return m_pointer; } const T* ptr() const { ASSERT(!empty()); return m_pointer; }
T& operator*() { return *ptr();} T& operator*() { return *ptr(); }
const T& operator*() const { return *ptr();} const T& operator*() const { return *ptr(); }
T* operator->() { return ptr(); } T* operator->() { return ptr(); }
const T* operator->() const { return ptr(); } const T* operator->() const { return ptr(); }
bool empty() const { return m_pointer == nullptr; }
operator bool() const { return m_pointer; }
void clear() void clear()
{ {
if (!*this) if (m_pointer)
return; m_pointer->unref();
(*m_count)--;
if (*m_count == 0)
{
delete m_pointer;
delete m_count;
}
m_pointer = nullptr; m_pointer = nullptr;
m_count = nullptr;
}
operator bool() const
{
if (!m_count && !m_pointer)
return false;
ASSERT(m_count && m_pointer);
ASSERT(*m_count > 0);
return true;
} }
private: private:
RefCounted(T* pointer, uint32_t* count) RefPtr(T* pointer)
: m_pointer(pointer) : m_pointer(pointer)
, m_count(count) {}
{
ASSERT(!pointer == !count);
}
private: private:
T* m_pointer = nullptr; T* m_pointer = nullptr;
uint32_t* m_count = nullptr;
}; };
} }

View File

@ -131,8 +131,8 @@ namespace Kernel
virtual BAN::StringView name() const override { return m_name; } virtual BAN::StringView name() const override { return m_name; }
virtual BAN::ErrorOr<BAN::Vector<uint8_t>> read_all() override; virtual BAN::ErrorOr<BAN::Vector<uint8_t>> read_all() override;
virtual BAN::ErrorOr<BAN::Vector<BAN::RefCounted<Inode>>> directory_inodes() override; virtual BAN::ErrorOr<BAN::Vector<BAN::RefPtr<Inode>>> directory_inodes() override;
virtual BAN::ErrorOr<BAN::RefCounted<Inode>> directory_find(BAN::StringView) override; virtual BAN::ErrorOr<BAN::RefPtr<Inode>> directory_find(BAN::StringView) override;
private: private:
BAN::ErrorOr<void> for_each_block(BAN::Function<BAN::ErrorOr<bool>(const BAN::Vector<uint8_t>&)>&); BAN::ErrorOr<void> for_each_block(BAN::Function<BAN::ErrorOr<bool>(const BAN::Vector<uint8_t>&)>&);
@ -158,7 +158,7 @@ namespace Kernel
public: public:
static BAN::ErrorOr<Ext2FS*> create(StorageDevice::Partition&); static BAN::ErrorOr<Ext2FS*> create(StorageDevice::Partition&);
virtual const BAN::RefCounted<Inode> root_inode() const override { return m_root_inode; } virtual const BAN::RefPtr<Inode> root_inode() const override { return m_root_inode; }
private: private:
Ext2FS(StorageDevice::Partition& partition) Ext2FS(StorageDevice::Partition& partition)
@ -179,7 +179,7 @@ namespace Kernel
private: private:
StorageDevice::Partition& m_partition; StorageDevice::Partition& m_partition;
BAN::RefCounted<Inode> m_root_inode; BAN::RefPtr<Inode> m_root_inode;
Ext2::Superblock m_superblock; Ext2::Superblock m_superblock;
BAN::Vector<Ext2::BlockGroupDescriptor> m_block_group_descriptors; BAN::Vector<Ext2::BlockGroupDescriptor> m_block_group_descriptors;

View File

@ -9,7 +9,7 @@ namespace Kernel
class FileSystem class FileSystem
{ {
public: public:
virtual const BAN::RefCounted<Inode> root_inode() const = 0; virtual const BAN::RefPtr<Inode> root_inode() const = 0;
}; };
} }

View File

@ -8,7 +8,7 @@
namespace Kernel namespace Kernel
{ {
class Inode class Inode : public BAN::RefCounted<Inode>
{ {
public: public:
union Mode union Mode
@ -48,8 +48,8 @@ namespace Kernel
virtual BAN::StringView name() const = 0; virtual BAN::StringView name() const = 0;
virtual BAN::ErrorOr<BAN::Vector<uint8_t>> read_all() = 0; virtual BAN::ErrorOr<BAN::Vector<uint8_t>> read_all() = 0;
virtual BAN::ErrorOr<BAN::Vector<BAN::RefCounted<Inode>>> directory_inodes() = 0; virtual BAN::ErrorOr<BAN::Vector<BAN::RefPtr<Inode>>> directory_inodes() = 0;
virtual BAN::ErrorOr<BAN::RefCounted<Inode>> directory_find(BAN::StringView) = 0; virtual BAN::ErrorOr<BAN::RefPtr<Inode>> directory_find(BAN::StringView) = 0;
}; };
} }

View File

@ -13,16 +13,16 @@ namespace Kernel
static VirtualFileSystem& get(); static VirtualFileSystem& get();
static bool is_initialized(); static bool is_initialized();
virtual const BAN::RefCounted<Inode> root_inode() const override { return m_root_inode; } virtual const BAN::RefPtr<Inode> root_inode() const override { return m_root_inode; }
BAN::ErrorOr<BAN::RefCounted<Inode>> from_absolute_path(BAN::StringView); BAN::ErrorOr<BAN::RefPtr<Inode>> from_absolute_path(BAN::StringView);
private: private:
VirtualFileSystem() = default; VirtualFileSystem() = default;
BAN::ErrorOr<void> initialize_impl(); BAN::ErrorOr<void> initialize_impl();
private: private:
BAN::RefCounted<Inode> m_root_inode; BAN::RefPtr<Inode> m_root_inode;
BAN::Vector<StorageController*> m_storage_controllers; BAN::Vector<StorageController*> m_storage_controllers;
}; };

View File

@ -16,7 +16,7 @@ namespace Kernel
void start(); void start();
void reschedule(); void reschedule();
BAN::ErrorOr<void> add_thread(BAN::RefCounted<Thread>); BAN::ErrorOr<void> add_thread(BAN::RefPtr<Thread>);
void set_current_thread_sleeping(uint64_t); void set_current_thread_sleeping(uint64_t);
[[noreturn]] void set_current_thread_done(); [[noreturn]] void set_current_thread_done();
@ -24,7 +24,7 @@ namespace Kernel
private: private:
Scheduler() = default; Scheduler() = default;
BAN::RefCounted<Thread> current_thread(); BAN::RefPtr<Thread> current_thread();
void wake_threads(); void wake_threads();
[[nodiscard]] bool save_current_thread(); [[nodiscard]] bool save_current_thread();
@ -34,17 +34,17 @@ namespace Kernel
private: private:
struct ActiveThread struct ActiveThread
{ {
BAN::RefCounted<Thread> thread; BAN::RefPtr<Thread> thread;
uint64_t padding; uint64_t padding = 0;
}; };
struct SleepingThread struct SleepingThread
{ {
BAN::RefCounted<Thread> thread; BAN::RefPtr<Thread> thread;
uint64_t wake_delta; uint64_t wake_time;
}; };
BAN::RefCounted<Thread> m_idle_thread; BAN::RefPtr<Thread> m_idle_thread;
BAN::LinkedList<ActiveThread> m_active_threads; BAN::LinkedList<ActiveThread> m_active_threads;
BAN::LinkedList<SleepingThread> m_sleeping_threads; BAN::LinkedList<SleepingThread> m_sleeping_threads;

View File

@ -6,13 +6,10 @@
namespace Kernel namespace Kernel
{ {
class Thread class Thread : public BAN::RefCounted<Thread>
{ {
BAN_NON_COPYABLE(Thread);
BAN_NON_MOVABLE(Thread);
public: public:
static BAN::ErrorOr<BAN::RefCounted<Thread>> create(const BAN::Function<void()>&); static BAN::ErrorOr<BAN::RefPtr<Thread>> create(const BAN::Function<void()>&);
~Thread(); ~Thread();
uint32_t tid() const { return m_tid; } uint32_t tid() const { return m_tid; }
@ -40,7 +37,7 @@ namespace Kernel
BAN::Function<void()> m_function; BAN::Function<void()> m_function;
friend class BAN::RefCounted<Thread>; friend class BAN::RefPtr<Thread>;
}; };
} }

View File

@ -253,12 +253,12 @@ namespace Kernel
return data_buffer; return data_buffer;
} }
BAN::ErrorOr<BAN::RefCounted<Inode>> Ext2Inode::directory_find(BAN::StringView file_name) BAN::ErrorOr<BAN::RefPtr<Inode>> Ext2Inode::directory_find(BAN::StringView file_name)
{ {
if (!ifdir()) if (!ifdir())
return BAN::Error::from_errno(ENOTDIR); return BAN::Error::from_errno(ENOTDIR);
BAN::RefCounted<Inode> result; BAN::RefPtr<Inode> result;
BAN::Function<BAN::ErrorOr<bool>(const BAN::Vector<uint8_t>&)> function( BAN::Function<BAN::ErrorOr<bool>(const BAN::Vector<uint8_t>&)> function(
[&](const BAN::Vector<uint8_t>& block_data) -> BAN::ErrorOr<bool> [&](const BAN::Vector<uint8_t>& block_data) -> BAN::ErrorOr<bool>
{ {
@ -273,7 +273,7 @@ namespace Kernel
Ext2Inode* inode = new Ext2Inode(m_fs, TRY(m_fs->read_inode(entry->inode)), entry_name); Ext2Inode* inode = new Ext2Inode(m_fs, TRY(m_fs->read_inode(entry->inode)), entry_name);
if (inode == nullptr) if (inode == nullptr)
return BAN::Error::from_errno(ENOMEM); return BAN::Error::from_errno(ENOMEM);
result = TRY(BAN::RefCounted<Inode>::adopt(inode)); result = BAN::RefPtr<Inode>::adopt(inode);
return false; return false;
} }
entry_addr += entry->rec_len; entry_addr += entry->rec_len;
@ -288,12 +288,12 @@ namespace Kernel
return BAN::Error::from_errno(ENOENT); return BAN::Error::from_errno(ENOENT);
} }
BAN::ErrorOr<BAN::Vector<BAN::RefCounted<Inode>>> Ext2Inode::directory_inodes() BAN::ErrorOr<BAN::Vector<BAN::RefPtr<Inode>>> Ext2Inode::directory_inodes()
{ {
if (!ifdir()) if (!ifdir())
return BAN::Error::from_errno(ENOTDIR); return BAN::Error::from_errno(ENOTDIR);
BAN::Vector<BAN::RefCounted<Inode>> inodes; BAN::Vector<BAN::RefPtr<Inode>> inodes;
BAN::Function<BAN::ErrorOr<bool>(const BAN::Vector<uint8_t>&)> function( BAN::Function<BAN::ErrorOr<bool>(const BAN::Vector<uint8_t>&)> function(
[&](const BAN::Vector<uint8_t>& block_data) -> BAN::ErrorOr<bool> [&](const BAN::Vector<uint8_t>& block_data) -> BAN::ErrorOr<bool>
{ {
@ -310,7 +310,7 @@ namespace Kernel
Ext2Inode* inode = new Ext2Inode(m_fs, BAN::move(current_inode), entry_name); Ext2Inode* inode = new Ext2Inode(m_fs, BAN::move(current_inode), entry_name);
if (inode == nullptr) if (inode == nullptr)
return BAN::Error::from_errno(ENOMEM); return BAN::Error::from_errno(ENOMEM);
TRY(inodes.push_back(TRY(BAN::RefCounted<Inode>::adopt(inode)))); TRY(inodes.push_back(BAN::RefPtr<Inode>::adopt(inode)));
} }
entry_addr += entry->rec_len; entry_addr += entry->rec_len;
} }
@ -431,7 +431,7 @@ namespace Kernel
Ext2Inode* root_inode = new Ext2Inode(this, TRY(read_inode(Ext2::Enum::ROOT_INO)), ""); Ext2Inode* root_inode = new Ext2Inode(this, TRY(read_inode(Ext2::Enum::ROOT_INO)), "");
if (root_inode == nullptr) if (root_inode == nullptr)
return BAN::Error::from_errno(ENOMEM); return BAN::Error::from_errno(ENOMEM);
m_root_inode = TRY(BAN::RefCounted<Inode>::adopt(root_inode)); m_root_inode = BAN::RefPtr<Inode>::adopt(root_inode);
#if EXT2_DEBUG_PRINT #if EXT2_DEBUG_PRINT
dprintln("root inode:"); dprintln("root inode:");

View File

@ -112,7 +112,7 @@ namespace Kernel
return BAN::Error::from_string("Could not locate root partition"); return BAN::Error::from_string("Could not locate root partition");
} }
BAN::ErrorOr<BAN::RefCounted<Inode>> VirtualFileSystem::from_absolute_path(BAN::StringView path) BAN::ErrorOr<BAN::RefPtr<Inode>> VirtualFileSystem::from_absolute_path(BAN::StringView path)
{ {
if (path.front() != '/') if (path.front() != '/')
return BAN::Error::from_string("Path must be an absolute path"); return BAN::Error::from_string("Path must be an absolute path");

View File

@ -46,9 +46,9 @@ namespace Kernel
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
} }
Thread& Scheduler::current_thread() BAN::RefPtr<Thread> Scheduler::current_thread()
{ {
return m_current_thread ? *m_current_thread->thread : *m_idle_thread; return m_current_thread ? m_current_thread->thread : m_idle_thread;
} }
void Scheduler::reschedule() void Scheduler::reschedule()
@ -147,9 +147,9 @@ namespace Kernel
} }
read_rsp(rsp); read_rsp(rsp);
auto& current = current_thread(); auto current = current_thread();
current.set_rip(rip); current->set_rip(rip);
current.set_rsp(rsp); current->set_rsp(rsp);
return false; return false;
} }
@ -157,7 +157,7 @@ namespace Kernel
{ {
VERIFY_CLI(); VERIFY_CLI();
auto& current = current_thread(); auto& current = *current_thread();
if (current.started()) if (current.started())
{ {

View File

@ -182,7 +182,7 @@ argument_done:
s_thread_spinlock.lock(); s_thread_spinlock.lock();
MUST(Scheduler::get().add_thread(MUST(Thread::create( auto thread_or_error = Thread::create(
[this, &arguments] [this, &arguments]
{ {
auto args = arguments; auto args = arguments;
@ -191,7 +191,11 @@ argument_done:
PIT::sleep(5000); PIT::sleep(5000);
process_command(args); process_command(args);
} }
)))); );
if (thread_or_error.is_error())
return TTY_PRINTLN("{}", thread_or_error.error());
MUST(Scheduler::get().add_thread(thread));
while (s_thread_spinlock.is_locked()); while (s_thread_spinlock.is_locked());
} }

View File

@ -22,9 +22,9 @@ namespace Kernel
} }
BAN::ErrorOr<BAN::RefCounted<Thread>> Thread::create(const BAN::Function<void()>& function) BAN::ErrorOr<BAN::RefPtr<Thread>> Thread::create(const BAN::Function<void()>& function)
{ {
return BAN::RefCounted<Thread>::create(function); return BAN::RefPtr<Thread>::create(function);
} }
Thread::Thread(const BAN::Function<void()>& function) Thread::Thread(const BAN::Function<void()>& function)