diff --git a/kernel/include/kernel/USB/Device.h b/kernel/include/kernel/USB/Device.h index 973fc5ef..2dd91348 100644 --- a/kernel/include/kernel/USB/Device.h +++ b/kernel/include/kernel/USB/Device.h @@ -21,6 +21,7 @@ namespace Kernel virtual BAN::ErrorOr initialize() { return {}; }; + virtual void handle_stall(uint8_t endpoint_id) = 0; virtual void handle_input_data(size_t byte_count, uint8_t endpoint_id) = 0; }; @@ -71,6 +72,7 @@ namespace Kernel static USB::SpeedClass determine_speed_class(uint64_t bits_per_second); protected: + void handle_stall(uint8_t endpoint_id); void handle_input_data(size_t byte_count, uint8_t endpoint_id); virtual BAN::ErrorOr initialize_control_endpoint() = 0; diff --git a/kernel/include/kernel/USB/HID/HIDDriver.h b/kernel/include/kernel/USB/HID/HIDDriver.h index 8a63d6b9..7f1216c9 100644 --- a/kernel/include/kernel/USB/HID/HIDDriver.h +++ b/kernel/include/kernel/USB/HID/HIDDriver.h @@ -75,6 +75,7 @@ namespace Kernel }; public: + void handle_stall(uint8_t endpoint_id) override; void handle_input_data(size_t byte_count, uint8_t endpoint_id) override; private: diff --git a/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h b/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h index 5cee4fe5..f24b2b9c 100644 --- a/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h +++ b/kernel/include/kernel/USB/MassStorage/MassStorageDriver.h @@ -15,6 +15,7 @@ namespace Kernel BAN_NON_MOVABLE(USBMassStorageDriver); public: + void handle_stall(uint8_t endpoint_id) override; void handle_input_data(size_t byte_count, uint8_t endpoint_id) override; BAN::ErrorOr send_bytes(paddr_t, size_t count); diff --git a/kernel/kernel/USB/Device.cpp b/kernel/kernel/USB/Device.cpp index 4967a3c7..87517cf6 100644 --- a/kernel/kernel/USB/Device.cpp +++ b/kernel/kernel/USB/Device.cpp @@ -328,6 +328,12 @@ namespace Kernel return BAN::move(configuration); } + void USBDevice::handle_stall(uint8_t endpoint_id) + { + for (auto& driver : m_class_drivers) + driver->handle_stall(endpoint_id); + } + void USBDevice::handle_input_data(size_t byte_count, uint8_t endpoint_id) { for (auto& driver : m_class_drivers) diff --git a/kernel/kernel/USB/HID/HIDDriver.cpp b/kernel/kernel/USB/HID/HIDDriver.cpp index 1df32a70..a8fa37a4 100644 --- a/kernel/kernel/USB/HID/HIDDriver.cpp +++ b/kernel/kernel/USB/HID/HIDDriver.cpp @@ -272,6 +272,12 @@ namespace Kernel return BAN::move(result); } + void USBHIDDriver::handle_stall(uint8_t endpoint_id) + { + (void)endpoint_id; + // FIXME: do something :) + } + void USBHIDDriver::handle_input_data(size_t byte_count, uint8_t endpoint_id) { if (m_data_endpoint_id != endpoint_id) diff --git a/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp b/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp index 23947dd8..acfa4117 100644 --- a/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp +++ b/kernel/kernel/USB/MassStorage/MassStorageDriver.cpp @@ -169,6 +169,12 @@ namespace Kernel return static_cast(bytes_recv); } + void USBMassStorageDriver::handle_stall(uint8_t endpoint_id) + { + (void)endpoint_id; + // FIXME: do something :) + } + void USBMassStorageDriver::handle_input_data(size_t byte_count, uint8_t endpoint_id) { if (endpoint_id != m_in_endpoint_id && endpoint_id != m_out_endpoint_id) diff --git a/kernel/kernel/USB/XHCI/Device.cpp b/kernel/kernel/USB/XHCI/Device.cpp index 92d858be..59ffbb6e 100644 --- a/kernel/kernel/USB/XHCI/Device.cpp +++ b/kernel/kernel/USB/XHCI/Device.cpp @@ -295,15 +295,19 @@ namespace Kernel void XHCIDevice::on_interrupt_or_bulk_endpoint_event(XHCI::TRB trb) { ASSERT(trb.trb_type == XHCI::TRBType::TransferEvent); + + const uint32_t endpoint_id = trb.transfer_event.endpoint_id; + auto& endpoint = m_endpoints[endpoint_id - 1]; + + if (trb.transfer_event.completion_code == 6) + return handle_stall(endpoint_id); + if (trb.transfer_event.completion_code != 1 && trb.transfer_event.completion_code != 13) { dwarnln("Interrupt or bulk endpoint got transfer event with completion code {}", +trb.transfer_event.completion_code); return; } - const uint32_t endpoint_id = trb.transfer_event.endpoint_id; - auto& endpoint = m_endpoints[endpoint_id - 1]; - const auto* transfer_trb_arr = reinterpret_cast(endpoint.transfer_ring->vaddr()); const uint32_t transfer_trb_index = (trb.transfer_event.trb_pointer - endpoint.transfer_ring->paddr()) / sizeof(XHCI::TRB); const uint32_t original_len = transfer_trb_arr[transfer_trb_index].normal.trb_transfer_length;