Kernel: Refactor USB mass storage code

Also increment command timeout to 10 seconds so commands don't timeout
when they are not supposted to :)
This commit is contained in:
Bananymous 2024-11-23 01:24:32 +02:00
parent 076001462e
commit 793c0368f2
4 changed files with 186 additions and 115 deletions

View File

@ -14,6 +14,9 @@ namespace Kernel
BAN_NON_COPYABLE(USBMassStorageDriver);
BAN_NON_MOVABLE(USBMassStorageDriver);
public:
static constexpr size_t transfer_stall = -2;
public:
void handle_stall(uint8_t endpoint_id) override;
void handle_input_data(size_t byte_count, uint8_t endpoint_id) override;
@ -21,14 +24,17 @@ namespace Kernel
BAN::ErrorOr<size_t> send_bytes(paddr_t, size_t count);
BAN::ErrorOr<size_t> recv_bytes(paddr_t, size_t count);
void lock() { m_mutex.lock(); }
void unlock() { m_mutex.unlock(); }
template<bool IN, typename SPAN>
BAN::ErrorOr<size_t> send_command(uint8_t lun, BAN::ConstByteSpan scsi_command, SPAN data);
private:
USBMassStorageDriver(USBDevice&, const USBDevice::InterfaceDescriptor&);
~USBMassStorageDriver();
BAN::ErrorOr<void> initialize() override;
BAN::ErrorOr<void> mass_storage_reset();
BAN::ErrorOr<void> clear_feature(uint8_t endpoint_id);
BAN::ErrorOr<void> reset_recovery();
private:
USBDevice& m_device;
@ -42,6 +48,8 @@ namespace Kernel
uint8_t m_out_endpoint_id { 0 };
BAN::Function<void(size_t)> m_out_callback;
BAN::UniqPtr<DMARegion> m_data_region;
BAN::Vector<BAN::RefPtr<StorageDevice>> m_storage_devices;
friend class BAN::UniqPtr<USBMassStorageDriver>;

View File

@ -19,22 +19,16 @@ namespace Kernel
BAN::StringView name() const override { return m_name; }
private:
USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr<DMARegion>&&, uint64_t block_count, uint32_t block_size);
USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size, uint64_t block_count, uint32_t block_size);
~USBSCSIDevice();
template<bool IN, typename SPAN = BAN::either_or_t<IN, BAN::ByteSpan, BAN::ConstByteSpan>>
BAN::ErrorOr<size_t> send_scsi_command(BAN::ConstByteSpan command, SPAN data);
template<bool IN, typename SPAN = BAN::either_or_t<IN, BAN::ByteSpan, BAN::ConstByteSpan>>
static BAN::ErrorOr<size_t> send_scsi_command_impl(USBMassStorageDriver&, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan command, SPAN data);
BAN::ErrorOr<void> read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer) override;
BAN::ErrorOr<void> write_sectors_impl(uint64_t lba, uint64_t sector_count, BAN::ConstByteSpan buffer) override;
private:
USBMassStorageDriver& m_driver;
BAN::UniqPtr<DMARegion> m_dma_region;
const uint32_t m_max_packet_size;
const uint8_t m_lun;
const uint64_t m_block_count;

View File

