diff --git a/kernel/include/kernel/USB/MassStorage/SCSIDevice.h b/kernel/include/kernel/USB/MassStorage/SCSIDevice.h index 7c05d766..645d13fa 100644 --- a/kernel/include/kernel/USB/MassStorage/SCSIDevice.h +++ b/kernel/include/kernel/USB/MassStorage/SCSIDevice.h @@ -22,8 +22,11 @@ namespace Kernel USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr&&, uint64_t block_count, uint32_t block_size); ~USBSCSIDevice(); - static BAN::ErrorOr send_scsi_command_impl(USBMassStorageDriver&, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan command, BAN::ByteSpan data, bool in); - BAN::ErrorOr send_scsi_command(BAN::ConstByteSpan command, BAN::ByteSpan data, bool in); + template> + BAN::ErrorOr send_scsi_command(BAN::ConstByteSpan command, SPAN data); + + template> + static BAN::ErrorOr send_scsi_command_impl(USBMassStorageDriver&, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan command, SPAN data); BAN::ErrorOr read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer) override; BAN::ErrorOr write_sectors_impl(uint64_t lba, uint64_t sector_count, BAN::ConstByteSpan buffer) override; diff --git a/kernel/kernel/USB/MassStorage/SCSIDevice.cpp b/kernel/kernel/USB/MassStorage/SCSIDevice.cpp index 70242fee..495dde47 100644 --- a/kernel/kernel/USB/MassStorage/SCSIDevice.cpp +++ b/kernel/kernel/USB/MassStorage/SCSIDevice.cpp @@ -85,7 +85,7 @@ namespace Kernel 0x00 }; SCSI::InquiryRes inquiry_res; - TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res), true)); + TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res))); dprintln(" vendor: {}", BAN::StringView(reinterpret_cast(inquiry_res.t10_vendor_identification), 8)); dprintln(" product: {}", BAN::StringView(reinterpret_cast(inquiry_res.product_identification), 16)); @@ -105,7 +105,7 @@ namespace Kernel 0x00 }; SCSI::ReadCapacity10 read_capacity_res; - TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res), true)); + TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res))); block_count = read_capacity_res.logical_block_address + 1; block_size = read_capacity_res.block_length; @@ -144,12 +144,14 @@ namespace Kernel scsi_free_rdev(m_rdev); } - BAN::ErrorOr USBSCSIDevice::send_scsi_command(BAN::ConstByteSpan scsi_command, BAN::ByteSpan data, bool in) + template + BAN::ErrorOr USBSCSIDevice::send_scsi_command(BAN::ConstByteSpan scsi_command, SPAN data) { - return TRY(send_scsi_command_impl(m_driver, *m_dma_region, m_lun, scsi_command, data, in)); + return TRY(send_scsi_command_impl(m_driver, *m_dma_region, m_lun, scsi_command, data)); } - BAN::ErrorOr USBSCSIDevice::send_scsi_command_impl(USBMassStorageDriver& driver, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan scsi_command, BAN::ByteSpan data, bool in) + template + BAN::ErrorOr USBSCSIDevice::send_scsi_command_impl(USBMassStorageDriver& driver, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan scsi_command, SPAN data) { ASSERT(scsi_command.size() <= 16); @@ -160,7 +162,7 @@ namespace Kernel .dCBWSignature = 0x43425355, .dCBWTag = 0x00000000, .dCBWDataTransferLength = static_cast(data.size()), - .bmCBWFlags = static_cast(in ? 0x80 : 0x00), + .bmCBWFlags = IN ? 0x80 : 0x00, .bCBWLUN = lun, .bCBWCBLength = static_cast(scsi_command.size()), .CBWCB = {}, @@ -178,19 +180,19 @@ namespace Kernel { if (data.empty()) return 0; - if (in) + if constexpr(IN) return TRY(driver.recv_bytes(dma_region.paddr(), data.size())); memcpy(reinterpret_cast(dma_region.vaddr()), data.data(), data.size()); return TRY(driver.send_bytes(dma_region.paddr(), data.size())); }()); - if (ntransfer > data.size()) + if (ntransfer != data.size()) { - dwarnln("device responded with more bytes than requested"); + dwarnln("device responded with {}/{} bytes", ntransfer, data.size()); return BAN::Error::from_errno(EFAULT); } - if (in && !data.empty()) + if constexpr (IN) memcpy(data.data(), reinterpret_cast(dma_region.vaddr()), ntransfer); if (TRY(driver.recv_bytes(dma_region.paddr(), sizeof(USBMassStorage::CSW))) != sizeof(USBMassStorage::CSW)) @@ -226,7 +228,7 @@ namespace Kernel (uint8_t)(count >> 8), (uint8_t)(count >> 0), 0x00 }; - TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size), true)); + TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size))); i += count; } @@ -234,13 +236,13 @@ namespace Kernel return {}; } - BAN::ErrorOr USBSCSIDevice::write_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ConstByteSpan _buffer) + BAN::ErrorOr USBSCSIDevice::write_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ConstByteSpan buffer) { + dprintln("write_sectors_impl({}, {})", first_lba, sector_count); + const size_t max_blocks_per_write = m_dma_region->size() / m_block_size; ASSERT(max_blocks_per_write <= 0xFFFF); - auto buffer = BAN::ByteSpan(const_cast(_buffer.data()), _buffer.size()); - for (uint64_t i = 0; i < sector_count;) { const uint32_t lba = first_lba + i; @@ -254,7 +256,7 @@ namespace Kernel (uint8_t)(count >> 8), (uint8_t)(count >> 0), 0x00 }; - TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size), false)); + TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size))); i += count; }