Kernel: Make USB HID interfaces configure all endpoints

This commit is contained in:
Bananymous 2024-07-16 00:23:26 +03:00
parent 339e8a7910
commit 1337758660
3 changed files with 67 additions and 53 deletions

View File

@ -93,8 +93,6 @@ namespace Kernel
const uint8_t m_interface_index; const uint8_t m_interface_index;
bool m_uses_report_id { false }; bool m_uses_report_id { false };
uint8_t m_endpoint_id { 0 };
BAN::Vector<DeviceReport> m_device_inputs; BAN::Vector<DeviceReport> m_device_inputs;
friend class BAN::UniqPtr<USBHIDDriver>; friend class BAN::UniqPtr<USBHIDDriver>;

View File

@ -90,24 +90,6 @@ namespace Kernel
ASSERT(static_cast<USB::InterfaceBaseClass>(m_interface.descriptor.bInterfaceClass) == USB::InterfaceBaseClass::HID); ASSERT(static_cast<USB::InterfaceBaseClass>(m_interface.descriptor.bInterfaceClass) == USB::InterfaceBaseClass::HID);
size_t endpoint_index = static_cast<size_t>(-1);
for (size_t i = 0; i < m_interface.endpoints.size(); i++)
{
const auto& endpoint = m_interface.endpoints[i];
if (!(endpoint.descriptor.bEndpointAddress & 0x80))
continue;
if (endpoint.descriptor.bmAttributes != 0x03)
continue;
endpoint_index = i;
break;
}
if (endpoint_index >= m_interface.endpoints.size())
{
dwarnln("HID device does not contain IN interrupt endpoint");
return BAN::Error::from_errno(EFAULT);
}
bool hid_descriptor_invalid = false; bool hid_descriptor_invalid = false;
size_t hid_descriptor_index = static_cast<size_t>(-1); size_t hid_descriptor_index = static_cast<size_t>(-1);
for (size_t i = 0; i < m_interface.misc_descriptors.size(); i++) for (size_t i = 0; i < m_interface.misc_descriptors.size(); i++)
@ -206,9 +188,8 @@ namespace Kernel
m_device_inputs = TRY(initializes_device_reports(collections)); m_device_inputs = TRY(initializes_device_reports(collections));
const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor; for (const auto& endpoint : m_interface.endpoints)
m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); TRY(m_device.initialize_endpoint(endpoint.descriptor));
TRY(m_device.initialize_endpoint(endpoint_descriptor));
return {}; return {};
} }
@ -274,9 +255,21 @@ namespace Kernel
void USBHIDDriver::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id) void USBHIDDriver::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id)
{ {
{
bool found = false;
for (const auto& endpoint : m_interface.endpoints)
{
const auto& desc = endpoint.descriptor;
if (endpoint_id == (desc.bEndpointAddress & 0x0F) * 2 + !!(desc.bEndpointAddress & 0x80))
{
found = true;
break;
}
}
// If this packet is not for us, skip it // If this packet is not for us, skip it
if (m_endpoint_id != endpoint_id) if (!found)
return; return;
}
if constexpr(DEBUG_HID) if constexpr(DEBUG_HID)
{ {

View File

@ -15,7 +15,6 @@ namespace Kernel
return TRY(BAN::UniqPtr<XHCIDevice>::create(controller, port_id, slot_id)); return TRY(BAN::UniqPtr<XHCIDevice>::create(controller, port_id, slot_id));
} }
uint64_t XHCIDevice::calculate_port_bits_per_second(XHCIController& controller, uint32_t port_id) uint64_t XHCIDevice::calculate_port_bits_per_second(XHCIController& controller, uint32_t port_id)
{ {
const uint32_t portsc = controller.operational_regs().ports[port_id - 1].portsc; const uint32_t portsc = controller.operational_regs().ports[port_id - 1].portsc;
@ -203,7 +202,29 @@ namespace Kernel
BAN::ErrorOr<void> XHCIDevice::initialize_endpoint(const USBEndpointDescriptor& endpoint_descriptor) BAN::ErrorOr<void> XHCIDevice::initialize_endpoint(const USBEndpointDescriptor& endpoint_descriptor)
{ {
ASSERT(m_controller.port(m_port_id).revision_major == 2);
const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80);
const uint32_t max_packet_size = endpoint_descriptor.wMaxPacketSize & 0x07FF;
const uint32_t max_burst_size = (endpoint_descriptor.wMaxPacketSize >> 11) & 0x0003;
const uint32_t max_esit_payload = max_packet_size * (max_burst_size + 1);
const uint32_t interval = determine_interval(endpoint_descriptor, m_speed_class);
const uint32_t average_trb_length = ((endpoint_descriptor.bmAttributes & 3) == 0b00) ? 8 : max_esit_payload;
const uint32_t error_count = ((endpoint_descriptor.bmAttributes & 3) == 0b01) ? 0 : 3;
XHCI::EndpointType endpoint_type;
switch ((endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03))
{
case 0x00:
case 0x80: endpoint_type = XHCI::EndpointType::Control; break;
case 0x01: endpoint_type = XHCI::EndpointType::IsochOut; break;
case 0x81: endpoint_type = XHCI::EndpointType::IsochIn; break;
case 0x02: endpoint_type = XHCI::EndpointType::BulkOut; break;
case 0x82: endpoint_type = XHCI::EndpointType::BulkIn; break;
case 0x03: endpoint_type = XHCI::EndpointType::InterruptOut; break;
case 0x83: endpoint_type = XHCI::EndpointType::InterruptIn; break;
default: ASSERT_NOT_REACHED();
}
auto& endpoint = m_endpoints[endpoint_id - 1]; auto& endpoint = m_endpoints[endpoint_id - 1];
ASSERT(!endpoint.transfer_ring); ASSERT(!endpoint.transfer_ring);
@ -214,7 +235,7 @@ namespace Kernel
last_valid_endpoint_id = i + 1; last_valid_endpoint_id = i + 1;
endpoint.transfer_ring = TRY(DMARegion::create(m_transfer_ring_trb_count * sizeof(XHCI::TRB))); endpoint.transfer_ring = TRY(DMARegion::create(m_transfer_ring_trb_count * sizeof(XHCI::TRB)));
endpoint.max_packet_size = endpoint_descriptor.wMaxPacketSize & 0x07FF; endpoint.max_packet_size = max_packet_size;
endpoint.dequeue_index = 0; endpoint.dequeue_index = 0;
endpoint.enqueue_index = 0; endpoint.enqueue_index = 0;
endpoint.cycle_bit = 1; endpoint.cycle_bit = 1;
@ -237,23 +258,16 @@ namespace Kernel
slot_context.context_entries = last_valid_endpoint_id; slot_context.context_entries = last_valid_endpoint_id;
// FIXME: 4.5.2 hub // FIXME: 4.5.2 hub
ASSERT(endpoint_descriptor.bEndpointAddress & 0x80);
ASSERT((endpoint_descriptor.bmAttributes & 0x03) == 3);
ASSERT(m_controller.port(m_port_id).revision_major == 2);
const uint32_t max_esit_payload = endpoint_context.max_packet_size * (endpoint_context.max_burst_size + 1);
const uint32_t interval = determine_interval(endpoint_descriptor, m_speed_class);
memset(&endpoint_context, 0, context_size); memset(&endpoint_context, 0, context_size);
endpoint_context.endpoint_type = XHCI::EndpointType::InterruptIn; endpoint_context.endpoint_type = endpoint_type;
endpoint_context.max_packet_size = endpoint.max_packet_size; endpoint_context.max_packet_size = max_packet_size;
endpoint_context.max_burst_size = (endpoint_descriptor.wMaxPacketSize >> 11) & 0x0003; endpoint_context.max_burst_size = max_burst_size;
endpoint_context.mult = 0; endpoint_context.mult = 0;
endpoint_context.error_count = 3; endpoint_context.error_count = error_count;
endpoint_context.tr_dequeue_pointer = endpoint.transfer_ring->paddr() | 1; endpoint_context.tr_dequeue_pointer = endpoint.transfer_ring->paddr() | 1;
endpoint_context.max_esit_payload_lo = max_esit_payload & 0xFFFF; endpoint_context.max_esit_payload_lo = max_esit_payload & 0xFFFF;
endpoint_context.max_esit_payload_hi = max_esit_payload >> 16; endpoint_context.max_esit_payload_hi = max_esit_payload >> 16;
endpoint_context.average_trb_length = max_esit_payload; endpoint_context.average_trb_length = average_trb_length;
endpoint_context.interval = interval; endpoint_context.interval = interval;
} }
@ -264,6 +278,8 @@ namespace Kernel
configure_endpoint.configure_endpoint_command.slot_id = m_slot_id; configure_endpoint.configure_endpoint_command.slot_id = m_slot_id;
TRY(m_controller.send_command(configure_endpoint)); TRY(m_controller.send_command(configure_endpoint));
if (endpoint_type == XHCI::EndpointType::InterruptIn)
{
auto& trb = *reinterpret_cast<volatile XHCI::TRB*>(endpoint.transfer_ring->vaddr()); auto& trb = *reinterpret_cast<volatile XHCI::TRB*>(endpoint.transfer_ring->vaddr());
memset(const_cast<XHCI::TRB*>(&trb), 0, sizeof(XHCI::TRB)); memset(const_cast<XHCI::TRB*>(&trb), 0, sizeof(XHCI::TRB));
trb.normal.trb_type = XHCI::TRBType::Normal; trb.normal.trb_type = XHCI::TRBType::Normal;
@ -277,6 +293,13 @@ namespace Kernel
advance_endpoint_enqueue(endpoint, false); advance_endpoint_enqueue(endpoint, false);
m_controller.doorbell_reg(m_slot_id) = endpoint_id; m_controller.doorbell_reg(m_slot_id) = endpoint_id;
}
else
{
dwarnln("Configured unsupported endpoint {2H}",
(endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03)
);
}
return {}; return {};
} }