diff --git a/kernel/include/kernel/USB/Definitions.h b/kernel/include/kernel/USB/Definitions.h index 2adad2ee..d36f31df 100644 --- a/kernel/include/kernel/USB/Definitions.h +++ b/kernel/include/kernel/USB/Definitions.h @@ -53,6 +53,18 @@ namespace Kernel VendorSpecific = 0xFF, }; + enum DescriptorType : uint8_t + { + DEVICE = 1, + CONFIGURATION = 2, + STRING = 3, + INTERFACE = 4, + ENDPOINT = 5, + DEVICE_QUALIFIER = 6, + OTHER_SPEED_CONFIGURATION = 7, + INTERFACE_POWER = 8, + }; + enum RequestType : uint8_t { HostToDevice = 0b0 << 7, @@ -117,7 +129,7 @@ namespace Kernel static_assert(sizeof(USBConfigurationDescriptor) == 9); static constexpr size_t foo = sizeof(USBConfigurationDescriptor); - struct USBInterfaceDescritor + struct USBInterfaceDescriptor { uint8_t bLength; uint8_t bDescriptorType; @@ -129,7 +141,7 @@ namespace Kernel uint8_t bInterfaceProtocol; uint8_t iInterface; }; - static_assert(sizeof(USBInterfaceDescritor) == 9); + static_assert(sizeof(USBInterfaceDescriptor) == 9); struct USBEndpointDescriptor { diff --git a/kernel/include/kernel/USB/Device.h b/kernel/include/kernel/USB/Device.h index d6b144d2..a630e325 100644 --- a/kernel/include/kernel/USB/Device.h +++ b/kernel/include/kernel/USB/Device.h @@ -2,6 +2,7 @@ #include +#include #include #include @@ -21,8 +22,9 @@ namespace Kernel struct InterfaceDescriptor { - USBInterfaceDescritor descriptor; + USBInterfaceDescriptor descriptor; BAN::Vector endpoints; + BAN::Vector> misc_descriptors; }; struct ConfigurationDescriptor @@ -43,14 +45,21 @@ namespace Kernel BAN::ErrorOr initialize(); + const BAN::Vector& configurations() { return m_descriptor.configurations; } + + virtual BAN::ErrorOr send_request(const USBDeviceRequest&, paddr_t buffer) = 0; + static USB::SpeedClass determine_speed_class(uint64_t bits_per_second); protected: virtual BAN::ErrorOr initialize_control_endpoint() = 0; - virtual BAN::ErrorOr send_request(const USBDeviceRequest&, paddr_t buffer) = 0; + + private: + BAN::ErrorOr parse_configuration(size_t index); private: DeviceDescriptor m_descriptor; + BAN::UniqPtr m_dma_buffer; }; } diff --git a/kernel/include/kernel/USB/XHCI/Definitions.h b/kernel/include/kernel/USB/XHCI/Definitions.h index a148fa1c..8cf704f7 100644 --- a/kernel/include/kernel/USB/XHCI/Definitions.h +++ b/kernel/include/kernel/USB/XHCI/Definitions.h @@ -227,7 +227,7 @@ namespace Kernel::XHCI uint64_t data_buffer_pointer : 64; uint32_t trb_transfer_length : 17; - uint32_t : 5; + uint32_t td_size : 5; uint32_t interrupt_target : 10; uint32_t cycle_bit : 1; diff --git a/kernel/include/kernel/USB/XHCI/Device.h b/kernel/include/kernel/USB/XHCI/Device.h index 14f808d2..c52400d9 100644 --- a/kernel/include/kernel/USB/XHCI/Device.h +++ b/kernel/include/kernel/USB/XHCI/Device.h @@ -18,21 +18,24 @@ namespace Kernel struct Endpoint { BAN::UniqPtr transfer_ring; + uint32_t dequeue_index { 0 }; uint32_t enqueue_index { 0 }; bool cycle_bit { 1 }; Mutex mutex; + volatile uint32_t transfer_count { 0 }; volatile XHCI::TRB completion_trb; }; public: static BAN::ErrorOr> create(XHCIController&, uint32_t port_id, uint32_t slot_id); + BAN::ErrorOr send_request(const USBDeviceRequest&, paddr_t buffer) override; + void on_transfer_event(const volatile XHCI::TRB&); protected: BAN::ErrorOr initialize_control_endpoint() override; - BAN::ErrorOr send_request(const USBDeviceRequest&, paddr_t buffer) override; private: XHCIDevice(XHCIController& controller, uint32_t port_id, uint32_t slot_id) @@ -46,7 +49,7 @@ namespace Kernel void advance_endpoint_enqueue(Endpoint&, bool chain); private: - static constexpr uint32_t m_transfer_ring_trb_count = 256; + static constexpr uint32_t m_transfer_ring_trb_count = PAGE_SIZE / sizeof(XHCI::TRB); XHCIController& m_controller; const uint32_t m_port_id; diff --git a/kernel/kernel/USB/Device.cpp b/kernel/kernel/USB/Device.cpp index a0889ff6..4bc32fed 100644 --- a/kernel/kernel/USB/Device.cpp +++ b/kernel/kernel/USB/Device.cpp @@ -11,7 +11,7 @@ namespace Kernel { TRY(initialize_control_endpoint()); - auto buffer = TRY(DMARegion::create(PAGE_SIZE)); + m_dma_buffer = TRY(DMARegion::create(1024)); USBDeviceRequest request; request.bmRequestType = USB::RequestType::DeviceToHost | USB::RequestType::Standard | USB::RequestType::Device; @@ -19,73 +19,20 @@ namespace Kernel request.wValue = 0x0100; request.wIndex = 0; request.wLength = sizeof(USBDeviceDescriptor); - TRY(send_request(request, buffer->paddr())); + auto transferred = TRY(send_request(request, m_dma_buffer->paddr())); + + m_descriptor.descriptor = *reinterpret_cast(m_dma_buffer->vaddr()); + if (transferred < sizeof(USBDeviceDescriptor) || transferred < m_descriptor.descriptor.bLength) + { + dprintln("invalid device descriptor response {}"); + return BAN::Error::from_errno(EINVAL); + } - m_descriptor.descriptor = *reinterpret_cast(buffer->vaddr()); dprintln_if(DEBUG_USB, "device has {} configurations", m_descriptor.descriptor.bNumConfigurations); for (uint32_t i = 0; i < m_descriptor.descriptor.bNumConfigurations; i++) - { - { - USBDeviceRequest request; - request.bmRequestType = USB::RequestType::DeviceToHost | USB::RequestType::Standard | USB::RequestType::Device; - request.bRequest = USB::Request::GET_DESCRIPTOR; - request.wValue = 0x0200 | i; - request.wIndex = 0; - request.wLength = sizeof(USBConfigurationDescriptor); - TRY(send_request(request, buffer->paddr())); - - auto configuration = *reinterpret_cast(buffer->vaddr()); - - dprintln_if(DEBUG_USB, " configuration {} is {} bytes", i, +configuration.wTotalLength); - if (configuration.wTotalLength > buffer->size()) - { - dwarnln(" our buffer is only {} bytes, skipping some fields..."); - configuration.wTotalLength = buffer->size(); - } - - if (configuration.wTotalLength > request.wLength) - { - request.wLength = configuration.wTotalLength; - TRY(send_request(request, buffer->paddr())); - } - } - - auto configuration = *reinterpret_cast(buffer->vaddr()); - - BAN::Vector interfaces; - TRY(interfaces.reserve(configuration.bNumInterfaces)); - - dprintln_if(DEBUG_USB, " configuration {} has {} interfaces", i, configuration.bNumInterfaces); - - uintptr_t offset = configuration.bLength; - for (uint32_t j = 0; j < configuration.bNumInterfaces; j++) - { - if (offset + sizeof(USBInterfaceDescritor) > buffer->size()) - break; - auto interface = *reinterpret_cast(buffer->vaddr() + offset); - - BAN::Vector endpoints; - TRY(endpoints.reserve(interface.bNumEndpoints)); - - dprintln_if(DEBUG_USB, " interface {} has {} endpoints", j, interface.bNumEndpoints); - - offset += interface.bLength; - for (uint32_t k = 0; k < interface.bNumEndpoints; k++) - { - if (offset + sizeof(USBEndpointDescriptor) > buffer->size()) - break; - auto endpoint = *reinterpret_cast(buffer->vaddr() + offset); - offset += endpoint.bLength; - - TRY(endpoints.emplace_back(endpoint)); - } - - TRY(interfaces.emplace_back(interface, BAN::move(endpoints))); - } - - TRY(m_descriptor.configurations.emplace_back(configuration, BAN::move(interfaces))); - } + if (auto opt_configuration = parse_configuration(i); !opt_configuration.is_error()) + TRY(m_descriptor.configurations.push_back(opt_configuration.release_value())); #if USB_DUMP_DESCRIPTORS const auto& descriptor = m_descriptor.descriptor; @@ -176,10 +123,14 @@ namespace Kernel ASSERT_NOT_REACHED(); } - for (const auto& configuration : m_descriptor.configurations) + for (size_t i = 0; i < m_descriptor.configurations.size(); i++) { - for (const auto& interface : configuration.interfaces) + const auto& configuration = m_descriptor.configurations[i]; + + for (size_t j = 0; j < configuration.interfaces.size(); j++) { + const auto& interface = configuration.interfaces[j]; + switch (static_cast(interface.descriptor.bInterfaceClass)) { case USB::InterfaceBaseClass::Audio: @@ -258,6 +209,84 @@ namespace Kernel return BAN::Error::from_errno(ENOTSUP); } + BAN::ErrorOr USBDevice::parse_configuration(size_t index) + { + { + USBDeviceRequest request; + request.bmRequestType = USB::RequestType::DeviceToHost | USB::RequestType::Standard | USB::RequestType::Device; + request.bRequest = USB::Request::GET_DESCRIPTOR; + request.wValue = 0x0200 | index; + request.wIndex = 0; + request.wLength = m_dma_buffer->size(); + auto transferred = TRY(send_request(request, m_dma_buffer->paddr())); + + auto configuration = *reinterpret_cast(m_dma_buffer->vaddr()); + + dprintln_if(DEBUG_USB, "configuration {} is {} bytes", index, +configuration.wTotalLength); + if (configuration.bLength < sizeof(USBConfigurationDescriptor) || transferred < configuration.wTotalLength) + { + dwarnln("invalid configuration descriptor size: {} length, {} total", configuration.bLength, +configuration.wTotalLength); + return BAN::Error::from_errno(EINVAL); + } + } + + ConfigurationDescriptor configuration; + configuration.desciptor = *reinterpret_cast(m_dma_buffer->vaddr()); + + ptrdiff_t offset = configuration.desciptor.bLength; + while (offset < configuration.desciptor.wTotalLength) + { + const uint8_t length = *reinterpret_cast(m_dma_buffer->vaddr() + offset + 0); + const uint8_t type = *reinterpret_cast(m_dma_buffer->vaddr() + offset + 1); + + switch (type) + { + case USB::DescriptorType::INTERFACE: + if (length < sizeof(USBInterfaceDescriptor)) + { + dwarnln("invalid interface descriptor size {}", length); + return BAN::Error::from_errno(EINVAL); + } + TRY(configuration.interfaces.emplace_back( + *reinterpret_cast(m_dma_buffer->vaddr() + offset), + BAN::Vector(), + BAN::Vector>() + )); + break; + case USB::DescriptorType::ENDPOINT: + if (length < sizeof(USBEndpointDescriptor)) + { + dwarnln("invalid interface descriptor size {}", length); + return BAN::Error::from_errno(EINVAL); + } + if (configuration.interfaces.empty()) + { + dwarnln("invalid endpoint descriptor before interface descriptor"); + return BAN::Error::from_errno(EINVAL); + } + TRY(configuration.interfaces.back().endpoints.emplace_back( + *reinterpret_cast(m_dma_buffer->vaddr() + offset) + )); + break; + default: + if (configuration.interfaces.empty()) + dprintln_if(DEBUG_USB, "skipping descriptor type {}", type); + else + { + BAN::Vector descriptor; + TRY(descriptor.resize(length)); + memcpy(descriptor.data(), reinterpret_cast(m_dma_buffer->vaddr() + offset), length); + TRY(configuration.interfaces.back().misc_descriptors.push_back(BAN::move(descriptor))); + } + break; + } + + offset += length; + } + + return BAN::move(configuration); + } + USB::SpeedClass USBDevice::determine_speed_class(uint64_t bits_per_second) { if (bits_per_second <= 1'500'000) diff --git a/kernel/kernel/USB/XHCI/Device.cpp b/kernel/kernel/USB/XHCI/Device.cpp index 681c65db..dd7d3de7 100644 --- a/kernel/kernel/USB/XHCI/Device.cpp +++ b/kernel/kernel/USB/XHCI/Device.cpp @@ -105,6 +105,8 @@ namespace Kernel BAN::ErrorOr XHCIDevice::update_actual_max_packet_size() { + // FIXME: This is more or less generic USB code + dprintln_if(DEBUG_XHCI, "Retrieving actual max packet size of full speed device"); BAN::Vector buffer; @@ -159,6 +161,27 @@ namespace Kernel return; } + // Get received bytes from short packet + if (trb.transfer_event.completion_code == 13) + { + auto& endpoint = m_endpoints[trb.transfer_event.endpoint_id - 1]; + auto* transfer_trb_arr = reinterpret_cast(endpoint.transfer_ring->vaddr()); + + const uint32_t trb_index = (trb.transfer_event.trb_pointer - endpoint.transfer_ring->paddr()) / sizeof(XHCI::TRB); + + const uint32_t full_trbs_transferred = (trb_index >= endpoint.dequeue_index) + ? trb_index - 1 - endpoint.dequeue_index + : trb_index + m_transfer_ring_trb_count - 2 - endpoint.dequeue_index; + + const uint32_t full_trb_data = full_trbs_transferred * m_max_packet_size; + const uint32_t short_data = transfer_trb_arr[trb_index].data_stage.trb_transfer_length - trb.transfer_event.trb_transfer_length; + + endpoint.transfer_count = full_trb_data + short_data; + + ASSERT(trb_index >= endpoint.dequeue_index); + return; + } + // NOTE: dword2 is last (and atomic) as that is what send_request is waiting for auto& completion_trb = m_endpoints[trb.transfer_event.endpoint_id - 1].completion_trb; completion_trb.raw.dword0 = trb.raw.dword0; @@ -167,23 +190,35 @@ namespace Kernel __atomic_store_n(&completion_trb.raw.dword2, trb.raw.dword2, __ATOMIC_SEQ_CST); } - BAN::ErrorOr XHCIDevice::send_request(const USBDeviceRequest& request, paddr_t buffer_paddr) + BAN::ErrorOr XHCIDevice::send_request(const USBDeviceRequest& request, paddr_t buffer_paddr) { - // minus 3: Setup, Status, Link + // FIXME: This is more or less generic USB code + + // minus 3: Setup, Status, Link (this is probably too generous and will result in STALL) if (request.wLength > (m_transfer_ring_trb_count - 3) * m_max_packet_size) return BAN::Error::from_errno((ENOBUFS)); auto& endpoint = m_endpoints[0]; LockGuard _(endpoint.mutex); + uint8_t transfer_type = + [&request]() -> uint8_t + { + if (request.wLength == 0) + return 0; + if (request.bmRequestType & USB::RequestType::DeviceToHost) + return 3; + return 2; + }(); + auto* transfer_trb_arr = reinterpret_cast(endpoint.transfer_ring->vaddr()); { auto& trb = transfer_trb_arr[endpoint.enqueue_index]; - memset((void*)&trb, 0, sizeof(XHCI::TRB)); + memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); trb.setup_stage.trb_type = XHCI::TRBType::SetupStage; - trb.setup_stage.transfer_type = 3; + trb.setup_stage.transfer_type = transfer_type; trb.setup_stage.trb_transfer_length = 8; trb.setup_stage.interrupt_on_completion = 0; trb.setup_stage.immediate_data = 1; @@ -198,31 +233,37 @@ namespace Kernel advance_endpoint_enqueue(endpoint, false); } + const uint32_t td_packet_count = BAN::Math::div_round_up(request.wLength, m_max_packet_size); + uint32_t packets_transferred = 1; + uint32_t bytes_handled = 0; while (bytes_handled < request.wLength) { const uint32_t to_handle = BAN::Math::min(m_max_packet_size, request.wLength - bytes_handled); auto& trb = transfer_trb_arr[endpoint.enqueue_index]; - memset((void*)&trb, 0, sizeof(XHCI::TRB)); + memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); - trb.data_stage.trb_type = XHCI::TRBType::DataStage; - trb.data_stage.direction = 1; - trb.data_stage.trb_transfer_length = to_handle; - trb.data_stage.chain_bit = (bytes_handled + to_handle < request.wLength); - trb.data_stage.interrupt_on_completion = 0; - trb.data_stage.immediate_data = 0; - trb.data_stage.data_buffer_pointer = buffer_paddr + bytes_handled; - trb.data_stage.cycle_bit = endpoint.cycle_bit; + trb.data_stage.trb_type = XHCI::TRBType::DataStage; + trb.data_stage.direction = 1; + trb.data_stage.trb_transfer_length = to_handle; + trb.data_stage.td_size = BAN::Math::min(td_packet_count - packets_transferred, 31); + trb.data_stage.chain_bit = (bytes_handled + to_handle < request.wLength); + trb.data_stage.interrupt_on_completion = 0; + trb.data_stage.interrupt_on_short_packet = 1; + trb.data_stage.immediate_data = 0; + trb.data_stage.data_buffer_pointer = buffer_paddr + bytes_handled; + trb.data_stage.cycle_bit = endpoint.cycle_bit; bytes_handled += to_handle; + packets_transferred++; - advance_endpoint_enqueue(endpoint, false); + advance_endpoint_enqueue(endpoint, trb.data_stage.chain_bit); } { auto& trb = transfer_trb_arr[endpoint.enqueue_index]; - memset((void*)&trb, 0, sizeof(XHCI::TRB)); + memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); trb.status_stage.trb_type = XHCI::TRBType::StatusStage; trb.status_stage.direction = 0; @@ -239,6 +280,8 @@ namespace Kernel completion_trb.raw.dword2 = 0; completion_trb.raw.dword3 = 0; + endpoint.transfer_count = request.wLength; + m_controller.doorbell_reg(m_slot_id) = 1; const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + 1000; @@ -246,13 +289,15 @@ namespace Kernel if (SystemTimer::get().ms_since_boot() > timeout_ms) return BAN::Error::from_errno(ETIMEDOUT); + endpoint.dequeue_index = endpoint.enqueue_index; + if (completion_trb.transfer_event.completion_code != 1) { dwarnln("Completion error: {}", +completion_trb.transfer_event.completion_code); return BAN::Error::from_errno(EFAULT); } - return {}; + return endpoint.transfer_count; } void XHCIDevice::advance_endpoint_enqueue(Endpoint& endpoint, bool chain)