From 793c0368f2e96cf4392f5d90e66629bfe8ccff0d Mon Sep 17 00:00:00 2001 From: Bananymous Date: Sat, 23 Nov 2024 01:24:32 +0200 Subject: [PATCH] Kernel: Refactor USB mass storage code Also increment command timeout to 10 seconds so commands don't timeout when they are not supposted to :) --- .../USB/MassStorage/MassStorageDriver.h | 12 +- .../kernel/USB/MassStorage/SCSIDevice.h | 10 +- .../USB/MassStorage/MassStorageDriver.cpp | 181 +++++++++++++++--- kernel/kernel/USB/MassStorage/SCSIDevice.cpp | 98 ++-------- 4 files changed, 186 insertions(+), 115 deletions(-) diff --git a/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h b/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h index f24b2b9c..bd26e798 100644 --- a/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h +++ b/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h @@ -14,6 +14,9 @@ namespace Kernel BAN_NON_COPYABLE(USBMassStorageDriver); BAN_NON_MOVABLE(USBMassStorageDriver); + public: + static constexpr size_t transfer_stall = -2; + public: void handle_stall(uint8_t endpoint_id) override; void handle_input_data(size_t byte_count, uint8_t endpoint_id) override; @@ -21,14 +24,17 @@ namespace Kernel BAN::ErrorOr send_bytes(paddr_t, size_t count); BAN::ErrorOr recv_bytes(paddr_t, size_t count); - void lock() { m_mutex.lock(); } - void unlock() { m_mutex.unlock(); } + template + BAN::ErrorOr send_command(uint8_t lun, BAN::ConstByteSpan scsi_command, SPAN data); private: USBMassStorageDriver(USBDevice&, const USBDevice::InterfaceDescriptor&); ~USBMassStorageDriver(); BAN::ErrorOr initialize() override; + BAN::ErrorOr mass_storage_reset(); + BAN::ErrorOr clear_feature(uint8_t endpoint_id); + BAN::ErrorOr reset_recovery(); private: USBDevice& m_device; @@ -42,6 +48,8 @@ namespace Kernel uint8_t m_out_endpoint_id { 0 }; BAN::Function m_out_callback; + BAN::UniqPtr m_data_region; + BAN::Vector> m_storage_devices; friend class BAN::UniqPtr; diff --git a/kernel/include/kernel/USB/MassStorage/SCSIDevice.h b/kernel/include/kernel/USB/MassStorage/SCSIDevice.h index 645d13fa..f3a5733e 100644 --- a/kernel/include/kernel/USB/MassStorage/SCSIDevice.h +++ b/kernel/include/kernel/USB/MassStorage/SCSIDevice.h @@ -19,22 +19,16 @@ namespace Kernel BAN::StringView name() const override { return m_name; } private: - USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr&&, uint64_t block_count, uint32_t block_size); + USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size, uint64_t block_count, uint32_t block_size); ~USBSCSIDevice(); - template> - BAN::ErrorOr send_scsi_command(BAN::ConstByteSpan command, SPAN data); - - template> - static BAN::ErrorOr send_scsi_command_impl(USBMassStorageDriver&, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan command, SPAN data); - BAN::ErrorOr read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer) override; BAN::ErrorOr write_sectors_impl(uint64_t lba, uint64_t sector_count, BAN::ConstByteSpan buffer) override; private: USBMassStorageDriver& m_driver; - BAN::UniqPtr m_dma_region; + const uint32_t m_max_packet_size; const uint8_t m_lun; const uint64_t m_block_count; diff --git a/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp b/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp index acfa4117..376a1f76 100644 --- a/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp +++ b/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp @@ -5,12 +5,15 @@ #include #include #include +#include #include #include namespace Kernel { + static constexpr uint64_t s_timeout_ms = 10'000; + USBMassStorageDriver::USBMassStorageDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface) : m_device(device) , m_interface(interface) @@ -27,20 +30,9 @@ namespace Kernel return BAN::Error::from_errno(ENOTSUP); } - auto dma_region = TRY(DMARegion::create(PAGE_SIZE)); + m_data_region = TRY(DMARegion::create(PAGE_SIZE)); - // Bulk-Only Mass Storage Reset - { - USBDeviceRequest reset_request { - .bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Class | USB::RequestType::Interface, - .bRequest = 0xFF, - .wValue = 0x0000, - .wIndex = m_interface.descriptor.bInterfaceNumber, - .wLength = 0x0000, - }; - - TRY(m_device.send_request(reset_request, 0)); - } + TRY(mass_storage_reset()); // Get Max LUN { @@ -53,9 +45,9 @@ namespace Kernel }; uint32_t max_lun = 0; - const auto lun_result = m_device.send_request(lun_request, dma_region->paddr()); + const auto lun_result = m_device.send_request(lun_request, m_data_region->paddr()); if (!lun_result.is_error() && lun_result.value() == 1) - max_lun = *reinterpret_cast(dma_region->vaddr()); + max_lun = *reinterpret_cast(m_data_region->vaddr()); TRY(m_storage_devices.resize(max_lun + 1)); } @@ -105,8 +97,9 @@ namespace Kernel create_device_func = [](USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size) -> BAN::ErrorOr> { - auto ret = TRY(USBSCSIDevice::create(driver, lun, max_packet_size)); - return BAN::RefPtr(ret); + return BAN::RefPtr( + TRY(USBSCSIDevice::create(driver, lun, max_packet_size)) + ); }; break; default: @@ -121,6 +114,45 @@ namespace Kernel return {}; } + BAN::ErrorOr USBMassStorageDriver::mass_storage_reset() + { + USBDeviceRequest reset_request { + .bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Class | USB::RequestType::Interface, + .bRequest = 0xFF, + .wValue = 0x0000, + .wIndex = m_interface.descriptor.bInterfaceNumber, + .wLength = 0x0000, + }; + + TRY(m_device.send_request(reset_request, 0)); + return {}; + } + + BAN::ErrorOr USBMassStorageDriver::clear_feature(uint8_t endpoint_id) + { + const uint8_t direction = (endpoint_id % 2) ? 0x80 : 0x00; + const uint8_t number = endpoint_id / 2; + + USBDeviceRequest clear_feature_request { + .bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Standard | USB::RequestType::Endpoint, + .bRequest = USB::Request::CLEAR_FEATURE, + .wValue = 0x0000, + .wIndex = static_cast(direction | number), + .wLength = 0x0000, + }; + + TRY(m_device.send_request(clear_feature_request, 0)); + return {}; + } + + BAN::ErrorOr USBMassStorageDriver::reset_recovery() + { + TRY(mass_storage_reset()); + TRY(clear_feature(m_in_endpoint_id)); + TRY(clear_feature(m_out_endpoint_id)); + return {}; + } + BAN::ErrorOr USBMassStorageDriver::send_bytes(paddr_t paddr, size_t count) { ASSERT(m_mutex.is_locked()); @@ -136,10 +168,15 @@ namespace Kernel m_device.send_data_buffer(m_out_endpoint_id, paddr, count); - const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + 100; + const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + s_timeout_ms; while (bytes_sent == invalid) - if (SystemTimer::get().ms_since_boot() > timeout_ms) - return BAN::Error::from_errno(EIO); + { + if (SystemTimer::get().ms_since_boot() < timeout_ms) + continue; + if (reset_recovery().is_error()) + dwarnln_if(DEBUG_USB_MASS_STORAGE, "could not reset USBMassStorage"); + return BAN::Error::from_errno(EIO); + } return static_cast(bytes_sent); } @@ -159,20 +196,114 @@ namespace Kernel m_device.send_data_buffer(m_in_endpoint_id, paddr, count); - const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + 100; + const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + s_timeout_ms; while (bytes_recv == invalid) - if (SystemTimer::get().ms_since_boot() > timeout_ms) - return BAN::Error::from_errno(EIO); + { + if (SystemTimer::get().ms_since_boot() < timeout_ms) + continue; + if (reset_recovery().is_error()) + dwarnln_if(DEBUG_USB_MASS_STORAGE, "could not reset USBMassStorage"); + return BAN::Error::from_errno(EIO); + } m_in_callback.clear(); return static_cast(bytes_recv); } + template + BAN::ErrorOr USBMassStorageDriver::send_command(uint8_t lun, BAN::ConstByteSpan command, SPAN data) + { + ASSERT(command.size() <= 16); + + LockGuard _(m_mutex); + + auto& cbw = *reinterpret_cast(m_data_region->vaddr()); + cbw = { + .dCBWSignature = 0x43425355, + .dCBWTag = 0x00000000, + .dCBWDataTransferLength = static_cast(data.size()), + .bmCBWFlags = IN ? 0x80 : 0x00, + .bCBWLUN = lun, + .bCBWCBLength = static_cast(command.size()), + .CBWCB = {}, + }; + memcpy(cbw.CBWCB, command.data(), command.size()); + + if (TRY(send_bytes(m_data_region->paddr(), sizeof(USBMassStorage::CBW))) != sizeof(USBMassStorage::CBW)) + { + dwarnln("failed to send CBW"); + return BAN::Error::from_errno(EIO); + } + + const size_t ntransfer = + TRY([&]() -> BAN::ErrorOr + { + if (data.empty()) + return 0; + if constexpr (IN) + return TRY(recv_bytes(m_data_region->paddr(), data.size())); + memcpy(reinterpret_cast(m_data_region->vaddr()), data.data(), data.size()); + return TRY(send_bytes(m_data_region->paddr(), data.size())); + }()); + + if (ntransfer == transfer_stall) + TRY(clear_feature(IN ? m_in_endpoint_id : m_out_endpoint_id)); + + if constexpr (IN) + memcpy(data.data(), reinterpret_cast(m_data_region->vaddr()), ntransfer); + + size_t csw_ntransfer = TRY(recv_bytes(m_data_region->paddr(), sizeof(USBMassStorage::CSW))); + if (csw_ntransfer == transfer_stall) + { + TRY(clear_feature(m_in_endpoint_id)); + csw_ntransfer = TRY(recv_bytes(m_data_region->paddr(), sizeof(USBMassStorage::CSW))); + } + + if (csw_ntransfer != sizeof(USBMassStorage::CSW)) + { + dwarnln("could not receive CSW"); + return BAN::Error::from_errno(EFAULT); + } + + const auto& csw = *reinterpret_cast(m_data_region->vaddr()); + switch (csw.bmCSWStatus) + { + case 0x00: + case 0x01: + return data.size() - csw.dCSWDataResidue; + default: + dwarnln_if(DEBUG_USB_MASS_STORAGE, "received invalid CSW"); + // fall through + case 0x02: + TRY(reset_recovery()); + return BAN::Error::from_errno(EIO); + } + + ASSERT_NOT_REACHED(); + } + + template BAN::ErrorOr USBMassStorageDriver::send_command(uint8_t, BAN::ConstByteSpan, BAN::ByteSpan); + template BAN::ErrorOr USBMassStorageDriver::send_command(uint8_t, BAN::ConstByteSpan, BAN::ConstByteSpan); + void USBMassStorageDriver::handle_stall(uint8_t endpoint_id) { - (void)endpoint_id; - // FIXME: do something :) + if (endpoint_id != m_in_endpoint_id && endpoint_id != m_out_endpoint_id) + return; + + dprintln_if(DEBUG_USB_MASS_STORAGE, "got STALL to {} endpoint", endpoint_id == m_in_endpoint_id ? "IN" : "OUT"); + + if (m_in_endpoint_id == endpoint_id) + { + ASSERT(m_in_callback); + return m_in_callback(transfer_stall); + } + + if (m_out_endpoint_id == endpoint_id) + { + ASSERT(m_out_callback); + return m_out_callback(transfer_stall); + } } void USBMassStorageDriver::handle_input_data(size_t byte_count, uint8_t endpoint_id) diff --git a/kernel/kernel/USB/MassStorage/SCSIDevice.cpp b/kernel/kernel/USB/MassStorage/SCSIDevice.cpp index 495dde47..9033554b 100644 --- a/kernel/kernel/USB/MassStorage/SCSIDevice.cpp +++ b/kernel/kernel/USB/MassStorage/SCSIDevice.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include @@ -72,8 +71,6 @@ namespace Kernel BAN::ErrorOr> USBSCSIDevice::create(USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size) { - auto dma_region = TRY(DMARegion::create(max_packet_size)); - dprintln("USB SCSI device"); { @@ -85,14 +82,14 @@ namespace Kernel 0x00 }; SCSI::InquiryRes inquiry_res; - TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res))); + TRY(driver.send_command(lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res))); dprintln(" vendor: {}", BAN::StringView(reinterpret_cast(inquiry_res.t10_vendor_identification), 8)); dprintln(" product: {}", BAN::StringView(reinterpret_cast(inquiry_res.product_identification), 16)); dprintln(" revision: {}", BAN::StringView(reinterpret_cast(inquiry_res.product_revision_level), 4)); } - uint32_t block_count; + uint64_t block_count; uint32_t block_size; { @@ -105,7 +102,7 @@ namespace Kernel 0x00 }; SCSI::ReadCapacity10 read_capacity_res; - TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res))); + TRY(driver.send_command(lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res))); block_count = read_capacity_res.logical_block_address + 1; block_size = read_capacity_res.block_length; @@ -121,7 +118,7 @@ namespace Kernel dprintln(" total size: {} MiB", block_count * block_size / 1024 / 1024); } - auto result = TRY(BAN::RefPtr::create(driver, lun, BAN::move(dma_region), block_count, block_size)); + auto result = TRY(BAN::RefPtr::create(driver, lun, max_packet_size, block_count, block_size)); result->add_disk_cache(); DevFileSystem::get().add_device(result); if (auto res = result->initialize_partitions(result->name()); res.is_error()) @@ -129,9 +126,9 @@ namespace Kernel return result; } - USBSCSIDevice::USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr&& dma_region, uint64_t block_count, uint32_t block_size) + USBSCSIDevice::USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size, uint64_t block_count, uint32_t block_size) : m_driver(driver) - , m_dma_region(BAN::move(dma_region)) + , m_max_packet_size(max_packet_size) , m_lun(lun) , m_block_count(block_count) , m_block_size(block_size) @@ -144,75 +141,11 @@ namespace Kernel scsi_free_rdev(m_rdev); } - template - BAN::ErrorOr USBSCSIDevice::send_scsi_command(BAN::ConstByteSpan scsi_command, SPAN data) - { - return TRY(send_scsi_command_impl(m_driver, *m_dma_region, m_lun, scsi_command, data)); - } - - template - BAN::ErrorOr USBSCSIDevice::send_scsi_command_impl(USBMassStorageDriver& driver, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan scsi_command, SPAN data) - { - ASSERT(scsi_command.size() <= 16); - - LockGuard _(driver); - - auto& cbw = *reinterpret_cast(dma_region.vaddr()); - cbw = { - .dCBWSignature = 0x43425355, - .dCBWTag = 0x00000000, - .dCBWDataTransferLength = static_cast(data.size()), - .bmCBWFlags = IN ? 0x80 : 0x00, - .bCBWLUN = lun, - .bCBWCBLength = static_cast(scsi_command.size()), - .CBWCB = {}, - }; - memcpy(cbw.CBWCB, scsi_command.data(), scsi_command.size()); - - if (TRY(driver.send_bytes(dma_region.paddr(), sizeof(USBMassStorage::CBW))) != sizeof(USBMassStorage::CBW)) - { - dwarnln("failed to send full CBW"); - return BAN::Error::from_errno(EFAULT); - } - - const size_t ntransfer = - TRY([&]() -> BAN::ErrorOr - { - if (data.empty()) - return 0; - if constexpr(IN) - return TRY(driver.recv_bytes(dma_region.paddr(), data.size())); - memcpy(reinterpret_cast(dma_region.vaddr()), data.data(), data.size()); - return TRY(driver.send_bytes(dma_region.paddr(), data.size())); - }()); - - if (ntransfer != data.size()) - { - dwarnln("device responded with {}/{} bytes", ntransfer, data.size()); - return BAN::Error::from_errno(EFAULT); - } - - if constexpr (IN) - memcpy(data.data(), reinterpret_cast(dma_region.vaddr()), ntransfer); - - if (TRY(driver.recv_bytes(dma_region.paddr(), sizeof(USBMassStorage::CSW))) != sizeof(USBMassStorage::CSW)) - { - dwarnln("could not receive full CSW"); - return BAN::Error::from_errno(EFAULT); - } - - if (auto status = reinterpret_cast(dma_region.vaddr())->bmCSWStatus) - { - dwarnln("CSW status {2H}", status); - return BAN::Error::from_errno(EFAULT); - } - - return ntransfer; - } - BAN::ErrorOr USBSCSIDevice::read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer) { - const size_t max_blocks_per_read = m_dma_region->size() / m_block_size; + dprintln_if(DEBUG_USB_MASS_STORAGE, "read_blocks({}, {})", first_lba, sector_count); + + const size_t max_blocks_per_read = m_max_packet_size / m_block_size; ASSERT(max_blocks_per_read <= 0xFFFF); for (uint64_t i = 0; i < sector_count;) @@ -228,7 +161,10 @@ namespace Kernel (uint8_t)(count >> 8), (uint8_t)(count >> 0), 0x00 }; - TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size))); + + const size_t nread = TRY(m_driver.send_command(m_lun, BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size))); + if (nread != count * m_block_size) + return BAN::Error::from_errno(EIO); i += count; } @@ -238,9 +174,9 @@ namespace Kernel BAN::ErrorOr USBSCSIDevice::write_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ConstByteSpan buffer) { - dprintln("write_sectors_impl({}, {})", first_lba, sector_count); + dprintln_if(DEBUG_USB_MASS_STORAGE, "write_blocks({}, {})", first_lba, sector_count); - const size_t max_blocks_per_write = m_dma_region->size() / m_block_size; + const size_t max_blocks_per_write = m_max_packet_size / m_block_size; ASSERT(max_blocks_per_write <= 0xFFFF); for (uint64_t i = 0; i < sector_count;) @@ -256,7 +192,9 @@ namespace Kernel (uint8_t)(count >> 8), (uint8_t)(count >> 0), 0x00 }; - TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size))); + const size_t nwrite = TRY(m_driver.send_command(m_lun, BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size))); + if (nwrite != count * m_block_size) + return BAN::Error::from_errno(EIO); i += count; }