From 58fcd2b2fe50eb7793340146f644aad665aecbb4 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Tue, 16 Jul 2024 22:29:18 +0300 Subject: [PATCH] Kernel: Fix multi-interface USB device initialization --- kernel/include/kernel/USB/HID/HIDDriver.h | 5 ++--- kernel/kernel/Terminal/VirtualTTY.cpp | 1 + kernel/kernel/USB/Device.cpp | 19 +++++++++--------- kernel/kernel/USB/HID/HIDDriver.cpp | 24 +++++++++++++---------- 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/kernel/include/kernel/USB/HID/HIDDriver.h b/kernel/include/kernel/USB/HID/HIDDriver.h index 59fd1a2d9e..07f3ca8449 100644 --- a/kernel/include/kernel/USB/HID/HIDDriver.h +++ b/kernel/include/kernel/USB/HID/HIDDriver.h @@ -75,12 +75,12 @@ namespace Kernel }; public: - static BAN::ErrorOr> create(USBDevice&, const USBDevice::InterfaceDescriptor&, uint8_t interface_index); + static BAN::ErrorOr> create(USBDevice&, const USBDevice::InterfaceDescriptor&); void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id) override; private: - USBHIDDriver(USBDevice&, const USBDevice::InterfaceDescriptor&, uint8_t interface_index); + USBHIDDriver(USBDevice&, const USBDevice::InterfaceDescriptor&); ~USBHIDDriver(); BAN::ErrorOr initialize(); @@ -90,7 +90,6 @@ namespace Kernel private: USBDevice& m_device; USBDevice::InterfaceDescriptor m_interface; - const uint8_t m_interface_index; bool m_uses_report_id { false }; BAN::Vector m_device_inputs; diff --git a/kernel/kernel/Terminal/VirtualTTY.cpp b/kernel/kernel/Terminal/VirtualTTY.cpp index 384f4d160b..51e84ee6f8 100644 --- a/kernel/kernel/Terminal/VirtualTTY.cpp +++ b/kernel/kernel/Terminal/VirtualTTY.cpp @@ -423,6 +423,7 @@ namespace Kernel } m_show_cursor = old_show_cursor; + m_column = 0; m_row--; } diff --git a/kernel/kernel/USB/Device.cpp b/kernel/kernel/USB/Device.cpp index e069e6cff4..f12b34a969 100644 --- a/kernel/kernel/USB/Device.cpp +++ b/kernel/kernel/USB/Device.cpp @@ -140,10 +140,8 @@ namespace Kernel TRY(send_request(request, 0)); } - for (size_t j = 0; j < configuration.interfaces.size(); j++) + for (const auto& interface : configuration.interfaces) { - const auto& interface = configuration.interfaces[j]; - switch (static_cast(interface.descriptor.bInterfaceClass)) { case USB::InterfaceBaseClass::Audio: @@ -153,7 +151,7 @@ namespace Kernel dprintln_if(DEBUG_USB, "Found CommunicationAndCDCControl interface"); break; case USB::InterfaceBaseClass::HID: - if (auto result = USBHIDDriver::create(*this, interface, j); !result.is_error()) + if (auto result = USBHIDDriver::create(*this, interface); !result.is_error()) TRY(m_class_drivers.push_back(result.release_value())); break; case USB::InterfaceBaseClass::Physical: @@ -217,12 +215,15 @@ namespace Kernel dprintln_if(DEBUG_USB, "Invalid interface base class {2H}", interface.descriptor.bInterfaceClass); break; } + } - if (!m_class_drivers.empty()) - { - dprintln("Successfully initialized USB interface"); - return {}; - } + if (!m_class_drivers.empty()) + { + dprintln("Successfully initialized USB device with {}/{} interfaces", + m_class_drivers.size(), + configuration.interfaces.size() + ); + return {}; } } diff --git a/kernel/kernel/USB/HID/HIDDriver.cpp b/kernel/kernel/USB/HID/HIDDriver.cpp index 761bc0a80e..736ce17643 100644 --- a/kernel/kernel/USB/HID/HIDDriver.cpp +++ b/kernel/kernel/USB/HID/HIDDriver.cpp @@ -11,6 +11,11 @@ namespace Kernel { + enum HIDRequest : uint8_t + { + SET_PROTOCOL = 0x0B, + }; + enum class HIDDescriptorType : uint8_t { HID = 0x21, @@ -64,17 +69,16 @@ namespace Kernel static BAN::ErrorOr> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id); - BAN::ErrorOr> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface, uint8_t interface_index) + BAN::ErrorOr> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface) { - auto result = TRY(BAN::UniqPtr::create(device, interface, interface_index)); + auto result = TRY(BAN::UniqPtr::create(device, interface)); TRY(result->initialize()); return result; } - USBHIDDriver::USBHIDDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface, uint8_t interface_index) + USBHIDDriver::USBHIDDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface) : m_device(device) , m_interface(interface) - , m_interface_index(interface_index) {} USBHIDDriver::~USBHIDDriver() @@ -119,18 +123,17 @@ namespace Kernel } // If this device supports boot protocol, make sure it is not used - if (m_interface.endpoints.front().descriptor.bDescriptorType & 0x80) + if (m_interface.descriptor.bInterfaceSubClass == 0x01) { USBDeviceRequest request; request.bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Class | USB::RequestType::Interface; - request.bRequest = USB::Request::SET_INTERFACE; + request.bRequest = HIDRequest::SET_PROTOCOL; request.wValue = 1; // report protocol - request.wIndex = m_interface_index; + request.wIndex = m_interface.descriptor.bInterfaceNumber; request.wLength = 0; TRY(m_device.send_request(request, 0)); } - const auto& hid_descriptor = *reinterpret_cast(m_interface.misc_descriptors[hid_descriptor_index].data()); dprintln_if(DEBUG_HID, "HID descriptor ({} bytes)", m_interface.misc_descriptors[hid_descriptor_index].size()); dprintln_if(DEBUG_HID, " bLength: {}", hid_descriptor.bLength); @@ -139,6 +142,7 @@ namespace Kernel dprintln_if(DEBUG_HID, " bCountryCode: {}", hid_descriptor.bCountryCode); dprintln_if(DEBUG_HID, " bNumDescriptors: {}", hid_descriptor.bNumDescriptors); + uint32_t report_descriptor_index = 0; BAN::Vector collections; for (size_t i = 0; i < hid_descriptor.bNumDescriptors; i++) { @@ -160,8 +164,8 @@ namespace Kernel USBDeviceRequest request; request.bmRequestType = USB::RequestType::DeviceToHost | USB::RequestType::Standard | USB::RequestType::Interface; request.bRequest = USB::Request::GET_DESCRIPTOR; - request.wValue = static_cast(HIDDescriptorType::Report) << 8; - request.wIndex = m_interface_index; + request.wValue = (static_cast(HIDDescriptorType::Report) << 8) | report_descriptor_index++; + request.wIndex = m_interface.descriptor.bInterfaceNumber; request.wLength = descriptor.wItemLength; auto transferred = TRY(m_device.send_request(request, dma_buffer->paddr()));