@ -5,12 +5,15 @@
#include <kernel/FS/VirtualFileSystem.h>
#include <kernel/Lock/LockGuard.h>
#include <kernel/Timer/Timer.h>
#include <kernel/USB/MassStorage/Definitions.h>
#include <kernel/USB/MassStorage/MassStorageDriver.h>
#include <kernel/USB/MassStorage/SCSIDevice.h>
namespace Kernel
{
static constexpr uint64_t s_timeout_ms = 10'000;
USBMassStorageDriver::USBMassStorageDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface)
: m_device(device)
, m_interface(interface)
@ -27,20 +30,9 @@ namespace Kernel
return BAN::Error::from_errno(ENOTSUP);
}
auto dma_region = TRY(DMARegion::create(PAGE_SIZE));
m_data_region = TRY(DMARegion::create(PAGE_SIZE));
// Bulk-Only Mass Storage Reset
{
USBDeviceRequest reset_request {
.bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Class | USB::RequestType::Interface,
.bRequest = 0xFF,
.wValue = 0x0000,
.wIndex = m_interface.descriptor.bInterfaceNumber,
.wLength = 0x0000,
};
TRY(m_device.send_request(reset_request, 0));
}
TRY(mass_storage_reset());
// Get Max LUN
{
@ -53,9 +45,9 @@ namespace Kernel
};
uint32_t max_lun = 0;
const auto lun_result = m_device.send_request(lun_request, dma_region->paddr());
const auto lun_result = m_device.send_request(lun_request, m_data_region->paddr());
if (!lun_result.is_error() && lun_result.value() == 1)
max_lun = *reinterpret_cast<uint8_t*>(dma_region->vaddr());
max_lun = *reinterpret_cast<uint8_t*>(m_data_region->vaddr());
TRY(m_storage_devices.resize(max_lun + 1));
}
@ -105,8 +97,9 @@ namespace Kernel
create_device_func =
[](USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size) -> BAN::ErrorOr<BAN::RefPtr<StorageDevice>>
{
auto ret = TRY(USBSCSIDevice::create(driver, lun, max_packet_size));
return BAN::RefPtr<StorageDevice>(ret);
return BAN::RefPtr<StorageDevice>(
TRY(USBSCSIDevice::create(driver, lun, max_packet_size))
);
};
break;
default:
@ -121,6 +114,45 @@ namespace Kernel
return {};
}
BAN::ErrorOr<void> USBMassStorageDriver::mass_storage_reset()
{
USBDeviceRequest reset_request {
.bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Class | USB::RequestType::Interface,
.bRequest = 0xFF,
.wValue = 0x0000,
.wIndex = m_interface.descriptor.bInterfaceNumber,
.wLength = 0x0000,
};
TRY(m_device.send_request(reset_request, 0));
return {};
}
BAN::ErrorOr<void> USBMassStorageDriver::clear_feature(uint8_t endpoint_id)
{
const uint8_t direction = (endpoint_id % 2) ? 0x80 : 0x00;
const uint8_t number = endpoint_id / 2;
USBDeviceRequest clear_feature_request {
.bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Standard | USB::RequestType::Endpoint,
.bRequest = USB::Request::CLEAR_FEATURE,
.wValue = 0x0000,
.wIndex = static_cast<uint16_t>(direction | number),
.wLength = 0x0000,
};
TRY(m_device.send_request(clear_feature_request, 0));
return {};
}
BAN::ErrorOr<void> USBMassStorageDriver::reset_recovery()
{
TRY(mass_storage_reset());
TRY(clear_feature(m_in_endpoint_id));
TRY(clear_feature(m_out_endpoint_id));
return {};
}
BAN::ErrorOr<size_t> USBMassStorageDriver::send_bytes(paddr_t paddr, size_t count)
{
ASSERT(m_mutex.is_locked());
@ -136,10 +168,15 @@ namespace Kernel
m_device.send_data_buffer(m_out_endpoint_id, paddr, count);
const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + 100;
const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + s_timeout_ms;
while (bytes_sent == invalid)
if (SystemTimer::get().ms_since_boot() > timeout_ms)
{
if (SystemTimer::get().ms_since_boot() < timeout_ms)
continue;
if (reset_recovery().is_error())
dwarnln_if(DEBUG_USB_MASS_STORAGE, "could not reset USBMassStorage");
return BAN::Error::from_errno(EIO);
}
return static_cast<size_t>(bytes_sent);
}
@ -159,20 +196,114 @@ namespace Kernel
m_device.send_data_buffer(m_in_endpoint_id, paddr, count);
const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + 100;
const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + s_timeout_ms;
while (bytes_recv == invalid)
if (SystemTimer::get().ms_since_boot() > timeout_ms)
{
if (SystemTimer::get().ms_since_boot() < timeout_ms)
continue;
if (reset_recovery().is_error())
dwarnln_if(DEBUG_USB_MASS_STORAGE, "could not reset USBMassStorage");
return BAN::Error::from_errno(EIO);
}
m_in_callback.clear();
return static_cast<size_t>(bytes_recv);
}
template<bool IN, typename SPAN>
BAN::ErrorOr<size_t> USBMassStorageDriver::send_command(uint8_t lun, BAN::ConstByteSpan command, SPAN data)
{
ASSERT(command.size() <= 16);
LockGuard _(m_mutex);
auto& cbw = *reinterpret_cast<USBMassStorage::CBW*>(m_data_region->vaddr());
cbw = {
.dCBWSignature = 0x43425355,
.dCBWTag = 0x00000000,
.dCBWDataTransferLength = static_cast<uint32_t>(data.size()),
.bmCBWFlags = IN ? 0x80 : 0x00,
.bCBWLUN = lun,
.bCBWCBLength = static_cast<uint8_t>(command.size()),
.CBWCB = {},
};
memcpy(cbw.CBWCB, command.data(), command.size());
if (TRY(send_bytes(m_data_region->paddr(), sizeof(USBMassStorage::CBW))) != sizeof(USBMassStorage::CBW))
{
dwarnln("failed to send CBW");
return BAN::Error::from_errno(EIO);
}
const size_t ntransfer =
TRY([&]() -> BAN::ErrorOr<size_t>
{
if (data.empty())
return 0;
if constexpr (IN)
return TRY(recv_bytes(m_data_region->paddr(), data.size()));
memcpy(reinterpret_cast<void*>(m_data_region->vaddr()), data.data(), data.size());
return TRY(send_bytes(m_data_region->paddr(), data.size()));
}());
if (ntransfer == transfer_stall)
TRY(clear_feature(IN ? m_in_endpoint_id : m_out_endpoint_id));
if constexpr (IN)
memcpy(data.data(), reinterpret_cast<void*>(m_data_region->vaddr()), ntransfer);
size_t csw_ntransfer = TRY(recv_bytes(m_data_region->paddr(), sizeof(USBMassStorage::CSW)));
if (csw_ntransfer == transfer_stall)
{
TRY(clear_feature(m_in_endpoint_id));
csw_ntransfer = TRY(recv_bytes(m_data_region->paddr(), sizeof(USBMassStorage::CSW)));
}
if (csw_ntransfer != sizeof(USBMassStorage::CSW))
{
dwarnln("could not receive CSW");
return BAN::Error::from_errno(EFAULT);
}
const auto& csw = *reinterpret_cast<USBMassStorage::CSW*>(m_data_region->vaddr());
switch (csw.bmCSWStatus)
{
case 0x00:
case 0x01:
return data.size() - csw.dCSWDataResidue;
default:
dwarnln_if(DEBUG_USB_MASS_STORAGE, "received invalid CSW");
// fall through
case 0x02:
TRY(reset_recovery());
return BAN::Error::from_errno(EIO);
}
ASSERT_NOT_REACHED();
}
template BAN::ErrorOr<size_t> USBMassStorageDriver::send_command<true, BAN::ByteSpan >(uint8_t, BAN::ConstByteSpan, BAN::ByteSpan);
template BAN::ErrorOr<size_t> USBMassStorageDriver::send_command<false, BAN::ConstByteSpan>(uint8_t, BAN::ConstByteSpan, BAN::ConstByteSpan);
void USBMassStorageDriver::handle_stall(uint8_t endpoint_id)
{
(void)endpoint_id;
// FIXME: do something :)
if (endpoint_id != m_in_endpoint_id && endpoint_id != m_out_endpoint_id)
return;
dprintln_if(DEBUG_USB_MASS_STORAGE, "got STALL to {} endpoint", endpoint_id == m_in_endpoint_id ? "IN" : "OUT");
if (m_in_endpoint_id == endpoint_id)
{
ASSERT(m_in_callback);
return m_in_callback(transfer_stall);
}
if (m_out_endpoint_id == endpoint_id)
{
ASSERT(m_out_callback);
return m_out_callback(transfer_stall);
}
}
void USBMassStorageDriver::handle_input_data(size_t byte_count, uint8_t endpoint_id)

