diff --git a/kernel/kernel/USB/XHCI/Device.cpp b/kernel/kernel/USB/XHCI/Device.cpp index 9c1d384f..43d6e57d 100644 --- a/kernel/kernel/USB/XHCI/Device.cpp +++ b/kernel/kernel/USB/XHCI/Device.cpp @@ -355,17 +355,9 @@ namespace Kernel BAN::ErrorOr XHCIDevice::send_request(const USBDeviceRequest& request, paddr_t buffer_paddr) { - // FIXME: This is more or less generic USB code + ASSERT(request.wLength == 0 || buffer_paddr); - auto& endpoint = m_endpoints[0]; - - // minus 3: Setup, Status, Link (this is probably too generous and will result in STALL) - if (request.wLength > (m_transfer_ring_trb_count - 3) * endpoint.max_packet_size) - return BAN::Error::from_errno((ENOBUFS)); - - LockGuard _(endpoint.mutex); - - uint8_t transfer_type = + const uint8_t transfer_type = [&request]() -> uint8_t { if (request.wLength == 0) @@ -375,6 +367,12 @@ namespace Kernel return 2; }(); + const bool status_stage_dir = !((request.wLength > 0) && (request.bmRequestType & USB::RequestType::DeviceToHost)); + + auto& endpoint = m_endpoints[0]; + + LockGuard _(endpoint.mutex); + auto* transfer_trb_arr = reinterpret_cast(endpoint.transfer_ring->vaddr()); { @@ -397,32 +395,23 @@ namespace Kernel advance_endpoint_enqueue(endpoint, false); } - const uint32_t td_packet_count = BAN::Math::div_round_up(request.wLength, endpoint.max_packet_size); - uint32_t packets_transferred = 1; - - uint32_t bytes_handled = 0; - while (bytes_handled < request.wLength) + if (request.wLength) { - const uint32_t to_handle = BAN::Math::min(endpoint.max_packet_size, request.wLength - bytes_handled); - auto& trb = transfer_trb_arr[endpoint.enqueue_index]; memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); trb.data_stage.trb_type = XHCI::TRBType::DataStage; - trb.data_stage.direction = 1; - trb.data_stage.trb_transfer_length = to_handle; - trb.data_stage.td_size = BAN::Math::min(td_packet_count - packets_transferred, 31); - trb.data_stage.chain_bit = (bytes_handled + to_handle < request.wLength); + trb.data_stage.direction = !!(request.bmRequestType & USB::RequestType::DeviceToHost); + trb.data_stage.trb_transfer_length = request.wLength; + trb.data_stage.td_size = 0; + trb.data_stage.chain_bit = 0; trb.data_stage.interrupt_on_completion = 0; - trb.data_stage.interrupt_on_short_packet = 1; + trb.data_stage.interrupt_on_short_packet = 0; trb.data_stage.immediate_data = 0; - trb.data_stage.data_buffer_pointer = buffer_paddr + bytes_handled; + trb.data_stage.data_buffer_pointer = buffer_paddr; trb.data_stage.cycle_bit = endpoint.cycle_bit; - bytes_handled += to_handle; - packets_transferred++; - - advance_endpoint_enqueue(endpoint, trb.data_stage.chain_bit); + advance_endpoint_enqueue(endpoint, false); } { @@ -430,10 +419,10 @@ namespace Kernel memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); trb.status_stage.trb_type = XHCI::TRBType::StatusStage; - trb.status_stage.direction = 0; + trb.status_stage.direction = status_stage_dir; trb.status_stage.chain_bit = 0; trb.status_stage.interrupt_on_completion = 1; - trb.data_stage.cycle_bit = endpoint.cycle_bit; + trb.status_stage.cycle_bit = endpoint.cycle_bit; advance_endpoint_enqueue(endpoint, false); } @@ -469,7 +458,7 @@ namespace Kernel ASSERT(endpoint_id != 0); auto& endpoint = m_endpoints[endpoint_id - 1]; - ASSERT(buffer_len <= endpoint.max_packet_size); + ASSERT(buffer_len <= (1 << 16)); auto& trb = *reinterpret_cast(endpoint.transfer_ring->vaddr() + endpoint.enqueue_index * sizeof(XHCI::TRB)); memset(const_cast(&trb), 0, sizeof(XHCI::TRB));