From 1253e2a4583e81776ac7422e783da4a0fb0f75fe Mon Sep 17 00:00:00 2001 From: Bananymous Date: Thu, 21 Nov 2024 13:44:21 +0200 Subject: [PATCH] Kernel: Add support for bulk endpoints and update endpoint API USB device now sets its own data buffers for IN/OUT endpoints. This allows more customization and parallelism as data buffer does not have to be shared. --- kernel/include/kernel/USB/Device.h | 7 +- kernel/include/kernel/USB/HID/HIDDriver.h | 9 +- kernel/include/kernel/USB/XHCI/Device.h | 4 +- kernel/kernel/USB/Device.cpp | 15 +++- kernel/kernel/USB/HID/HIDDriver.cpp | 56 +++++++----- kernel/kernel/USB/XHCI/Device.cpp | 102 +++++++++++----------- 6 files changed, 105 insertions(+), 88 deletions(-) diff --git a/kernel/include/kernel/USB/Device.h b/kernel/include/kernel/USB/Device.h index dc03ebc2..973fc5ef 100644 --- a/kernel/include/kernel/USB/Device.h +++ b/kernel/include/kernel/USB/Device.h @@ -19,7 +19,9 @@ namespace Kernel USBClassDriver() = default; virtual ~USBClassDriver() = default; - virtual void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id) = 0; + virtual BAN::ErrorOr initialize() { return {}; }; + + virtual void handle_input_data(size_t byte_count, uint8_t endpoint_id) = 0; }; class USBDevice @@ -64,11 +66,12 @@ namespace Kernel virtual BAN::ErrorOr initialize_endpoint(const USBEndpointDescriptor&) = 0; virtual BAN::ErrorOr send_request(const USBDeviceRequest&, paddr_t buffer) = 0; + virtual void send_data_buffer(uint8_t endpoint_id, paddr_t buffer, size_t buffer_len) = 0; static USB::SpeedClass determine_speed_class(uint64_t bits_per_second); protected: - void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id); + void handle_input_data(size_t byte_count, uint8_t endpoint_id); virtual BAN::ErrorOr initialize_control_endpoint() = 0; private: diff --git a/kernel/include/kernel/USB/HID/HIDDriver.h b/kernel/include/kernel/USB/HID/HIDDriver.h index 07f3ca84..8a63d6b9 100644 --- a/kernel/include/kernel/USB/HID/HIDDriver.h +++ b/kernel/include/kernel/USB/HID/HIDDriver.h @@ -75,15 +75,13 @@ namespace Kernel }; public: - static BAN::ErrorOr> create(USBDevice&, const USBDevice::InterfaceDescriptor&); - - void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id) override; + void handle_input_data(size_t byte_count, uint8_t endpoint_id) override; private: USBHIDDriver(USBDevice&, const USBDevice::InterfaceDescriptor&); ~USBHIDDriver(); - BAN::ErrorOr initialize(); + BAN::ErrorOr initialize() override; BAN::ErrorOr> initializes_device_reports(const BAN::Vector&); @@ -94,6 +92,9 @@ namespace Kernel bool m_uses_report_id { false }; BAN::Vector m_device_inputs; + uint8_t m_data_endpoint_id = 0; + BAN::UniqPtr m_data_buffer; + friend class BAN::UniqPtr; }; diff --git a/kernel/include/kernel/USB/XHCI/Device.h b/kernel/include/kernel/USB/XHCI/Device.h index 7e09a004..20034735 100644 --- a/kernel/include/kernel/USB/XHCI/Device.h +++ b/kernel/include/kernel/USB/XHCI/Device.h @@ -27,7 +27,6 @@ namespace Kernel volatile uint32_t transfer_count { 0 }; volatile XHCI::TRB completion_trb; - BAN::UniqPtr data_region; void(XHCIDevice::*callback)(XHCI::TRB); }; @@ -36,6 +35,7 @@ namespace Kernel BAN::ErrorOr initialize_endpoint(const USBEndpointDescriptor&) override; BAN::ErrorOr send_request(const USBDeviceRequest&, paddr_t buffer) override; + void send_data_buffer(uint8_t endpoint_id, paddr_t buffer, size_t buffer_size) override; void on_transfer_event(const volatile XHCI::TRB&); @@ -47,7 +47,7 @@ namespace Kernel ~XHCIDevice(); BAN::ErrorOr update_actual_max_packet_size(); - void on_interrupt_endpoint_event(XHCI::TRB); + void on_interrupt_or_bulk_endpoint_event(XHCI::TRB); void advance_endpoint_enqueue(Endpoint&, bool chain); diff --git a/kernel/kernel/USB/Device.cpp b/kernel/kernel/USB/Device.cpp index b820f694..3bbdb3c1 100644 --- a/kernel/kernel/USB/Device.cpp +++ b/kernel/kernel/USB/Device.cpp @@ -154,7 +154,7 @@ namespace Kernel dprintln_if(DEBUG_USB, "Found CommunicationAndCDCControl interface"); break; case USB::InterfaceBaseClass::HID: - if (auto result = USBHIDDriver::create(*this, interface); !result.is_error()) + if (auto result = BAN::UniqPtr::create(*this, interface); !result.is_error()) TRY(m_class_drivers.push_back(result.release_value())); break; case USB::InterfaceBaseClass::Physical: @@ -220,6 +220,15 @@ namespace Kernel } } + for (size_t i = 0; i < m_class_drivers.size(); i++) + { + if (auto ret = m_class_drivers[i]->initialize(); ret.is_error()) + { + dwarnln("Could not initialize USB interface {}", ret.error()); + m_class_drivers.remove(i--); + } + } + if (!m_class_drivers.empty()) { dprintln("Successfully initialized USB device with {}/{} interfaces", @@ -317,10 +326,10 @@ namespace Kernel return BAN::move(configuration); } - void USBDevice::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id) + void USBDevice::handle_input_data(size_t byte_count, uint8_t endpoint_id) { for (auto& driver : m_class_drivers) - driver->handle_input_data(data, endpoint_id); + driver->handle_input_data(byte_count, endpoint_id); } USB::SpeedClass USBDevice::determine_speed_class(uint64_t bits_per_second) diff --git a/kernel/kernel/USB/HID/HIDDriver.cpp b/kernel/kernel/USB/HID/HIDDriver.cpp index 5cecb990..1df32a70 100644 --- a/kernel/kernel/USB/HID/HIDDriver.cpp +++ b/kernel/kernel/USB/HID/HIDDriver.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -68,13 +69,6 @@ namespace Kernel static BAN::ErrorOr> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id); - BAN::ErrorOr> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface) - { - auto result = TRY(BAN::UniqPtr::create(device, interface)); - TRY(result->initialize()); - return result; - } - USBHIDDriver::USBHIDDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface) : m_device(device) , m_interface(interface) @@ -192,7 +186,29 @@ namespace Kernel m_device_inputs = TRY(initializes_device_reports(collections)); for (const auto& endpoint : m_interface.endpoints) - TRY(m_device.initialize_endpoint(endpoint.descriptor)); + { + const auto& desc = endpoint.descriptor; + + if (!(desc.bEndpointAddress & 0x80)) + continue; + if ((desc.bmAttributes & 0x03) != 0x03) + continue; + + TRY(m_device.initialize_endpoint(desc)); + m_data_buffer = TRY(DMARegion::create(desc.wMaxPacketSize & 0x07FF)); + + m_data_endpoint_id = (desc.bEndpointAddress & 0x0F) * 2 + !!(desc.bEndpointAddress & 0x80); + + break; + } + + if (m_data_endpoint_id == 0) + { + dwarnln("HID device does not an interrupt IN endpoints"); + return BAN::Error::from_errno(EINVAL); + } + + m_device.send_data_buffer(m_data_endpoint_id, m_data_buffer->paddr(), m_data_buffer->size()); return {}; } @@ -256,23 +272,15 @@ namespace Kernel return BAN::move(result); } - void USBHIDDriver::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id) + void USBHIDDriver::handle_input_data(size_t byte_count, uint8_t endpoint_id) { - { - bool found = false; - for (const auto& endpoint : m_interface.endpoints) - { - const auto& desc = endpoint.descriptor; - if (endpoint_id == (desc.bEndpointAddress & 0x0F) * 2 + !!(desc.bEndpointAddress & 0x80)) - { - found = true; - break; - } - } - // If this packet is not for us, skip it - if (!found) - return; - } + if (m_data_endpoint_id != endpoint_id) + return; + + auto data = BAN::ConstByteSpan(reinterpret_cast(m_data_buffer->vaddr()), byte_count); + BAN::ScopeGuard _([&] { + m_device.send_data_buffer(m_data_endpoint_id, m_data_buffer->paddr(), m_data_buffer->size()); + }); if constexpr(DEBUG_USB_HID) { diff --git a/kernel/kernel/USB/XHCI/Device.cpp b/kernel/kernel/USB/XHCI/Device.cpp index dfb5a3e5..3a8e80a8 100644 --- a/kernel/kernel/USB/XHCI/Device.cpp +++ b/kernel/kernel/USB/XHCI/Device.cpp @@ -200,20 +200,20 @@ namespace Kernel BAN::ErrorOr XHCIDevice::initialize_endpoint(const USBEndpointDescriptor& endpoint_descriptor) { - ASSERT(m_controller.port(m_port_id).revision_major == 2); + const bool is_control { (endpoint_descriptor.bmAttributes & 0x03) == 0x00 }; + const bool is_isoch { (endpoint_descriptor.bmAttributes & 0x03) == 0x01 }; + const bool is_bulk { (endpoint_descriptor.bmAttributes & 0x03) == 0x02 }; + const bool is_interrupt { (endpoint_descriptor.bmAttributes & 0x03) == 0x03 }; - const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); - const uint32_t max_packet_size = endpoint_descriptor.wMaxPacketSize & 0x07FF; - const uint32_t max_burst_size = (endpoint_descriptor.wMaxPacketSize >> 11) & 0x0003; - const uint32_t max_esit_payload = max_packet_size * (max_burst_size + 1); - const uint32_t interval = determine_interval(endpoint_descriptor, m_speed_class); - const uint32_t average_trb_length = ((endpoint_descriptor.bmAttributes & 3) == 0b00) ? 8 : max_esit_payload; - const uint32_t error_count = ((endpoint_descriptor.bmAttributes & 3) == 0b01) ? 0 : 3; + (void)is_control; + (void)is_isoch; + (void)is_bulk; + (void)is_interrupt; XHCI::EndpointType endpoint_type; switch ((endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03)) { - case 0x00: + case 0x00: ASSERT_NOT_REACHED(); case 0x80: endpoint_type = XHCI::EndpointType::Control; break; case 0x01: endpoint_type = XHCI::EndpointType::IsochOut; break; case 0x81: endpoint_type = XHCI::EndpointType::IsochIn; break; @@ -224,6 +224,16 @@ namespace Kernel default: ASSERT_NOT_REACHED(); } + // FIXME: Streams + + const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); + const uint32_t max_packet_size = (is_control || is_bulk) ? endpoint_descriptor.wMaxPacketSize : endpoint_descriptor.wMaxPacketSize & 0x07FF; + const uint32_t max_burst_size = (is_control || is_bulk) ? 0 : (endpoint_descriptor.wMaxPacketSize & 0x1800) >> 11; + const uint32_t max_esit_payload = max_packet_size * (max_burst_size + 1); + const uint32_t interval = determine_interval(endpoint_descriptor, m_speed_class); + const uint32_t average_trb_length = (is_control) ? 8 : max_esit_payload; + const uint32_t error_count = (is_isoch) ? 0 : 3; + auto& endpoint = m_endpoints[endpoint_id - 1]; ASSERT(!endpoint.transfer_ring); @@ -237,8 +247,7 @@ namespace Kernel endpoint.dequeue_index = 0; endpoint.enqueue_index = 0; endpoint.cycle_bit = 1; - endpoint.callback = &XHCIDevice::on_interrupt_endpoint_event; - endpoint.data_region = TRY(DMARegion::create(endpoint.max_packet_size)); + endpoint.callback = (is_interrupt || is_bulk) ? &XHCIDevice::on_interrupt_or_bulk_endpoint_event : nullptr; memset(reinterpret_cast(endpoint.transfer_ring->vaddr()), 0, endpoint.transfer_ring->size()); @@ -276,62 +285,27 @@ namespace Kernel configure_endpoint.configure_endpoint_command.slot_id = m_slot_id; TRY(m_controller.send_command(configure_endpoint)); - if (endpoint_type == XHCI::EndpointType::InterruptIn) - { - auto& trb = *reinterpret_cast(endpoint.transfer_ring->vaddr()); - memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); - trb.normal.trb_type = XHCI::TRBType::Normal; - trb.normal.data_buffer_pointer = endpoint.data_region->paddr(); - trb.normal.trb_transfer_length = endpoint.data_region->size(); - trb.normal.td_size = 0; - trb.normal.interrupt_target = 0; - trb.normal.cycle_bit = 1; - trb.normal.interrupt_on_completion = 1; - trb.normal.interrupt_on_short_packet = 1; - advance_endpoint_enqueue(endpoint, false); - - m_controller.doorbell_reg(m_slot_id) = endpoint_id; - } - else - { - dwarnln("Configured unsupported endpoint {2H}", - (endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03) - ); - } - return {}; } - void XHCIDevice::on_interrupt_endpoint_event(XHCI::TRB trb) + void XHCIDevice::on_interrupt_or_bulk_endpoint_event(XHCI::TRB trb) { ASSERT(trb.trb_type == XHCI::TRBType::TransferEvent); if (trb.transfer_event.completion_code != 1 && trb.transfer_event.completion_code != 13) { - dwarnln("Interrupt endpoint got transfer event with completion code {}", +trb.transfer_event.completion_code); + dwarnln("Interrupt or bulk endpoint got transfer event with completion code {}", +trb.transfer_event.completion_code); return; } const uint32_t endpoint_id = trb.transfer_event.endpoint_id; auto& endpoint = m_endpoints[endpoint_id - 1]; - ASSERT(endpoint.transfer_ring && endpoint.data_region); - const uint32_t transfer_length = endpoint.max_packet_size - trb.transfer_event.trb_transfer_length; - auto received_data = BAN::ConstByteSpan(reinterpret_cast(endpoint.data_region->vaddr()), transfer_length); - handle_input_data(received_data, endpoint_id); + const auto* transfer_trb_arr = reinterpret_cast(endpoint.transfer_ring->vaddr()); + const uint32_t transfer_trb_index = (trb.transfer_event.trb_pointer - endpoint.transfer_ring->paddr()) / sizeof(XHCI::TRB); + const uint32_t original_len = transfer_trb_arr[transfer_trb_index].normal.trb_transfer_length; - auto& new_trb = *reinterpret_cast(endpoint.transfer_ring->vaddr() + endpoint.enqueue_index * sizeof(XHCI::TRB)); - memset(const_cast(&new_trb), 0, sizeof(XHCI::TRB)); - new_trb.normal.trb_type = XHCI::TRBType::Normal; - new_trb.normal.data_buffer_pointer = endpoint.data_region->paddr(); - new_trb.normal.trb_transfer_length = endpoint.max_packet_size; - new_trb.normal.td_size = 0; - new_trb.normal.interrupt_target = 0; - new_trb.normal.cycle_bit = endpoint.cycle_bit; - new_trb.normal.interrupt_on_completion = 1; - new_trb.normal.interrupt_on_short_packet = 1; - advance_endpoint_enqueue(endpoint, false); - - m_controller.doorbell_reg(m_slot_id) = endpoint_id; + const uint32_t transfer_length = original_len - trb.transfer_event.trb_transfer_length; + handle_input_data(transfer_length, endpoint_id); } void XHCIDevice::on_transfer_event(const volatile XHCI::TRB& trb) @@ -495,6 +469,28 @@ namespace Kernel return endpoint.transfer_count; } + void XHCIDevice::send_data_buffer(uint8_t endpoint_id, paddr_t buffer, size_t buffer_len) + { + ASSERT(endpoint_id != 0); + auto& endpoint = m_endpoints[endpoint_id - 1]; + + ASSERT(buffer_len <= endpoint.max_packet_size); + + auto& trb = *reinterpret_cast(endpoint.transfer_ring->vaddr() + endpoint.enqueue_index * sizeof(XHCI::TRB)); + memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); + trb.normal.trb_type = XHCI::TRBType::Normal; + trb.normal.data_buffer_pointer = buffer; + trb.normal.trb_transfer_length = buffer_len; + trb.normal.td_size = 0; + trb.normal.interrupt_target = 0; + trb.normal.cycle_bit = endpoint.cycle_bit; + trb.normal.interrupt_on_completion = 1; + trb.normal.interrupt_on_short_packet = 1; + advance_endpoint_enqueue(endpoint, false); + + m_controller.doorbell_reg(m_slot_id) = endpoint_id; + } + void XHCIDevice::advance_endpoint_enqueue(Endpoint& endpoint, bool chain) { endpoint.enqueue_index++;