Kernel: Make USBMassStorage send_scsi_command templated

This allows passing ConstByteSpan when data will not be modified
This commit is contained in:
Bananymous 2024-11-22 22:21:19 +02:00
parent 480368c878
commit 0247d47a3d
2 changed files with 22 additions and 17 deletions

View File

@ -22,8 +22,11 @@ namespace Kernel
USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr<DMARegion>&&, uint64_t block_count, uint32_t block_size); USBSCSIDevice(USBMassStorageDriver& driver, uint8_t lun, BAN::UniqPtr<DMARegion>&&, uint64_t block_count, uint32_t block_size);
~USBSCSIDevice(); ~USBSCSIDevice();
static BAN::ErrorOr<size_t> send_scsi_command_impl(USBMassStorageDriver&, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan command, BAN::ByteSpan data, bool in); template<bool IN, typename SPAN = BAN::either_or_t<IN, BAN::ByteSpan, BAN::ConstByteSpan>>
BAN::ErrorOr<size_t> send_scsi_command(BAN::ConstByteSpan command, BAN::ByteSpan data, bool in); BAN::ErrorOr<size_t> send_scsi_command(BAN::ConstByteSpan command, SPAN data);
template<bool IN, typename SPAN = BAN::either_or_t<IN, BAN::ByteSpan, BAN::ConstByteSpan>>
static BAN::ErrorOr<size_t> send_scsi_command_impl(USBMassStorageDriver&, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan command, SPAN data);
BAN::ErrorOr<void> read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer) override; BAN::ErrorOr<void> read_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ByteSpan buffer) override;
BAN::ErrorOr<void> write_sectors_impl(uint64_t lba, uint64_t sector_count, BAN::ConstByteSpan buffer) override; BAN::ErrorOr<void> write_sectors_impl(uint64_t lba, uint64_t sector_count, BAN::ConstByteSpan buffer) override;

View File