View File

@ -1,7 +1,6 @@
#include <BAN/Endianness.h>
#include <kernel/Storage/SCSI.h>
#include <kernel/USB/MassStorage/Definitions.h>
#include <kernel/USB/MassStorage/SCSIDevice.h>
#include <sys/sysmacros.h>
@ -72,8 +71,6 @@ namespace Kernel
BAN::ErrorOr<BAN::RefPtr<USBSCSIDevice>> USBSCSIDevice::create(USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size)
{
auto dma_region = TRY(DMARegion::create(max_packet_size));
dprintln("USB SCSI device");
{
@ -85,14 +82,14 @@ namespace Kernel
0x00
};
SCSI::InquiryRes inquiry_res;
TRY(send_scsi_command_impl<true>(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res)));
TRY(driver.send_command<true>(lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res)));
dprintln(" vendor: {}", BAN::StringView(reinterpret_cast<const char*>(inquiry_res.t10_vendor_identification), 8));
dprintln(" product: {}", BAN::StringView(reinterpret_cast<const char*>(inquiry_res.product_identification), 16));
dprintln(" revision: {}", BAN::StringView(reinterpret_cast<const char*>(inquiry_res.product_revision_level), 4));
}
uint32_t block_count;
uint64_t block_count;
uint32_t block_size;
{
@ -105,7 +102,7 @@ namespace Kernel
0x00
};
SCSI::ReadCapacity10 read_capacity_res;
TRY(send_scsi_command_impl<true>(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res)));
TRY(driver.send_command<true>(lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res)));
block_count = read_capacity_res.logical_block_address + 1;
block_size = read_capacity_res.block_length;
@ -121,7 +118,7 @@ namespace Kernel
dprintln(" total size: {} MiB", block_count * block_size / 1024 / 1024);
}
auto result = TRY(BAN::RefPtr<USBSCSIDevice>::create(driver, lun, BAN::move(dma_region), block_count, block_size));
auto result = TRY(BAN::RefPtr<USBSCSIDevice>::create(driver, lun, max_packet_size, block_count, block_size));
result->add_disk_cache();
DevFileSystem::get().add_device(result);
if (auto res = result->initialize_partitions(result->name()); res.is_error())
@ -129,9 +126,9 @@ namespace Kernel
return result;
}
USBSCSIDevice::USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr<DMARegion>&& dma_region, uint64_t block_count, uint32_t block_size)
USBSCSIDevice::USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, uint32_t max_packet_size, uint64_t block_count, uint32_t block_size)
: m_driver(driver)
, m_dma_region(BAN::move(dma_region))
, m_max_packet_size(max_packet_size)
, m_lun(lun)
, m_block_count(block_count)
, m_block_size(block_size)
@ -144,75 +141,11 @@ namespace Kernel
scsi_free_rdev(m_rdev);
}
template<bool IN, typename SPAN>
BAN::ErrorOr<size_t> USBSCSIDevice::send_scsi_command(BAN::ConstByteSpan scsi_command, SPAN data)
{
return TRY(send_scsi_command_impl<IN>(m_driver, *m_dma_region, m_lun, scsi_command, data));
}
template<bool IN, typename SPAN>
BAN::ErrorOr<size_t> USBSCSIDevice::send_scsi_command_impl(USBMassStorageDriver& driver, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan scsi_command, SPAN data)
{
ASSERT(scsi_command.size() <= 16);
LockGuard _(driver);
auto& cbw = *reinterpret_cast<USBMassStorage::CBW*>(dma_region.vaddr());
cbw = {
.dCBWSignature = 0x43425355,
.dCBWTag = 0x00000000,
.dCBWDataTransferLength = static_cast<uint32_t>(data.size()),
.bmCBWFlags = IN ? 0x80 : 0x00,
.bCBWLUN = lun,
.bCBWCBLength = static_cast<uint8_t>(scsi_command.size()),
.CBWCB = {},
};
memcpy(cbw.CBWCB, scsi_command.data(), scsi_command.size());
if (TRY(driver.send_bytes(dma_region.paddr(), sizeof(USBMassStorage::CBW))) != sizeof(USBMassStorage::CBW))
{
dwarnln("failed to send full CBW");
return BAN::Error::from_errno(EFAULT);
}
const size_t ntransfer =
TRY([&]() -> BAN::ErrorOr<size_t>
{
if (data.empty())
return 0;
if constexpr(IN)
return TRY(driver.recv_bytes(dma_region.paddr(), data.size()));
memcpy(reinterpret_cast<void*>(dma_region.vaddr()), data.data(), data.size());
return TRY(driver.send_bytes(dma_region.paddr(), data.size()));
}());
if (ntransfer != data.size())
{
dwarnln("device responded with {}/{} bytes", ntransfer, data.size());
return BAN::Error::from_errno(EFAULT);
}
if constexpr (IN)
memcpy(data.data(), reinterpret_cast<void*>(dma_region.vaddr()), ntransfer);
if (TRY(driver.recv_bytes(dma_region.paddr(), sizeof(USBMassStorage::CSW))) != sizeof(USBMassStorage::CSW))
{
dwarnln("could not receive full CSW");
return BAN::Error::from_errno(EFAULT);
}
if (auto status = reinterpret_cast<USBMassStorage::CSW*>(dma_region.vaddr())->bmCSWStatus)
{
dwarnln("CSW status {2H}", status);
return BAN::Error::from_errno(EFAULT);
}
return ntransfer;
}
BAN::ErrorOr<void> USBSCSIDevice::read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer)
{
const size_t max_blocks_per_read = m_dma_region->size() / m_block_size;
dprintln_if(DEBUG_USB_MASS_STORAGE, "read_blocks({}, {})", first_lba, sector_count);
const size_t max_blocks_per_read = m_max_packet_size / m_block_size;
ASSERT(max_blocks_per_read <= 0xFFFF);
for (uint64_t i = 0; i < sector_count;)
@ -228,7 +161,10 @@ namespace Kernel
(uint8_t)(count >> 8), (uint8_t)(count >> 0),
0x00
};
TRY(send_scsi_command<true>(BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size)));
const size_t nread = TRY(m_driver.send_command<true>(m_lun, BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size)));
if (nread != count * m_block_size)
return BAN::Error::from_errno(EIO);
i += count;
}
@ -238,9 +174,9 @@ namespace Kernel
BAN::ErrorOr<void> USBSCSIDevice::write_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ConstByteSpan buffer)
{
dprintln("write_sectors_impl({}, {})", first_lba, sector_count);
dprintln_if(DEBUG_USB_MASS_STORAGE, "write_blocks({}, {})", first_lba, sector_count);
const size_t max_blocks_per_write = m_dma_region->size() / m_block_size;
const size_t max_blocks_per_write = m_max_packet_size / m_block_size;
ASSERT(max_blocks_per_write <= 0xFFFF);
for (uint64_t i = 0; i < sector_count;)
@ -256,7 +192,9 @@ namespace Kernel
(uint8_t)(count >> 8), (uint8_t)(count >> 0),
0x00
};
TRY(send_scsi_command<false>(BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size)));
const size_t nwrite = TRY(m_driver.send_command<false>(m_lun, BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size)));
if (nwrite != count * m_block_size)
return BAN::Error::from_errno(EIO);
i += count;
}