diff --git a/kernel/include/kernel/Storage/DiskCache.h b/kernel/include/kernel/Storage/DiskCache.h index bfce7be109..342a855a4a 100644 --- a/kernel/include/kernel/Storage/DiskCache.h +++ b/kernel/include/kernel/Storage/DiskCache.h @@ -7,16 +7,18 @@ namespace Kernel { + class StorageDevice; + class DiskCache { public: - DiskCache(size_t sector_size); + DiskCache(size_t sector_size, StorageDevice&); ~DiskCache(); bool read_from_cache(uint64_t sector, uint8_t* buffer); BAN::ErrorOr write_to_cache(uint64_t sector, const uint8_t* buffer, bool dirty); - void sync(); + BAN::ErrorOr sync(); size_t release_clean_pages(size_t); size_t release_pages(size_t); void release_all_pages(); @@ -32,6 +34,7 @@ namespace Kernel private: const size_t m_sector_size; + StorageDevice& m_device; BAN::Vector m_cache; }; diff --git a/kernel/include/kernel/Storage/StorageDevice.h b/kernel/include/kernel/Storage/StorageDevice.h index 17df45d50e..fe3edbce7b 100644 --- a/kernel/include/kernel/Storage/StorageDevice.h +++ b/kernel/include/kernel/Storage/StorageDevice.h @@ -76,12 +76,15 @@ namespace Kernel BAN::Vector& partitions() { return m_partitions; } const BAN::Vector& partitions() const { return m_partitions; } + BAN::ErrorOr sync_disk_cache(); + protected: virtual BAN::ErrorOr read_sectors_impl(uint64_t lba, uint8_t sector_count, uint8_t* buffer) = 0; virtual BAN::ErrorOr write_sectors_impl(uint64_t lba, uint8_t sector_count, const uint8_t* buffer) = 0; void add_disk_cache(); private: + SpinLock m_lock; BAN::Optional m_disk_cache; BAN::Vector m_partitions; diff --git a/kernel/kernel/Storage/DiskCache.cpp b/kernel/kernel/Storage/DiskCache.cpp index 08f81934dc..9307709ede 100644 --- a/kernel/kernel/Storage/DiskCache.cpp +++ b/kernel/kernel/Storage/DiskCache.cpp @@ -8,8 +8,9 @@ namespace Kernel { - DiskCache::DiskCache(size_t sector_size) + DiskCache::DiskCache(size_t sector_size, StorageDevice& device) : m_sector_size(sector_size) + , m_device(device) { ASSERT(PAGE_SIZE % m_sector_size == 0); ASSERT(PAGE_SIZE / m_sector_size <= sizeof(PageCache::sector_mask) * 8); @@ -31,7 +32,6 @@ namespace Kernel LockGuard page_table_locker(page_table); ASSERT(page_table.is_page_free(0)); - CriticalScope _; for (auto& cache : m_cache) { if (cache.first_sector < page_cache_start) @@ -42,6 +42,7 @@ namespace Kernel if (!(cache.sector_mask & (1 << page_cache_offset))) continue; + CriticalScope _; page_table.map_page_at(cache.paddr, 0, PageTable::Flags::Present); memcpy(buffer, (void*)(page_cache_offset * m_sector_size), m_sector_size); page_table.unmap_page(0); @@ -62,8 +63,6 @@ namespace Kernel LockGuard page_table_locker(page_table); ASSERT(page_table.is_page_free(0)); - CriticalScope _; - size_t index = 0; // Search the cache if the have this sector in memory @@ -76,9 +75,12 @@ namespace Kernel if (cache.first_sector > page_cache_start) break; - page_table.map_page_at(cache.paddr, 0, PageTable::Flags::ReadWrite | PageTable::Flags::Present); - memcpy((void*)(page_cache_offset * m_sector_size), buffer, m_sector_size); - page_table.unmap_page(0); + { + CriticalScope _; + page_table.map_page_at(cache.paddr, 0, PageTable::Flags::ReadWrite | PageTable::Flags::Present); + memcpy((void*)(page_cache_offset * m_sector_size), buffer, m_sector_size); + page_table.unmap_page(0); + } cache.sector_mask |= 1 << page_cache_offset; if (dirty) @@ -104,25 +106,51 @@ namespace Kernel return ret.error(); } - page_table.map_page_at(cache.paddr, 0, PageTable::Flags::Present); - memcpy((void*)(page_cache_offset * m_sector_size), buffer, m_sector_size); - page_table.unmap_page(0); + { + CriticalScope _; + page_table.map_page_at(cache.paddr, 0, PageTable::Flags::Present); + memcpy((void*)(page_cache_offset * m_sector_size), buffer, m_sector_size); + page_table.unmap_page(0); + } return {}; } - void DiskCache::sync() + BAN::ErrorOr DiskCache::sync() { - CriticalScope _; + BAN::Vector sector_buffer; + TRY(sector_buffer.resize(m_sector_size)); + + PageTable& page_table = PageTable::current(); + LockGuard page_table_locker(page_table); + ASSERT(page_table.is_page_free(0)); + for (auto& cache : m_cache) - ASSERT(cache.dirty_mask == 0); + { + for (int i = 0; cache.dirty_mask; i++) + { + if (!(cache.dirty_mask & (1 << i))) + continue; + + { + CriticalScope _; + page_table.map_page_at(cache.paddr, 0, PageTable::Flags::Present); + memcpy(sector_buffer.data(), (void*)(i * m_sector_size), m_sector_size); + page_table.unmap_page(0); + } + + TRY(m_device.write_sectors_impl(cache.first_sector + i, 1, sector_buffer.data())); + cache.dirty_mask &= ~(1 << i); + } + } + + return {}; } size_t DiskCache::release_clean_pages(size_t page_count) { // NOTE: There might not actually be page_count pages after this // function returns. The synchronization must be done elsewhere. - CriticalScope _; size_t released = 0; for (size_t i = 0; i < m_cache.size() && released < page_count;) @@ -147,8 +175,9 @@ namespace Kernel size_t released = release_clean_pages(page_count); if (released >= page_count) return released; - - ASSERT_NOT_REACHED(); + if (!sync().is_error()) + released += release_clean_pages(page_count - released); + return released; } void DiskCache::release_all_pages() diff --git a/kernel/kernel/Storage/StorageDevice.cpp b/kernel/kernel/Storage/StorageDevice.cpp index 01ea322bad..93070eab69 100644 --- a/kernel/kernel/Storage/StorageDevice.cpp +++ b/kernel/kernel/Storage/StorageDevice.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -257,15 +258,26 @@ namespace Kernel void StorageDevice::add_disk_cache() { + LockGuard _(m_lock); ASSERT(!m_disk_cache.has_value()); - m_disk_cache = DiskCache(sector_size()); + m_disk_cache = DiskCache(sector_size(), *this); + } + + BAN::ErrorOr StorageDevice::sync_disk_cache() + { + LockGuard _(m_lock); + if (m_disk_cache.has_value()) + TRY(m_disk_cache->sync()); + return {}; } BAN::ErrorOr StorageDevice::read_sectors(uint64_t lba, uint8_t sector_count, uint8_t* buffer) { for (uint8_t offset = 0; offset < sector_count; offset++) { - Thread::TerminateBlocker _(Thread::current()); + LockGuard _(m_lock); + Thread::TerminateBlocker blocker(Thread::current()); + uint8_t* buffer_ptr = buffer + offset * sector_size(); if (m_disk_cache.has_value()) if (m_disk_cache->read_from_cache(lba + offset, buffer_ptr)) @@ -283,11 +295,12 @@ namespace Kernel // TODO: use disk cache for dirty pages. I don't wanna think about how to do it safely now for (uint8_t offset = 0; offset < sector_count; offset++) { - Thread::TerminateBlocker _(Thread::current()); + LockGuard _(m_lock); + Thread::TerminateBlocker blocker(Thread::current()); + const uint8_t* buffer_ptr = buffer + offset * sector_size(); - TRY(write_sectors_impl(lba + offset, 1, buffer_ptr)); - if (m_disk_cache.has_value()) - (void)m_disk_cache->write_to_cache(lba + offset, buffer_ptr, false); + if (!m_disk_cache.has_value() || m_disk_cache->write_to_cache(lba + offset, buffer_ptr, true).is_error()) + TRY(write_sectors_impl(lba + offset, 1, buffer_ptr)); } return {};