From 1337758660a70be5cc20715d1d42a89a47e5ec5a Mon Sep 17 00:00:00 2001 From: Bananymous Date: Tue, 16 Jul 2024 00:23:26 +0300 Subject: [PATCH] Kernel: Make USB HID interfaces configure all endpoints --- kernel/include/kernel/USB/HID/HIDDriver.h | 2 - kernel/kernel/USB/HID/HIDDriver.cpp | 41 +++++------- kernel/kernel/USB/XHCI/Device.cpp | 77 +++++++++++++++-------- 3 files changed, 67 insertions(+), 53 deletions(-) diff --git a/kernel/include/kernel/USB/HID/HIDDriver.h b/kernel/include/kernel/USB/HID/HIDDriver.h index 557f03e54b..59fd1a2d9e 100644 --- a/kernel/include/kernel/USB/HID/HIDDriver.h +++ b/kernel/include/kernel/USB/HID/HIDDriver.h @@ -93,8 +93,6 @@ namespace Kernel const uint8_t m_interface_index; bool m_uses_report_id { false }; - - uint8_t m_endpoint_id { 0 }; BAN::Vector m_device_inputs; friend class BAN::UniqPtr; diff --git a/kernel/kernel/USB/HID/HIDDriver.cpp b/kernel/kernel/USB/HID/HIDDriver.cpp index 2f864bd883..761bc0a80e 100644 --- a/kernel/kernel/USB/HID/HIDDriver.cpp +++ b/kernel/kernel/USB/HID/HIDDriver.cpp @@ -90,24 +90,6 @@ namespace Kernel ASSERT(static_cast(m_interface.descriptor.bInterfaceClass) == USB::InterfaceBaseClass::HID); - size_t endpoint_index = static_cast(-1); - for (size_t i = 0; i < m_interface.endpoints.size(); i++) - { - const auto& endpoint = m_interface.endpoints[i]; - if (!(endpoint.descriptor.bEndpointAddress & 0x80)) - continue; - if (endpoint.descriptor.bmAttributes != 0x03) - continue; - endpoint_index = i; - break; - } - - if (endpoint_index >= m_interface.endpoints.size()) - { - dwarnln("HID device does not contain IN interrupt endpoint"); - return BAN::Error::from_errno(EFAULT); - } - bool hid_descriptor_invalid = false; size_t hid_descriptor_index = static_cast(-1); for (size_t i = 0; i < m_interface.misc_descriptors.size(); i++) @@ -206,9 +188,8 @@ namespace Kernel m_device_inputs = TRY(initializes_device_reports(collections)); - const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor; - m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); - TRY(m_device.initialize_endpoint(endpoint_descriptor)); + for (const auto& endpoint : m_interface.endpoints) + TRY(m_device.initialize_endpoint(endpoint.descriptor)); return {}; } @@ -274,9 +255,21 @@ namespace Kernel void USBHIDDriver::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id) { - // If this packet is not for us, skip it - if (m_endpoint_id != endpoint_id) - return; + { + bool found = false; + for (const auto& endpoint : m_interface.endpoints) + { + const auto& desc = endpoint.descriptor; + if (endpoint_id == (desc.bEndpointAddress & 0x0F) * 2 + !!(desc.bEndpointAddress & 0x80)) + { + found = true; + break; + } + } + // If this packet is not for us, skip it + if (!found) + return; + } if constexpr(DEBUG_HID) { diff --git a/kernel/kernel/USB/XHCI/Device.cpp b/kernel/kernel/USB/XHCI/Device.cpp index c7d337af61..4860eb8abe 100644 --- a/kernel/kernel/USB/XHCI/Device.cpp +++ b/kernel/kernel/USB/XHCI/Device.cpp @@ -15,7 +15,6 @@ namespace Kernel return TRY(BAN::UniqPtr::create(controller, port_id, slot_id)); } - uint64_t XHCIDevice::calculate_port_bits_per_second(XHCIController& controller, uint32_t port_id) { const uint32_t portsc = controller.operational_regs().ports[port_id - 1].portsc; @@ -203,7 +202,29 @@ namespace Kernel BAN::ErrorOr XHCIDevice::initialize_endpoint(const USBEndpointDescriptor& endpoint_descriptor) { - const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); + ASSERT(m_controller.port(m_port_id).revision_major == 2); + + const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); + const uint32_t max_packet_size = endpoint_descriptor.wMaxPacketSize & 0x07FF; + const uint32_t max_burst_size = (endpoint_descriptor.wMaxPacketSize >> 11) & 0x0003; + const uint32_t max_esit_payload = max_packet_size * (max_burst_size + 1); + const uint32_t interval = determine_interval(endpoint_descriptor, m_speed_class); + const uint32_t average_trb_length = ((endpoint_descriptor.bmAttributes & 3) == 0b00) ? 8 : max_esit_payload; + const uint32_t error_count = ((endpoint_descriptor.bmAttributes & 3) == 0b01) ? 0 : 3; + + XHCI::EndpointType endpoint_type; + switch ((endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03)) + { + case 0x00: + case 0x80: endpoint_type = XHCI::EndpointType::Control; break; + case 0x01: endpoint_type = XHCI::EndpointType::IsochOut; break; + case 0x81: endpoint_type = XHCI::EndpointType::IsochIn; break; + case 0x02: endpoint_type = XHCI::EndpointType::BulkOut; break; + case 0x82: endpoint_type = XHCI::EndpointType::BulkIn; break; + case 0x03: endpoint_type = XHCI::EndpointType::InterruptOut; break; + case 0x83: endpoint_type = XHCI::EndpointType::InterruptIn; break; + default: ASSERT_NOT_REACHED(); + } auto& endpoint = m_endpoints[endpoint_id - 1]; ASSERT(!endpoint.transfer_ring); @@ -214,7 +235,7 @@ namespace Kernel last_valid_endpoint_id = i + 1; endpoint.transfer_ring = TRY(DMARegion::create(m_transfer_ring_trb_count * sizeof(XHCI::TRB))); - endpoint.max_packet_size = endpoint_descriptor.wMaxPacketSize & 0x07FF; + endpoint.max_packet_size = max_packet_size; endpoint.dequeue_index = 0; endpoint.enqueue_index = 0; endpoint.cycle_bit = 1; @@ -237,23 +258,16 @@ namespace Kernel slot_context.context_entries = last_valid_endpoint_id; // FIXME: 4.5.2 hub - ASSERT(endpoint_descriptor.bEndpointAddress & 0x80); - ASSERT((endpoint_descriptor.bmAttributes & 0x03) == 3); - ASSERT(m_controller.port(m_port_id).revision_major == 2); - - const uint32_t max_esit_payload = endpoint_context.max_packet_size * (endpoint_context.max_burst_size + 1); - const uint32_t interval = determine_interval(endpoint_descriptor, m_speed_class); - memset(&endpoint_context, 0, context_size); - endpoint_context.endpoint_type = XHCI::EndpointType::InterruptIn; - endpoint_context.max_packet_size = endpoint.max_packet_size; - endpoint_context.max_burst_size = (endpoint_descriptor.wMaxPacketSize >> 11) & 0x0003; + endpoint_context.endpoint_type = endpoint_type; + endpoint_context.max_packet_size = max_packet_size; + endpoint_context.max_burst_size = max_burst_size; endpoint_context.mult = 0; - endpoint_context.error_count = 3; + endpoint_context.error_count = error_count; endpoint_context.tr_dequeue_pointer = endpoint.transfer_ring->paddr() | 1; endpoint_context.max_esit_payload_lo = max_esit_payload & 0xFFFF; endpoint_context.max_esit_payload_hi = max_esit_payload >> 16; - endpoint_context.average_trb_length = max_esit_payload; + endpoint_context.average_trb_length = average_trb_length; endpoint_context.interval = interval; } @@ -264,19 +278,28 @@ namespace Kernel configure_endpoint.configure_endpoint_command.slot_id = m_slot_id; TRY(m_controller.send_command(configure_endpoint)); - auto& trb = *reinterpret_cast(endpoint.transfer_ring->vaddr()); - memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); - trb.normal.trb_type = XHCI::TRBType::Normal; - trb.normal.data_buffer_pointer = endpoint.data_region->paddr(); - trb.normal.trb_transfer_length = endpoint.data_region->size(); - trb.normal.td_size = 0; - trb.normal.interrupt_target = 0; - trb.normal.cycle_bit = 1; - trb.normal.interrupt_on_completion = 1; - trb.normal.interrupt_on_short_packet = 1; - advance_endpoint_enqueue(endpoint, false); + if (endpoint_type == XHCI::EndpointType::InterruptIn) + { + auto& trb = *reinterpret_cast(endpoint.transfer_ring->vaddr()); + memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); + trb.normal.trb_type = XHCI::TRBType::Normal; + trb.normal.data_buffer_pointer = endpoint.data_region->paddr(); + trb.normal.trb_transfer_length = endpoint.data_region->size(); + trb.normal.td_size = 0; + trb.normal.interrupt_target = 0; + trb.normal.cycle_bit = 1; + trb.normal.interrupt_on_completion = 1; + trb.normal.interrupt_on_short_packet = 1; + advance_endpoint_enqueue(endpoint, false); - m_controller.doorbell_reg(m_slot_id) = endpoint_id; + m_controller.doorbell_reg(m_slot_id) = endpoint_id; + } + else + { + dwarnln("Configured unsupported endpoint {2H}", + (endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03) + ); + } return {}; }