diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index ff654ded01..26019ea20f 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -46,6 +46,7 @@ set(KERNEL_SOURCES kernel/Storage/ATABus.cpp kernel/Storage/ATAController.cpp kernel/Storage/ATADevice.cpp + kernel/Storage/DiskCache.cpp kernel/Storage/StorageDevice.cpp kernel/Syscall.cpp kernel/Syscall.S diff --git a/kernel/include/kernel/Storage/ATADevice.h b/kernel/include/kernel/Storage/ATADevice.h index 584d1576be..6c6a166c58 100644 --- a/kernel/include/kernel/Storage/ATADevice.h +++ b/kernel/include/kernel/Storage/ATADevice.h @@ -15,13 +15,15 @@ namespace Kernel { } BAN::ErrorOr initialize(ATABus::DeviceType, const uint16_t*); - virtual BAN::ErrorOr read_sectors(uint64_t, uint8_t, uint8_t*) override; - virtual BAN::ErrorOr write_sectors(uint64_t, uint8_t, const uint8_t*) override; virtual uint32_t sector_size() const override { return m_sector_words * 2; } virtual uint64_t total_size() const override { return m_lba_count * sector_size(); } BAN::StringView model() const { return m_model; } + protected: + virtual BAN::ErrorOr read_sectors_impl(uint64_t, uint8_t, uint8_t*) override; + virtual BAN::ErrorOr write_sectors_impl(uint64_t, uint8_t, const uint8_t*) override; + private: ATABus* m_bus; uint8_t m_index; diff --git a/kernel/include/kernel/Storage/DiskCache.h b/kernel/include/kernel/Storage/DiskCache.h new file mode 100644 index 0000000000..ed017707ee --- /dev/null +++ b/kernel/include/kernel/Storage/DiskCache.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include +#include + +namespace Kernel +{ + + class StorageDevice; + + class DiskCache + { + public: + DiskCache(StorageDevice&); + ~DiskCache(); + + BAN::ErrorOr read_sector(uint64_t sector, uint8_t* buffer); + BAN::ErrorOr write_sector(uint64_t sector, const uint8_t* buffer); + + size_t release_clean_pages(size_t); + size_t release_pages(size_t); + void release_all_pages(); + + private: + struct SectorCache + { + uint64_t sector { 0 }; + bool dirty { false }; + }; + struct CacheBlock + { + paddr_t paddr { 0 }; + BAN::Array sectors; + + void sync(StorageDevice&); + void read_sector(StorageDevice&, size_t, uint8_t*); + void write_sector(StorageDevice&, size_t, const uint8_t*); + }; + + private: + SpinLock m_lock; + StorageDevice& m_device; + BAN::Vector m_cache; + }; + +} \ No newline at end of file diff --git a/kernel/include/kernel/Storage/StorageDevice.h b/kernel/include/kernel/Storage/StorageDevice.h index 8fb93e6dd5..44c7452ffe 100644 --- a/kernel/include/kernel/Storage/StorageDevice.h +++ b/kernel/include/kernel/Storage/StorageDevice.h @@ -2,6 +2,7 @@ #include #include +#include namespace Kernel { @@ -61,20 +62,29 @@ namespace Kernel class StorageDevice : public BlockDevice { public: - virtual ~StorageDevice() {} + virtual ~StorageDevice(); BAN::ErrorOr initialize_partitions(); - virtual BAN::ErrorOr read_sectors(uint64_t lba, uint8_t sector_count, uint8_t* buffer) = 0; - virtual BAN::ErrorOr write_sectors(uint64_t lba, uint8_t sector_count, const uint8_t* buffer) = 0; + BAN::ErrorOr read_sectors(uint64_t lba, uint8_t sector_count, uint8_t* buffer); + BAN::ErrorOr write_sectors(uint64_t lba, uint8_t sector_count, const uint8_t* buffer); + virtual uint32_t sector_size() const = 0; virtual uint64_t total_size() const = 0; BAN::Vector& partitions() { return m_partitions; } const BAN::Vector& partitions() const { return m_partitions; } + + protected: + virtual BAN::ErrorOr read_sectors_impl(uint64_t lba, uint8_t sector_count, uint8_t* buffer) = 0; + virtual BAN::ErrorOr write_sectors_impl(uint64_t lba, uint8_t sector_count, const uint8_t* buffer) = 0; + void add_disk_cache(); private: + DiskCache* m_disk_cache { nullptr }; BAN::Vector m_partitions; + + friend class DiskCache; }; } \ No newline at end of file diff --git a/kernel/kernel/Storage/ATABus.cpp b/kernel/kernel/Storage/ATABus.cpp index 049480d25c..3251c340ca 100644 --- a/kernel/kernel/Storage/ATABus.cpp +++ b/kernel/kernel/Storage/ATABus.cpp @@ -68,7 +68,6 @@ namespace Kernel BAN::ScopeGuard guard([this, i] { m_devices[i]->unref(); m_devices[i] = nullptr; }); - auto type = identify(device, identify_buffer); if (type == DeviceType::None) continue; @@ -109,7 +108,7 @@ namespace Kernel io_write(ATA_PORT_COMMAND, ATA_COMMAND_IDENTIFY); PIT::sleep(1); - + // No device on port if (io_read(ATA_PORT_STATUS) == 0) return DeviceType::None; diff --git a/kernel/kernel/Storage/ATADevice.cpp b/kernel/kernel/Storage/ATADevice.cpp index 8e09cb5624..bfbbee445f 100644 --- a/kernel/kernel/Storage/ATADevice.cpp +++ b/kernel/kernel/Storage/ATADevice.cpp @@ -53,16 +53,18 @@ namespace Kernel dprintln("{} {} MB", m_device_name, total_size() / 1024 / 1024); + add_disk_cache(); + return {}; } - BAN::ErrorOr ATADevice::read_sectors(uint64_t lba, uint8_t sector_count, uint8_t* buffer) + BAN::ErrorOr ATADevice::read_sectors_impl(uint64_t lba, uint8_t sector_count, uint8_t* buffer) { TRY(m_bus->read(this, lba, sector_count, buffer)); return {}; } - BAN::ErrorOr ATADevice::write_sectors(uint64_t lba, uint8_t sector_count, const uint8_t* buffer) + BAN::ErrorOr ATADevice::write_sectors_impl(uint64_t lba, uint8_t sector_count, const uint8_t* buffer) { TRY(m_bus->write(this, lba, sector_count, buffer)); return {}; diff --git a/kernel/kernel/Storage/DiskCache.cpp b/kernel/kernel/Storage/DiskCache.cpp new file mode 100644 index 0000000000..1ee990f795 --- /dev/null +++ b/kernel/kernel/Storage/DiskCache.cpp @@ -0,0 +1,239 @@ +#include +#include +#include +#include +#include + +namespace Kernel +{ + + DiskCache::DiskCache(StorageDevice& device) + : m_device(device) + { } + + DiskCache::~DiskCache() + { + if (m_device.sector_size() == 0) + return; + release_all_pages(); + } + + BAN::ErrorOr DiskCache::read_sector(uint64_t sector, uint8_t* buffer) + { + LockGuard _(m_lock); + + ASSERT(m_device.sector_size() > 0); + ASSERT(m_device.sector_size() <= PAGE_SIZE); + + for (auto& cache_block : m_cache) + { + for (size_t i = 0; i < cache_block.sectors.size(); i++) + { + if (cache_block.sectors[i].sector != sector) + continue; + cache_block.read_sector(m_device, i, buffer); + return {}; + } + } + + // Sector was not cached so we must read it from disk + TRY(m_device.read_sectors_impl(sector, 1, buffer)); + + // We try to add the sector to exisiting cache block + if (!m_cache.empty()) + { + auto& cache_block = m_cache.back(); + for (size_t i = 0; i < m_cache.back().sectors.size(); i++) + { + if (cache_block.sectors[i].sector) + continue; + cache_block.write_sector(m_device, i, buffer); + cache_block.sectors[i].sector = sector; + cache_block.sectors[i].dirty = false; + return {}; + } + } + + // We try to allocate new cache block for this sector + TRY(m_cache.emplace_back()); + if (paddr_t paddr = Heap::get().take_free_page()) + { + auto& cache_block = m_cache.back(); + cache_block.paddr = paddr; + cache_block.write_sector(m_device, 0, buffer); + cache_block.sectors[0].sector = sector; + cache_block.sectors[0].dirty = false; + return {}; + } + + // We could not cache the sector + return {}; + } + + BAN::ErrorOr DiskCache::write_sector(uint64_t sector, const uint8_t* buffer) + { + LockGuard _(m_lock); + + ASSERT(m_device.sector_size() > 0); + ASSERT(m_device.sector_size() <= PAGE_SIZE); + + // Try to find this sector in the cache + for (auto& cache_block : m_cache) + { + for (size_t i = 0; i < cache_block.sectors.size(); i++) + { + if (cache_block.sectors[i].sector != sector) + continue; + cache_block.write_sector(m_device, i, buffer); + cache_block.sectors[i].dirty = true; + return {}; + } + } + + // Sector was not in the cache, we try to add it to exisiting cache block + if (!m_cache.empty()) + { + auto& cache_block = m_cache.back(); + for (size_t i = 0; i < m_cache.back().sectors.size(); i++) + { + if (cache_block.sectors[i].sector) + continue; + cache_block.write_sector(m_device, i, buffer); + cache_block.sectors[i].sector = sector; + cache_block.sectors[i].dirty = true; + return {}; + } + } + + // We try to allocate new cache block + TRY(m_cache.emplace_back()); + if (paddr_t paddr = Heap::get().take_free_page()) + { + auto& cache_block = m_cache.back(); + cache_block.paddr = paddr; + cache_block.write_sector(m_device, 0, buffer); + cache_block.sectors[0].sector = sector; + cache_block.sectors[0].dirty = true; + return {}; + } + + // We could not allocate cache, so we must sync it to disk + // right away + TRY(m_device.write_sectors_impl(sector, 1, buffer)); + return {}; + } + + size_t DiskCache::release_clean_pages(size_t page_count) + { + LockGuard _(m_lock); + + ASSERT(m_device.sector_size() > 0); + ASSERT(m_device.sector_size() <= PAGE_SIZE); + + size_t released = 0; + for (size_t i = 0; i < m_cache.size() && released < page_count;) + { + bool dirty = false; + for (size_t j = 0; j < sizeof(m_cache[i].sectors) / sizeof(SectorCache); j++) + if (m_cache[i].sectors[j].dirty) + dirty = true; + if (dirty) + { + i++; + continue; + } + + Heap::get().release_page(m_cache[i].paddr); + m_cache.remove(i); + released++; + } + + return released; + } + + size_t DiskCache::release_pages(size_t page_count) + { + ASSERT(m_device.sector_size() > 0); + ASSERT(m_device.sector_size() <= PAGE_SIZE); + + size_t released = release_clean_pages(page_count); + if (released >= page_count) + return page_count; + + // NOTE: There might not actually be page_count pages after this + // function returns. The synchronization must be done elsewhere. + LockGuard _(m_lock); + + while (!m_cache.empty() && released < page_count) + { + m_cache.back().sync(m_device); + Heap::get().release_page(m_cache.back().paddr); + m_cache.pop_back(); + released++; + } + + return released; + } + + void DiskCache::release_all_pages() + { + LockGuard _(m_lock); + + ASSERT(m_device.sector_size() > 0); + ASSERT(m_device.sector_size() <= PAGE_SIZE); + + uint8_t* temp_buffer = (uint8_t*)kmalloc(m_device.sector_size()); + ASSERT(temp_buffer); + + while (!m_cache.empty()) + { + auto& cache_block = m_cache.back(); + cache_block.sync(m_device); + Heap::get().release_page(cache_block.paddr); + m_cache.pop_back(); + } + } + + + void DiskCache::CacheBlock::sync(StorageDevice& device) + { + uint8_t* temp_buffer = (uint8_t*)kmalloc(device.sector_size()); + ASSERT(temp_buffer); + + for (size_t i = 0; i < sectors.size(); i++) + { + if (!sectors[i].dirty) + continue; + read_sector(device, i, temp_buffer); + MUST(device.write_sectors_impl(sectors[i].sector, 1, temp_buffer)); + sectors[i].dirty = false; + } + + kfree(temp_buffer); + } + + void DiskCache::CacheBlock::read_sector(StorageDevice& device, size_t index, uint8_t* buffer) + { + ASSERT(index < sectors.size()); + + PageTableScope _(PageTable::current()); + ASSERT(PageTable::current().is_page_free(0)); + PageTable::current().map_page_at(paddr, 0, PageTable::Flags::Present); + memcpy(buffer, (void*)(index * device.sector_size()), device.sector_size()); + PageTable::current().unmap_page(0); + PageTable::current().invalidate(0); + } + + void DiskCache::CacheBlock::write_sector(StorageDevice& device, size_t index, const uint8_t* buffer) + { + ASSERT(index < sectors.size()); + + PageTableScope _(PageTable::current()); + ASSERT(PageTable::current().is_page_free(0)); + PageTable::current().map_page_at(paddr, 0, PageTable::Flags::ReadWrite | PageTable::Flags::Present); + memcpy((void*)(index * device.sector_size()), buffer, device.sector_size()); + PageTable::current().unmap_page(0); + PageTable::current().invalidate(0); + } + +} \ No newline at end of file diff --git a/kernel/kernel/Storage/StorageDevice.cpp b/kernel/kernel/Storage/StorageDevice.cpp index eeb0e742db..eb8477ac10 100644 --- a/kernel/kernel/Storage/StorageDevice.cpp +++ b/kernel/kernel/Storage/StorageDevice.cpp @@ -252,4 +252,37 @@ namespace Kernel return sector_count * m_device.sector_size(); } + StorageDevice::~StorageDevice() + { + if (m_disk_cache) + delete m_disk_cache; + m_disk_cache = nullptr; + } + + void StorageDevice::add_disk_cache() + { + ASSERT(m_disk_cache == nullptr); + m_disk_cache = new DiskCache(*this); + ASSERT(m_disk_cache); + } + + BAN::ErrorOr StorageDevice::read_sectors(uint64_t lba, uint8_t sector_count, uint8_t* buffer) + { + if (!m_disk_cache) + return read_sectors_impl(lba, sector_count, buffer); + for (uint8_t sector = 0; sector < sector_count; sector++) + TRY(m_disk_cache->read_sector(lba + sector, buffer + sector * sector_size())); + return {}; + } + + BAN::ErrorOr StorageDevice::write_sectors(uint64_t lba, uint8_t sector_count, const uint8_t* buffer) + { + if (!m_disk_cache) + return write_sectors_impl(lba, sector_count, buffer); + for (uint8_t sector = 0; sector < sector_count; sector++) + TRY(m_disk_cache->write_sector(lba + sector, buffer + sector * sector_size())); + return {}; + } + + } \ No newline at end of file