diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index e92fbe0d..11dca457 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -106,6 +106,8 @@ set(KERNEL_SOURCES kernel/USB/HID/HIDDriver.cpp kernel/USB/HID/Keyboard.cpp kernel/USB/HID/Mouse.cpp + kernel/USB/MassStorage/MassStorageDriver.cpp + kernel/USB/MassStorage/SCSIDevice.cpp kernel/USB/USBManager.cpp kernel/USB/XHCI/Controller.cpp kernel/USB/XHCI/Device.cpp diff --git a/kernel/include/kernel/Debug.h b/kernel/include/kernel/Debug.h index 9bd8b901..c626f045 100644 --- a/kernel/include/kernel/Debug.h +++ b/kernel/include/kernel/Debug.h @@ -66,6 +66,7 @@ #define DEBUG_USB_HID 0 #define DEBUG_USB_KEYBOARD 0 #define DEBUG_USB_MOUSE 0 +#define DEBUG_USB_MASS_STORAGE 0 namespace Debug diff --git a/kernel/include/kernel/USB/MassStorage/Definitions.h b/kernel/include/kernel/USB/MassStorage/Definitions.h new file mode 100644 index 00000000..c2931485 --- /dev/null +++ b/kernel/include/kernel/USB/MassStorage/Definitions.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +namespace Kernel::USBMassStorage +{ + + struct CBW + { + uint32_t dCBWSignature; + uint32_t dCBWTag; + uint32_t dCBWDataTransferLength; + uint8_t bmCBWFlags; + uint8_t bCBWLUN; + uint8_t bCBWCBLength; + uint8_t CBWCB[16]; + } __attribute__((packed)); + static_assert(sizeof(CBW) == 31); + + struct CSW + { + uint32_t dCSWSignature; + uint32_t dCSWTag; + uint32_t dCSWDataResidue; + uint8_t bmCSWStatus; + } __attribute__((packed)); + static_assert(sizeof(CSW) == 13); + +} diff --git a/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h b/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h new file mode 100644 index 00000000..5cee4fe5 --- /dev/null +++ b/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +#include +#include +#include + +namespace Kernel +{ + + class USBMassStorageDriver final : public USBClassDriver + { + BAN_NON_COPYABLE(USBMassStorageDriver); + BAN_NON_MOVABLE(USBMassStorageDriver); + + public: + void handle_input_data(size_t byte_count, uint8_t endpoint_id) override; + + BAN::ErrorOr send_bytes(paddr_t, size_t count); + BAN::ErrorOr recv_bytes(paddr_t, size_t count); + + void lock() { m_mutex.lock(); } + void unlock() { m_mutex.unlock(); } + + private: + USBMassStorageDriver(USBDevice&, const USBDevice::InterfaceDescriptor&); + ~USBMassStorageDriver(); + + BAN::ErrorOr initialize() override; + + private: + USBDevice& m_device; + USBDevice::InterfaceDescriptor m_interface; + + Mutex m_mutex; + + uint8_t m_in_endpoint_id { 0 }; + BAN::Function m_in_callback; + + uint8_t m_out_endpoint_id { 0 }; + BAN::Function m_out_callback; + + BAN::Vector> m_storage_devices; + + friend class BAN::UniqPtr; + }; + +} diff --git a/kernel/include/kernel/USB/MassStorage/SCSIDevice.h b/kernel/include/kernel/USB/MassStorage/SCSIDevice.h new file mode 100644 index 00000000..7c05d766 --- /dev/null +++ b/kernel/include/kernel/USB/MassStorage/SCSIDevice.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include + +namespace Kernel +{ + + class USBSCSIDevice : public StorageDevice + { + public: + static BAN::ErrorOr> create(USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size); + + uint32_t sector_size() const override { return m_block_size; } + uint64_t total_size() const override { return m_block_size * m_block_count; } + + dev_t rdev() const override { return m_rdev; } + BAN::StringView name() const override { return m_name; } + + private: + USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr&&, uint64_t block_count, uint32_t block_size); + ~USBSCSIDevice(); + + static BAN::ErrorOr send_scsi_command_impl(USBMassStorageDriver&, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan command, BAN::ByteSpan data, bool in); + BAN::ErrorOr send_scsi_command(BAN::ConstByteSpan command, BAN::ByteSpan data, bool in); + + BAN::ErrorOr read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer) override; + BAN::ErrorOr write_sectors_impl(uint64_t lba, uint64_t sector_count, BAN::ConstByteSpan buffer) override; + + private: + USBMassStorageDriver& m_driver; + BAN::UniqPtr m_dma_region; + + const uint8_t m_lun; + + const uint64_t m_block_count; + const uint32_t m_block_size; + + const dev_t m_rdev; + const char m_name[4]; + + friend class BAN::RefPtr; + }; + +} diff --git a/kernel/kernel/USB/Device.cpp b/kernel/kernel/USB/Device.cpp index 3bbdb3c1..4967a3c7 100644 --- a/kernel/kernel/USB/Device.cpp +++ b/kernel/kernel/USB/Device.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #define USB_DUMP_DESCRIPTORS 0 @@ -167,7 +168,8 @@ namespace Kernel dprintln_if(DEBUG_USB, "Found Printer interface"); break; case USB::InterfaceBaseClass::MassStorage: - dprintln_if(DEBUG_USB, "Found MassStorage interface"); + if (auto result = BAN::UniqPtr::create(*this, interface); !result.is_error()) + TRY(m_class_drivers.push_back(result.release_value())); break; case USB::InterfaceBaseClass::CDCData: dprintln_if(DEBUG_USB, "Found CDCData interface"); diff --git a/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp b/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp new file mode 100644 index 00000000..23947dd8 --- /dev/null +++ b/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp @@ -0,0 +1,199 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace Kernel +{ + + USBMassStorageDriver::USBMassStorageDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface) + : m_device(device) + , m_interface(interface) + { } + + USBMassStorageDriver::~USBMassStorageDriver() + { } + + BAN::ErrorOr USBMassStorageDriver::initialize() + { + if (m_interface.descriptor.bInterfaceProtocol != 0x50) + { + dwarnln("Only USB Mass Storage BBB is supported"); + return BAN::Error::from_errno(ENOTSUP); + } + + auto dma_region = TRY(DMARegion::create(PAGE_SIZE)); + + // Bulk-Only Mass Storage Reset + { + USBDeviceRequest reset_request { + .bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Class | USB::RequestType::Interface, + .bRequest = 0xFF, + .wValue = 0x0000, + .wIndex = m_interface.descriptor.bInterfaceNumber, + .wLength = 0x0000, + }; + + TRY(m_device.send_request(reset_request, 0)); + } + + // Get Max LUN + { + USBDeviceRequest lun_request { + .bmRequestType = USB::RequestType::DeviceToHost | USB::RequestType::Class | USB::RequestType::Interface, + .bRequest = 0xFE, + .wValue = 0x0000, + .wIndex = m_interface.descriptor.bInterfaceNumber, + .wLength = 0x0001, + }; + + uint32_t max_lun = 0; + const auto lun_result = m_device.send_request(lun_request, dma_region->paddr()); + if (!lun_result.is_error() && lun_result.value() == 1) + max_lun = *reinterpret_cast(dma_region->vaddr()); + TRY(m_storage_devices.resize(max_lun + 1)); + } + + uint32_t max_packet_size = -1; + + // Initialize bulk-in and bulk-out endpoints + { + constexpr size_t invalid_index = -1; + + size_t bulk_in_index = invalid_index; + size_t bulk_out_index = invalid_index; + + for (size_t i = 0; i < m_interface.endpoints.size(); i++) + { + const auto& endpoint = m_interface.endpoints[i].descriptor; + if (endpoint.bmAttributes != 0x02) + continue; + ((endpoint.bEndpointAddress & 0x80) ? bulk_in_index : bulk_out_index) = i; + } + + if (bulk_in_index == invalid_index || bulk_out_index == invalid_index) + { + dwarnln("USB Mass Storage device does not contain bulk-in and bulk-out endpoints"); + return BAN::Error::from_errno(EFAULT); + } + + TRY(m_device.initialize_endpoint(m_interface.endpoints[bulk_in_index].descriptor)); + TRY(m_device.initialize_endpoint(m_interface.endpoints[bulk_out_index].descriptor)); + + { + const auto& desc = m_interface.endpoints[bulk_in_index].descriptor; + m_in_endpoint_id = (desc.bEndpointAddress & 0x0F) * 2 + !!(desc.bEndpointAddress & 0x80); + max_packet_size = BAN::Math::min(max_packet_size, desc.wMaxPacketSize); + } + + { + const auto& desc = m_interface.endpoints[bulk_out_index].descriptor; + m_out_endpoint_id = (desc.bEndpointAddress & 0x0F) * 2 + !!(desc.bEndpointAddress & 0x80); + max_packet_size = BAN::Math::min(max_packet_size, desc.wMaxPacketSize); + } + } + + BAN::Function>(USBMassStorageDriver&, uint8_t, uint32_t)> create_device_func; + switch (m_interface.descriptor.bInterfaceSubClass) + { + case 0x06: + create_device_func = + [](USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size) -> BAN::ErrorOr> + { + auto ret = TRY(USBSCSIDevice::create(driver, lun, max_packet_size)); + return BAN::RefPtr(ret); + }; + break; + default: + dwarnln("Unsupported command block {2H}", m_interface.descriptor.bInterfaceSubClass); + return BAN::Error::from_errno(ENOTSUP); + } + + ASSERT(m_storage_devices.size() <= 0xFF); + for (uint8_t lun = 0; lun < m_storage_devices.size(); lun++) + m_storage_devices[lun] = TRY(create_device_func(*this, lun, max_packet_size)); + + return {}; + } + + BAN::ErrorOr USBMassStorageDriver::send_bytes(paddr_t paddr, size_t count) + { + ASSERT(m_mutex.is_locked()); + + constexpr size_t invalid = -1; + + static volatile size_t bytes_sent; + bytes_sent = invalid; + + ASSERT(!m_out_callback); + m_out_callback = [](size_t bytes) { bytes_sent = bytes; }; + BAN::ScopeGuard _([this] { m_out_callback.clear(); }); + + m_device.send_data_buffer(m_out_endpoint_id, paddr, count); + + const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + 100; + while (bytes_sent == invalid) + if (SystemTimer::get().ms_since_boot() > timeout_ms) + return BAN::Error::from_errno(EIO); + + return static_cast(bytes_sent); + } + + BAN::ErrorOr USBMassStorageDriver::recv_bytes(paddr_t paddr, size_t count) + { + ASSERT(m_mutex.is_locked()); + + constexpr size_t invalid = -1; + + static volatile size_t bytes_recv; + bytes_recv = invalid; + + ASSERT(!m_in_callback); + m_in_callback = [](size_t bytes) { bytes_recv = bytes; }; + BAN::ScopeGuard _([this] { m_in_callback.clear(); }); + + m_device.send_data_buffer(m_in_endpoint_id, paddr, count); + + const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + 100; + while (bytes_recv == invalid) + if (SystemTimer::get().ms_since_boot() > timeout_ms) + return BAN::Error::from_errno(EIO); + + m_in_callback.clear(); + + return static_cast(bytes_recv); + } + + void USBMassStorageDriver::handle_input_data(size_t byte_count, uint8_t endpoint_id) + { + if (endpoint_id != m_in_endpoint_id && endpoint_id != m_out_endpoint_id) + return; + + dprintln_if(DEBUG_USB_MASS_STORAGE, "got {} bytes to {} endpoint", byte_count, endpoint_id == m_in_endpoint_id ? "IN" : "OUT"); + + if (endpoint_id == m_in_endpoint_id) + { + if (m_in_callback) + m_in_callback(byte_count); + else + dwarnln("ignoring {} bytes to IN endpoint", byte_count); + return; + } + + if (endpoint_id == m_out_endpoint_id) + { + if (m_out_callback) + m_out_callback(byte_count); + else + dwarnln("ignoring {} bytes to OUT endpoint", byte_count); + return; + } + + } + +} diff --git a/kernel/kernel/USB/MassStorage/SCSIDevice.cpp b/kernel/kernel/USB/MassStorage/SCSIDevice.cpp new file mode 100644 index 00000000..70242fee --- /dev/null +++ b/kernel/kernel/USB/MassStorage/SCSIDevice.cpp @@ -0,0 +1,265 @@ +#include + +#include +#include +#include + +#include + +namespace Kernel +{ + + namespace SCSI + { + + struct InquiryRes + { + uint8_t peripheral_device_type : 5; + uint8_t peripheral_qualifier : 3; + + uint8_t reserved0 : 7; + uint8_t rmb : 1; + + uint8_t version; + + uint8_t response_data_format : 4; + uint8_t hisup : 1; + uint8_t normaca : 1; + uint8_t obsolete0 : 1; + uint8_t obsolete1 : 1; + + uint8_t additional_length; + + uint8_t protect : 1; + uint8_t reserved1 : 2; + uint8_t _3pc : 1; + uint8_t tgps : 2; + uint8_t acc : 1; + uint8_t sccs : 1; + + uint8_t obsolete2 : 1; + uint8_t obsolete3 : 1; + uint8_t obsolete4 : 1; + uint8_t obsolete5 : 1; + uint8_t multip : 1; + uint8_t vs0 : 1; + uint8_t encserv : 1; + uint8_t obsolete6 : 1; + + uint8_t vs1 : 1; + uint8_t cmdque : 1; + uint8_t obsolete7 : 1; + uint8_t obsolete8 : 1; + uint8_t obsolete9 : 1; + uint8_t obsolete10 : 1; + uint8_t obsolete11 : 1; + uint8_t obsolete12 : 1; + + uint8_t t10_vendor_identification[8]; + uint8_t product_identification[16]; + uint8_t product_revision_level[4]; + }; + static_assert(sizeof(InquiryRes) == 36); + + struct ReadCapacity10 + { + BAN::BigEndian logical_block_address {}; + BAN::BigEndian block_length; + }; + static_assert(sizeof(ReadCapacity10) == 8); + + } + + BAN::ErrorOr> USBSCSIDevice::create(USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size) + { + auto dma_region = TRY(DMARegion::create(max_packet_size)); + + dprintln("USB SCSI device"); + + { + const uint8_t scsi_inquiry_req[6] { + 0x12, + 0x00, + 0x00, + 0x00, sizeof(SCSI::InquiryRes), + 0x00 + }; + SCSI::InquiryRes inquiry_res; + TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res), true)); + + dprintln(" vendor: {}", BAN::StringView(reinterpret_cast(inquiry_res.t10_vendor_identification), 8)); + dprintln(" product: {}", BAN::StringView(reinterpret_cast(inquiry_res.product_identification), 16)); + dprintln(" revision: {}", BAN::StringView(reinterpret_cast(inquiry_res.product_revision_level), 4)); + } + + uint32_t block_count; + uint32_t block_size; + + { + const uint8_t scsi_read_capacity_req[10] { + 0x25, + 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + 0x00, + 0x00 + }; + SCSI::ReadCapacity10 read_capacity_res; + TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res), true)); + + block_count = read_capacity_res.logical_block_address + 1; + block_size = read_capacity_res.block_length; + + if (block_count == 0) + { + dwarnln("Too big USB storage"); + return BAN::Error::from_errno(ENOTSUP); + } + + dprintln(" last LBA: {}", block_count); + dprintln(" block size: {}", block_size); + dprintln(" total size: {} MiB", block_count * block_size / 1024 / 1024); + } + + auto result = TRY(BAN::RefPtr::create(driver, lun, BAN::move(dma_region), block_count, block_size)); + result->add_disk_cache(); + DevFileSystem::get().add_device(result); + if (auto res = result->initialize_partitions(result->name()); res.is_error()) + dprintln("{}", res.error()); + return result; + } + + USBSCSIDevice::USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr&& dma_region, uint64_t block_count, uint32_t block_size) + : m_driver(driver) + , m_dma_region(BAN::move(dma_region)) + , m_lun(lun) + , m_block_count(block_count) + , m_block_size(block_size) + , m_rdev(scsi_get_rdev()) + , m_name { 's', 'd', (char)('a' + minor(m_rdev)), '\0' } + { } + + USBSCSIDevice::~USBSCSIDevice() + { + scsi_free_rdev(m_rdev); + } + + BAN::ErrorOr USBSCSIDevice::send_scsi_command(BAN::ConstByteSpan scsi_command, BAN::ByteSpan data, bool in) + { + return TRY(send_scsi_command_impl(m_driver, *m_dma_region, m_lun, scsi_command, data, in)); + } + + BAN::ErrorOr USBSCSIDevice::send_scsi_command_impl(USBMassStorageDriver& driver, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan scsi_command, BAN::ByteSpan data, bool in) + { + ASSERT(scsi_command.size() <= 16); + + LockGuard _(driver); + + auto& cbw = *reinterpret_cast(dma_region.vaddr()); + cbw = { + .dCBWSignature = 0x43425355, + .dCBWTag = 0x00000000, + .dCBWDataTransferLength = static_cast(data.size()), + .bmCBWFlags = static_cast(in ? 0x80 : 0x00), + .bCBWLUN = lun, + .bCBWCBLength = static_cast(scsi_command.size()), + .CBWCB = {}, + }; + memcpy(cbw.CBWCB, scsi_command.data(), scsi_command.size()); + + if (TRY(driver.send_bytes(dma_region.paddr(), sizeof(USBMassStorage::CBW))) != sizeof(USBMassStorage::CBW)) + { + dwarnln("failed to send full CBW"); + return BAN::Error::from_errno(EFAULT); + } + + const size_t ntransfer = + TRY([&]() -> BAN::ErrorOr + { + if (data.empty()) + return 0; + if (in) + return TRY(driver.recv_bytes(dma_region.paddr(), data.size())); + memcpy(reinterpret_cast(dma_region.vaddr()), data.data(), data.size()); + return TRY(driver.send_bytes(dma_region.paddr(), data.size())); + }()); + + if (ntransfer > data.size()) + { + dwarnln("device responded with more bytes than requested"); + return BAN::Error::from_errno(EFAULT); + } + + if (in && !data.empty()) + memcpy(data.data(), reinterpret_cast(dma_region.vaddr()), ntransfer); + + if (TRY(driver.recv_bytes(dma_region.paddr(), sizeof(USBMassStorage::CSW))) != sizeof(USBMassStorage::CSW)) + { + dwarnln("could not receive full CSW"); + return BAN::Error::from_errno(EFAULT); + } + + if (auto status = reinterpret_cast(dma_region.vaddr())->bmCSWStatus) + { + dwarnln("CSW status {2H}", status); + return BAN::Error::from_errno(EFAULT); + } + + return ntransfer; + } + + BAN::ErrorOr USBSCSIDevice::read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer) + { + const size_t max_blocks_per_read = m_dma_region->size() / m_block_size; + ASSERT(max_blocks_per_read <= 0xFFFF); + + for (uint64_t i = 0; i < sector_count;) + { + const uint32_t lba = first_lba + i; + const uint32_t count = BAN::Math::min(max_blocks_per_read, sector_count - i); + + const uint8_t scsi_read_req[10] { + 0x28, + 0x00, + (uint8_t)(lba >> 24), (uint8_t)(lba >> 16), (uint8_t)(lba >> 8), (uint8_t)(lba >> 0), + 0x00, + (uint8_t)(count >> 8), (uint8_t)(count >> 0), + 0x00 + }; + TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size), true)); + + i += count; + } + + return {}; + } + + BAN::ErrorOr USBSCSIDevice::write_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ConstByteSpan _buffer) + { + const size_t max_blocks_per_write = m_dma_region->size() / m_block_size; + ASSERT(max_blocks_per_write <= 0xFFFF); + + auto buffer = BAN::ByteSpan(const_cast(_buffer.data()), _buffer.size()); + + for (uint64_t i = 0; i < sector_count;) + { + const uint32_t lba = first_lba + i; + const uint32_t count = BAN::Math::min(max_blocks_per_write, sector_count - i); + + const uint8_t scsi_write_req[10] { + 0x2A, + 0x00, + (uint8_t)(lba >> 24), (uint8_t)(lba >> 16), (uint8_t)(lba >> 8), (uint8_t)(lba >> 0), + 0x00, + (uint8_t)(count >> 8), (uint8_t)(count >> 0), + 0x00 + }; + TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size), false)); + + i += count; + } + + return {}; + } + +}