@ -85,7 +85,7 @@ namespace Kernel
0x00 0x00
}; };
SCSI::InquiryRes inquiry_res; SCSI::InquiryRes inquiry_res;
TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res), true)); TRY(send_scsi_command_impl<true>(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_inquiry_req), BAN::ByteSpan::from(inquiry_res)));
dprintln(" vendor: {}", BAN::StringView(reinterpret_cast<const char*>(inquiry_res.t10_vendor_identification), 8)); dprintln(" vendor: {}", BAN::StringView(reinterpret_cast<const char*>(inquiry_res.t10_vendor_identification), 8));
dprintln(" product: {}", BAN::StringView(reinterpret_cast<const char*>(inquiry_res.product_identification), 16)); dprintln(" product: {}", BAN::StringView(reinterpret_cast<const char*>(inquiry_res.product_identification), 16));
@ -105,7 +105,7 @@ namespace Kernel
0x00 0x00
}; };
SCSI::ReadCapacity10 read_capacity_res; SCSI::ReadCapacity10 read_capacity_res;
TRY(send_scsi_command_impl(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res), true)); TRY(send_scsi_command_impl<true>(driver, *dma_region, lun, BAN::ConstByteSpan::from(scsi_read_capacity_req), BAN::ByteSpan::from(read_capacity_res)));
block_count = read_capacity_res.logical_block_address + 1; block_count = read_capacity_res.logical_block_address + 1;
block_size = read_capacity_res.block_length; block_size = read_capacity_res.block_length;
@ -144,12 +144,14 @@ namespace Kernel
scsi_free_rdev(m_rdev); scsi_free_rdev(m_rdev);
} }
BAN::ErrorOr<size_t> USBSCSIDevice::send_scsi_command(BAN::ConstByteSpan scsi_command, BAN::ByteSpan data, bool in) template<bool IN, typename SPAN>
BAN::ErrorOr<size_t> USBSCSIDevice::send_scsi_command(BAN::ConstByteSpan scsi_command, SPAN data)
{ {
return TRY(send_scsi_command_impl(m_driver, *m_dma_region, m_lun, scsi_command, data, in)); return TRY(send_scsi_command_impl<IN>(m_driver, *m_dma_region, m_lun, scsi_command, data));
} }
BAN::ErrorOr<size_t> USBSCSIDevice::send_scsi_command_impl(USBMassStorageDriver& driver, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan scsi_command, BAN::ByteSpan data, bool in) template<bool IN, typename SPAN>
BAN::ErrorOr<size_t> USBSCSIDevice::send_scsi_command_impl(USBMassStorageDriver& driver, DMARegion& dma_region, uint8_t lun, BAN::ConstByteSpan scsi_command, SPAN data)
{ {
ASSERT(scsi_command.size() <= 16); ASSERT(scsi_command.size() <= 16);
@ -160,7 +162,7 @@ namespace Kernel
.dCBWSignature = 0x43425355, .dCBWSignature = 0x43425355,
.dCBWTag = 0x00000000, .dCBWTag = 0x00000000,
.dCBWDataTransferLength = static_cast<uint32_t>(data.size()), .dCBWDataTransferLength = static_cast<uint32_t>(data.size()),
.bmCBWFlags = static_cast<uint8_t>(in ? 0x80 : 0x00), .bmCBWFlags = IN ? 0x80 : 0x00,
.bCBWLUN = lun, .bCBWLUN = lun,
.bCBWCBLength = static_cast<uint8_t>(scsi_command.size()), .bCBWCBLength = static_cast<uint8_t>(scsi_command.size()),
.CBWCB = {}, .CBWCB = {},
@ -178,19 +180,19 @@ namespace Kernel
{ {
if (data.empty()) if (data.empty())
return 0; return 0;
if (in) if constexpr(IN)
return TRY(driver.recv_bytes(dma_region.paddr(), data.size())); return TRY(driver.recv_bytes(dma_region.paddr(), data.size()));
memcpy(reinterpret_cast<void*>(dma_region.vaddr()), data.data(), data.size()); memcpy(reinterpret_cast<void*>(dma_region.vaddr()), data.data(), data.size());
return TRY(driver.send_bytes(dma_region.paddr(), data.size())); return TRY(driver.send_bytes(dma_region.paddr(), data.size()));
}()); }());
if (ntransfer > data.size()) if (ntransfer != data.size())
{ {
dwarnln("device responded with more bytes than requested"); dwarnln("device responded with {}/{} bytes", ntransfer, data.size());
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
} }
if (in && !data.empty()) if constexpr (IN)
memcpy(data.data(), reinterpret_cast<void*>(dma_region.vaddr()), ntransfer); memcpy(data.data(), reinterpret_cast<void*>(dma_region.vaddr()), ntransfer);
if (TRY(driver.recv_bytes(dma_region.paddr(), sizeof(USBMassStorage::CSW))) != sizeof(USBMassStorage::CSW)) if (TRY(driver.recv_bytes(dma_region.paddr(), sizeof(USBMassStorage::CSW))) != sizeof(USBMassStorage::CSW))
@ -226,7 +228,7 @@ namespace Kernel
(uint8_t)(count >> 8), (uint8_t)(count >> 0), (uint8_t)(count >> 8), (uint8_t)(count >> 0),
0x00 0x00
}; };
TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size), true)); TRY(send_scsi_command<true>(BAN::ConstByteSpan::from(scsi_read_req), buffer.slice(i * m_block_size, count * m_block_size)));
i += count; i += count;
} }
@ -234,13 +236,13 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<void> USBSCSIDevice::write_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ConstByteSpan _buffer) BAN::ErrorOr<void> USBSCSIDevice::write_sectors_impl(uint64_t first_lba, uint64_t sector_count, BAN::ConstByteSpan buffer)
{ {
dprintln("write_sectors_impl({}, {})", first_lba, sector_count);
const size_t max_blocks_per_write = m_dma_region->size() / m_block_size; const size_t max_blocks_per_write = m_dma_region->size() / m_block_size;
ASSERT(max_blocks_per_write <= 0xFFFF); ASSERT(max_blocks_per_write <= 0xFFFF);
auto buffer = BAN::ByteSpan(const_cast<uint8_t*>(_buffer.data()), _buffer.size());
for (uint64_t i = 0; i < sector_count;) for (uint64_t i = 0; i < sector_count;)
{ {
const uint32_t lba = first_lba + i; const uint32_t lba = first_lba + i;
@ -254,7 +256,7 @@ namespace Kernel
(uint8_t)(count >> 8), (uint8_t)(count >> 0), (uint8_t)(count >> 8), (uint8_t)(count >> 0),
0x00 0x00
}; };
TRY(send_scsi_command(BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size), false)); TRY(send_scsi_command<false>(BAN::ConstByteSpan::from(scsi_write_req), buffer.slice(i * m_block_size, count * m_block_size)));
i += count; i += count;
} }