Kernel: Generalize ATA device and cleanup code

This commit is contained in:
Bananymous 2023-10-12 21:35:25 +03:00
parent f4b901a646
commit d3e5c8e0aa
9 changed files with 135 additions and 105 deletions

View File

@ -27,8 +27,6 @@ namespace Kernel
virtual void handle_irq() override; virtual void handle_irq() override;
void initialize_devfs();
private: private:
ATABus(uint16_t base, uint16_t ctrl) ATABus(uint16_t base, uint16_t ctrl)
: m_base(base) : m_base(base)
@ -54,7 +52,7 @@ namespace Kernel
const uint16_t m_ctrl; const uint16_t m_ctrl;
SpinLock m_lock; SpinLock m_lock;
bool m_has_got_irq { false }; volatile bool m_has_got_irq { false };
// Non-owning pointers // Non-owning pointers
BAN::Vector<ATADevice*> m_devices; BAN::Vector<ATADevice*> m_devices;

View File

@ -12,7 +12,7 @@ namespace Kernel
class ATAController : public StorageController class ATAController : public StorageController
{ {
public: public:
static BAN::ErrorOr<BAN::UniqPtr<StorageController>> create(PCI::Device&); static BAN::ErrorOr<BAN::RefPtr<StorageController>> create(PCI::Device&);
virtual BAN::ErrorOr<void> initialize() override; virtual BAN::ErrorOr<void> initialize() override;
private: private:

View File

@ -6,52 +6,69 @@
namespace Kernel namespace Kernel
{ {
class ATADevice final : public StorageDevice namespace detail
{
class ATABaseDevice : public StorageDevice
{
public:
enum class Command
{
Read,
Write
};
public:
virtual ~ATABaseDevice() {};
virtual uint32_t sector_size() const override { return m_sector_words * 2; }
virtual uint64_t total_size() const override { return m_lba_count * sector_size(); }
uint32_t words_per_sector() const { return m_sector_words; }
uint64_t sector_count() const { return m_lba_count; }
BAN::StringView model() const { return m_model; }
BAN::StringView name() const;
virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::ErrorOr<size_t> read_impl(off_t, void*, size_t) override;
virtual BAN::ErrorOr<size_t> write_impl(off_t, const void*, size_t) override;
protected:
ATABaseDevice();
BAN::ErrorOr<void> initialize(BAN::Span<const uint16_t> identify_data);
protected:
uint16_t m_signature;
uint16_t m_capabilities;
uint32_t m_command_set;
uint32_t m_sector_words;
uint64_t m_lba_count;
char m_model[41];
const dev_t m_rdev;
};
}
class ATADevice final : public detail::ATABaseDevice
{ {
public: public:
static BAN::ErrorOr<BAN::RefPtr<ATADevice>> create(BAN::RefPtr<ATABus>, ATABus::DeviceType, bool is_secondary, BAN::Span<const uint16_t> identify_data); static BAN::ErrorOr<BAN::RefPtr<ATADevice>> create(BAN::RefPtr<ATABus>, ATABus::DeviceType, bool is_secondary, BAN::Span<const uint16_t> identify_data);
virtual uint32_t sector_size() const override { return m_sector_words * 2; }
virtual uint64_t total_size() const override { return m_lba_count * sector_size(); }
bool is_secondary() const { return m_is_secondary; } bool is_secondary() const { return m_is_secondary; }
uint32_t words_per_sector() const { return m_sector_words; }
uint64_t sector_count() const { return m_lba_count; }
BAN::StringView model() const { return m_model; }
BAN::StringView name() const;
protected:
virtual BAN::ErrorOr<void> read_sectors_impl(uint64_t, uint8_t, uint8_t*) override;
virtual BAN::ErrorOr<void> write_sectors_impl(uint64_t, uint8_t, const uint8_t*) override;
private: private:
ATADevice(BAN::RefPtr<ATABus>, ATABus::DeviceType, bool is_secodary); ATADevice(BAN::RefPtr<ATABus>, ATABus::DeviceType, bool is_secodary);
BAN::ErrorOr<void> initialize(BAN::Span<const uint16_t> identify_data);
virtual BAN::ErrorOr<void> read_sectors_impl(uint64_t, uint64_t, uint8_t*) override;
virtual BAN::ErrorOr<void> write_sectors_impl(uint64_t, uint64_t, const uint8_t*) override;
private: private:
BAN::RefPtr<ATABus> m_bus; BAN::RefPtr<ATABus> m_bus;
const ATABus::DeviceType m_type; const ATABus::DeviceType m_type;
const bool m_is_secondary; const bool m_is_secondary;
uint16_t m_signature;
uint16_t m_capabilities;
uint32_t m_command_set;
uint32_t m_sector_words;
uint64_t m_lba_count;
char m_model[41];
public:
virtual Mode mode() const override { return { Mode::IFBLK | Mode::IRUSR | Mode::IRGRP }; }
virtual uid_t uid() const override { return 0; }
virtual gid_t gid() const override { return 0; }
virtual dev_t rdev() const override { return m_rdev; }
private:
virtual BAN::ErrorOr<size_t> read_impl(off_t, void*, size_t) override;
public:
const dev_t m_rdev;
}; };
} }

View File

@ -1,9 +1,11 @@
#pragma once #pragma once
#include <BAN/RefPtr.h>
namespace Kernel namespace Kernel
{ {
class StorageController class StorageController : public BAN::RefCounted<StorageController>
{ {
public: public:
virtual ~StorageController() {} virtual ~StorageController() {}

View File

@ -73,8 +73,8 @@ namespace Kernel
BAN::ErrorOr<void> initialize_partitions(); BAN::ErrorOr<void> initialize_partitions();
BAN::ErrorOr<void> read_sectors(uint64_t lba, uint8_t sector_count, uint8_t* buffer); BAN::ErrorOr<void> read_sectors(uint64_t lba, uint64_t sector_count, uint8_t* buffer);
BAN::ErrorOr<void> write_sectors(uint64_t lba, uint8_t sector_count, const uint8_t* buffer); BAN::ErrorOr<void> write_sectors(uint64_t lba, uint64_t sector_count, const uint8_t* buffer);
virtual uint32_t sector_size() const = 0; virtual uint32_t sector_size() const = 0;
virtual uint64_t total_size() const = 0; virtual uint64_t total_size() const = 0;
@ -86,8 +86,8 @@ namespace Kernel
virtual bool is_storage_device() const override { return true; } virtual bool is_storage_device() const override { return true; }
protected: protected:
virtual BAN::ErrorOr<void> read_sectors_impl(uint64_t lba, uint8_t sector_count, uint8_t* buffer) = 0; virtual BAN::ErrorOr<void> read_sectors_impl(uint64_t lba, uint64_t sector_count, uint8_t* buffer) = 0;
virtual BAN::ErrorOr<void> write_sectors_impl(uint64_t lba, uint8_t sector_count, const uint8_t* buffer) = 0; virtual BAN::ErrorOr<void> write_sectors_impl(uint64_t lba, uint64_t sector_count, const uint8_t* buffer) = 0;
void add_disk_cache(); void add_disk_cache();
private: private:

View File

@ -41,6 +41,10 @@ namespace Kernel
else else
device_type = res.value(); device_type = res.value();
// Enable interrupts
select_device(is_secondary);
io_write(ATA_PORT_CONTROL, 0);
auto device_or_error = ATADevice::create(this, device_type, is_secondary, identify_buffer.span()); auto device_or_error = ATADevice::create(this, device_type, is_secondary, identify_buffer.span());
if (device_or_error.is_error()) if (device_or_error.is_error())
@ -50,35 +54,23 @@ namespace Kernel
} }
auto device = device_or_error.release_value(); auto device = device_or_error.release_value();
device->ref();
TRY(m_devices.push_back(device.ptr())); TRY(m_devices.push_back(device.ptr()));
} }
// Enable disk interrupts
for (auto& device : m_devices)
{
select_device(device->is_secondary());
io_write(ATA_PORT_CONTROL, 0);
}
return {}; return {};
} }
void ATABus::initialize_devfs() static void select_delay()
{ {
for (auto& device : m_devices) auto time = SystemTimer::get().ns_since_boot();
{ while (SystemTimer::get().ns_since_boot() < time + 400)
DevFileSystem::get().add_device(device); continue;
if (auto res = device->initialize_partitions(); res.is_error())
dprintln("{}", res.error());
device->unref();
}
} }
void ATABus::select_device(bool secondary) void ATABus::select_device(bool secondary)
{ {
io_write(ATA_PORT_DRIVE_SELECT, 0xA0 | ((uint8_t)secondary << 4)); io_write(ATA_PORT_DRIVE_SELECT, 0xA0 | ((uint8_t)secondary << 4));
SystemTimer::get().sleep(1); select_delay();
} }
BAN::ErrorOr<ATABus::DeviceType> ATABus::identify(bool secondary, BAN::Span<uint16_t> buffer) BAN::ErrorOr<ATABus::DeviceType> ATABus::identify(bool secondary, BAN::Span<uint16_t> buffer)
@ -236,6 +228,9 @@ namespace Kernel
{ {
// LBA28 // LBA28
io_write(ATA_PORT_DRIVE_SELECT, 0xE0 | ((uint8_t)device.is_secondary() << 4) | ((lba >> 24) & 0x0F)); io_write(ATA_PORT_DRIVE_SELECT, 0xE0 | ((uint8_t)device.is_secondary() << 4) | ((lba >> 24) & 0x0F));
select_delay();
io_write(ATA_PORT_CONTROL, 0);
io_write(ATA_PORT_SECTOR_COUNT, sector_count); io_write(ATA_PORT_SECTOR_COUNT, sector_count);
io_write(ATA_PORT_LBA0, (uint8_t)(lba >> 0)); io_write(ATA_PORT_LBA0, (uint8_t)(lba >> 0));
io_write(ATA_PORT_LBA1, (uint8_t)(lba >> 8)); io_write(ATA_PORT_LBA1, (uint8_t)(lba >> 8));
@ -268,14 +263,15 @@ namespace Kernel
{ {
// LBA28 // LBA28
io_write(ATA_PORT_DRIVE_SELECT, 0xE0 | ((uint8_t)device.is_secondary() << 4) | ((lba >> 24) & 0x0F)); io_write(ATA_PORT_DRIVE_SELECT, 0xE0 | ((uint8_t)device.is_secondary() << 4) | ((lba >> 24) & 0x0F));
select_delay();
io_write(ATA_PORT_CONTROL, 0);
io_write(ATA_PORT_SECTOR_COUNT, sector_count); io_write(ATA_PORT_SECTOR_COUNT, sector_count);
io_write(ATA_PORT_LBA0, (uint8_t)(lba >> 0)); io_write(ATA_PORT_LBA0, (uint8_t)(lba >> 0));
io_write(ATA_PORT_LBA1, (uint8_t)(lba >> 8)); io_write(ATA_PORT_LBA1, (uint8_t)(lba >> 8));
io_write(ATA_PORT_LBA2, (uint8_t)(lba >> 16)); io_write(ATA_PORT_LBA2, (uint8_t)(lba >> 16));
io_write(ATA_PORT_COMMAND, ATA_COMMAND_WRITE_SECTORS); io_write(ATA_PORT_COMMAND, ATA_COMMAND_WRITE_SECTORS);
SystemTimer::get().sleep(1);
for (uint32_t sector = 0; sector < sector_count; sector++) for (uint32_t sector = 0; sector < sector_count; sector++)
{ {
write_buffer(ATA_PORT_DATA, (uint16_t*)buffer + sector * device.words_per_sector(), device.words_per_sector()); write_buffer(ATA_PORT_DATA, (uint16_t*)buffer + sector * device.words_per_sector(), device.words_per_sector());

View File

@ -7,7 +7,7 @@
namespace Kernel namespace Kernel
{ {
BAN::ErrorOr<BAN::UniqPtr<StorageController>> ATAController::create(PCI::Device& pci_device) BAN::ErrorOr<BAN::RefPtr<StorageController>> ATAController::create(PCI::Device& pci_device)
{ {
StorageController* controller_ptr = nullptr; StorageController* controller_ptr = nullptr;
@ -29,7 +29,7 @@ namespace Kernel
if (controller_ptr == nullptr) if (controller_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM); return BAN::Error::from_errno(ENOMEM);
auto controller = BAN::UniqPtr<StorageController>::adopt(controller_ptr); auto controller = BAN::RefPtr<StorageController>::adopt(controller_ptr);
TRY(controller->initialize()); TRY(controller->initialize());
return controller; return controller;
} }
@ -67,9 +67,6 @@ namespace Kernel
dprintln("unsupported IDE ATABus in native mode"); dprintln("unsupported IDE ATABus in native mode");
} }
for (auto& bus : buses)
bus->initialize_devfs();
return {}; return {};
} }

View File

@ -21,24 +21,11 @@ namespace Kernel
return minor++; return minor++;
} }
BAN::ErrorOr<BAN::RefPtr<ATADevice>> ATADevice::create(BAN::RefPtr<ATABus> bus, ATABus::DeviceType type, bool is_secondary, BAN::Span<const uint16_t> identify_data) detail::ATABaseDevice::ATABaseDevice()
{ : m_rdev(makedev(get_ata_dev_major(), get_ata_dev_minor()))
auto* device_ptr = new ATADevice(bus, type, is_secondary);
if (device_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
auto device = BAN::RefPtr<ATADevice>::adopt(device_ptr);
TRY(device->initialize(identify_data));
return device;
}
ATADevice::ATADevice(BAN::RefPtr<ATABus> bus, ATABus::DeviceType type, bool is_secondary)
: m_bus(bus)
, m_type(type)
, m_is_secondary(is_secondary)
, m_rdev(makedev(get_ata_dev_major(), get_ata_dev_minor()))
{ } { }
BAN::ErrorOr<void> ATADevice::initialize(BAN::Span<const uint16_t> identify_data) BAN::ErrorOr<void> detail::ATABaseDevice::initialize(BAN::Span<const uint16_t> identify_data)
{ {
ASSERT(identify_data.size() >= 256); ASSERT(identify_data.size() >= 256);
@ -77,41 +64,74 @@ namespace Kernel
} }
m_model[40] = 0; m_model[40] = 0;
dprintln("ATA disk {} MB", total_size() / 1024 / 1024); size_t model_len = 40;
while (model_len > 0 && m_model[model_len - 1] == ' ')
model_len--;
dprintln("Initialized disk '{}' {} MB", BAN::StringView(m_model, model_len), total_size() / 1024 / 1024);
add_disk_cache(); add_disk_cache();
DevFileSystem::get().add_device(this);
if (auto res = initialize_partitions(); res.is_error())
dprintln("{}", res.error());
return {}; return {};
} }
BAN::ErrorOr<void> ATADevice::read_sectors_impl(uint64_t lba, uint8_t sector_count, uint8_t* buffer) BAN::ErrorOr<size_t> detail::ATABaseDevice::read_impl(off_t offset, void* buffer, size_t bytes)
{ {
TRY(m_bus->read(*this, lba, sector_count, buffer)); if (offset % sector_size())
return {}; return BAN::Error::from_errno(EINVAL);
} if (bytes % sector_size())
BAN::ErrorOr<void> ATADevice::write_sectors_impl(uint64_t lba, uint8_t sector_count, const uint8_t* buffer)
{
TRY(m_bus->write(*this, lba, sector_count, buffer));
return {};
}
BAN::ErrorOr<size_t> ATADevice::read_impl(off_t offset, void* buffer, size_t bytes)
{
ASSERT(offset >= 0);
if (offset % sector_size() || bytes % sector_size())
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
if ((size_t)offset == total_size())
return 0;
TRY(read_sectors(offset / sector_size(), bytes / sector_size(), (uint8_t*)buffer)); TRY(read_sectors(offset / sector_size(), bytes / sector_size(), (uint8_t*)buffer));
return bytes; return bytes;
} }
BAN::StringView ATADevice::name() const BAN::ErrorOr<size_t> detail::ATABaseDevice::write_impl(off_t offset, const void* buffer, size_t bytes)
{
if (offset % sector_size())
return BAN::Error::from_errno(EINVAL);
if (bytes % sector_size())
return BAN::Error::from_errno(EINVAL);
TRY(write_sectors(offset / sector_size(), bytes / sector_size(), (const uint8_t*)buffer));
return bytes;
}
BAN::StringView detail::ATABaseDevice::name() const
{ {
static char device_name[] = "sda"; static char device_name[] = "sda";
device_name[2] += minor(m_rdev); device_name[2] += minor(m_rdev);
return device_name; return device_name;
} }
BAN::ErrorOr<BAN::RefPtr<ATADevice>> ATADevice::create(BAN::RefPtr<ATABus> bus, ATABus::DeviceType type, bool is_secondary, BAN::Span<const uint16_t> identify_data)
{
auto* device_ptr = new ATADevice(bus, type, is_secondary);
if (device_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
auto device = BAN::RefPtr<ATADevice>::adopt(device_ptr);
TRY(device->initialize(identify_data));
return device;
}
ATADevice::ATADevice(BAN::RefPtr<ATABus> bus, ATABus::DeviceType type, bool is_secondary)
: m_bus(bus)
, m_type(type)
, m_is_secondary(is_secondary)
{ }
BAN::ErrorOr<void> ATADevice::read_sectors_impl(uint64_t lba, uint64_t sector_count, uint8_t* buffer)
{
TRY(m_bus->read(*this, lba, sector_count, buffer));
return {};
}
BAN::ErrorOr<void> ATADevice::write_sectors_impl(uint64_t lba, uint64_t sector_count, const uint8_t* buffer)
{
TRY(m_bus->write(*this, lba, sector_count, buffer));
return {};
}
} }

View File

@ -283,9 +283,9 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<void> StorageDevice::read_sectors(uint64_t lba, uint8_t sector_count, uint8_t* buffer) BAN::ErrorOr<void> StorageDevice::read_sectors(uint64_t lba, uint64_t sector_count, uint8_t* buffer)
{ {
for (uint8_t offset = 0; offset < sector_count; offset++) for (uint64_t offset = 0; offset < sector_count; offset++)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
Thread::TerminateBlocker blocker(Thread::current()); Thread::TerminateBlocker blocker(Thread::current());
@ -302,7 +302,7 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<void> StorageDevice::write_sectors(uint64_t lba, uint8_t sector_count, const uint8_t* buffer) BAN::ErrorOr<void> StorageDevice::write_sectors(uint64_t lba, uint64_t sector_count, const uint8_t* buffer)
{ {
// TODO: use disk cache for dirty pages. I don't wanna think about how to do it safely now // TODO: use disk cache for dirty pages. I don't wanna think about how to do it safely now
for (uint8_t offset = 0; offset < sector_count; offset++) for (uint8_t offset = 0; offset < sector_count; offset++)