BAN: Rewrite RefCounted to return ErrorOr

This commit is contained in:
Bananymous 2023-03-02 12:30:11 +02:00
parent 1dd61e93b6
commit 90a7268e5a
2 changed files with 63 additions and 47 deletions

View File

@ -54,16 +54,7 @@ namespace BAN
class RefCounted class RefCounted
{ {
public: public:
RefCounted() { } RefCounted() = default;
RefCounted(T* pointer)
{
if (pointer)
{
m_pointer = pointer;
m_count = new int32_t(1);
ASSERT(m_count);
}
}
RefCounted(const RefCounted<T>& other) RefCounted(const RefCounted<T>& other)
{ {
*this = other; *this = other;
@ -74,18 +65,33 @@ namespace BAN
} }
~RefCounted() ~RefCounted()
{ {
reset(); clear();
} }
template<typename... Args> template<typename U>
static RefCounted<T> create(Args... args) static ErrorOr<RefCounted<T>> adopt(U* data)
{ {
return RefCounted<T>(new T(forward<Args>(args)...), new int32_t(1)); uint32_t* count = new uint32_t(1);
if (!count)
return Error::from_string("RefCounted: Could not allocate memory");
return RefCounted<T>((T*)data, count);
}
template<typename... Args>
static ErrorOr<RefCounted<T>> create(Args... args)
{
uint32_t* count = new uint32_t(1);
if (!count)
return Error::from_string("RefCounted: Could not allocate memory");
T* data = new T(forward<Args>(args)...);
if (!data)
return Error::from_string("RefCounted: Could not allocate memory");
return RefCounted<T>(data, count);
} }
RefCounted<T>& operator=(const RefCounted<T>& other) RefCounted<T>& operator=(const RefCounted<T>& other)
{ {
reset(); clear();
if (other) if (other)
{ {
m_pointer = other.m_pointer; m_pointer = other.m_pointer;
@ -97,53 +103,53 @@ namespace BAN
RefCounted<T>& operator=(RefCounted<T>&& other) RefCounted<T>& operator=(RefCounted<T>&& other)
{ {
reset(); clear();
m_pointer = other.m_pointer; if (other)
m_count = other.m_count; {
other.m_pointer = nullptr; m_pointer = other.m_pointer;
other.m_count = nullptr; m_count = other.m_count;
if (!(*this)) other.m_pointer = nullptr;
reset(); other.m_count = nullptr;
}
return *this; return *this;
} }
T& operator*() { return *m_pointer;} T* ptr() { return m_pointer; }
const T& operator*() const { return *m_pointer;} const T* ptr() const { return m_pointer; }
T* operator->() { return m_pointer; } T& operator*() { return *ptr();}
const T* operator->() const { return m_pointer; } const T& operator*() const { return *ptr();}
void reset() T* operator->() { return ptr(); }
const T* operator->() const { return ptr(); }
void clear()
{ {
ASSERT(!m_count == !m_pointer); if (!*this)
if (!m_count)
return; return;
(*m_count)--; (*m_count)--;
if (*m_count == 0) if (*m_count == 0)
{ {
delete m_count;
delete m_pointer; delete m_pointer;
delete m_count;
} }
m_count = nullptr;
m_pointer = nullptr; m_pointer = nullptr;
m_count = nullptr;
} }
operator bool() const operator bool() const
{ {
ASSERT(!m_count == !m_pointer); if (!m_count && !m_pointer)
return m_count && *m_count > 0;
}
bool operator==(const RefCounted<T>& other) const
{
if (m_pointer != other.m_pointer)
return false; return false;
ASSERT(m_count == other.m_count); ASSERT(m_count && m_pointer);
return !m_count || *m_count > 0; ASSERT(*m_count > 0);
return true;
} }
private: private:
RefCounted(T* pointer, int32_t* count) RefCounted(T* pointer, uint32_t* count)
: m_pointer(pointer) : m_pointer(pointer)
, m_count(count) , m_count(count)
{ {
@ -152,7 +158,7 @@ namespace BAN
private: private:
T* m_pointer = nullptr; T* m_pointer = nullptr;
int32_t* m_count = nullptr; uint32_t* m_count = nullptr;
}; };
} }

View File

@ -270,8 +270,10 @@ namespace Kernel
BAN::StringView entry_name = BAN::StringView(entry->name, entry->name_len); BAN::StringView entry_name = BAN::StringView(entry->name, entry->name_len);
if (entry->inode && file_name == entry_name) if (entry->inode && file_name == entry_name)
{ {
Ext2::Inode asked_inode = TRY(m_fs->read_inode(entry->inode)); Ext2Inode* inode = new Ext2Inode(m_fs, TRY(m_fs->read_inode(entry->inode)), entry_name);
result = BAN::RefCounted<Inode>(new Ext2Inode(m_fs, BAN::move(asked_inode), entry_name)); if (inode == nullptr)
return BAN::Error::from_string("Could not allocate Ext2Inode");
result = TRY(BAN::RefCounted<Inode>::adopt(inode));
return false; return false;
} }
entry_addr += entry->rec_len; entry_addr += entry->rec_len;
@ -304,8 +306,11 @@ namespace Kernel
{ {
BAN::StringView entry_name = BAN::StringView(entry->name, entry->name_len); BAN::StringView entry_name = BAN::StringView(entry->name, entry->name_len);
Ext2::Inode current_inode = TRY(m_fs->read_inode(entry->inode)); Ext2::Inode current_inode = TRY(m_fs->read_inode(entry->inode));
auto ref_counted_inode = BAN::RefCounted<Inode>(new Ext2Inode(m_fs, BAN::move(current_inode), entry_name));
TRY(inodes.push_back(BAN::move(ref_counted_inode))); Ext2Inode* inode = new Ext2Inode(m_fs, BAN::move(current_inode), entry_name);
if (inode == nullptr)
return BAN::Error::from_string("Could not allocate memory for Ext2Inode");
TRY(inodes.push_back(TRY(BAN::RefCounted<Inode>::adopt(inode))));
} }
entry_addr += entry->rec_len; entry_addr += entry->rec_len;
} }
@ -423,7 +428,12 @@ namespace Kernel
BAN::ErrorOr<void> Ext2FS::initialize_root_inode() BAN::ErrorOr<void> Ext2FS::initialize_root_inode()
{ {
m_root_inode = BAN::RefCounted<Inode>(new Ext2Inode(this, TRY(read_inode(Ext2::Enum::ROOT_INO)), ""));
Ext2Inode* root_inode = new Ext2Inode(this, TRY(read_inode(Ext2::Enum::ROOT_INO)), "");
if (root_inode == nullptr)
return BAN::Error::from_string("Could not allocate Ext2Inode");
m_root_inode = TRY(BAN::RefCounted<Inode>::adopt(root_inode));
#if EXT2_DEBUG_PRINT #if EXT2_DEBUG_PRINT
dprintln("root inode:"); dprintln("root inode:");
dprintln(" created {}", ext2_root_inode().ctime); dprintln(" created {}", ext2_root_inode().ctime);