Kernel: Add support for bulk endpoints and update endpoint API

USB device now sets its own data buffers for IN/OUT endpoints. This
allows more customization and parallelism as data buffer does not have
to be shared.
This commit is contained in:
Bananymous 2024-11-21 13:44:21 +02:00
parent 857b3e92f8
commit 1253e2a458
6 changed files with 105 additions and 88 deletions

View File

@ -19,7 +19,9 @@ namespace Kernel
USBClassDriver() = default; USBClassDriver() = default;
virtual ~USBClassDriver() = default; virtual ~USBClassDriver() = default;
virtual void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id) = 0; virtual BAN::ErrorOr<void> initialize() { return {}; };
virtual void handle_input_data(size_t byte_count, uint8_t endpoint_id) = 0;
}; };
class USBDevice class USBDevice
@ -64,11 +66,12 @@ namespace Kernel
virtual BAN::ErrorOr<void> initialize_endpoint(const USBEndpointDescriptor&) = 0; virtual BAN::ErrorOr<void> initialize_endpoint(const USBEndpointDescriptor&) = 0;
virtual BAN::ErrorOr<size_t> send_request(const USBDeviceRequest&, paddr_t buffer) = 0; virtual BAN::ErrorOr<size_t> send_request(const USBDeviceRequest&, paddr_t buffer) = 0;
virtual void send_data_buffer(uint8_t endpoint_id, paddr_t buffer, size_t buffer_len) = 0;
static USB::SpeedClass determine_speed_class(uint64_t bits_per_second); static USB::SpeedClass determine_speed_class(uint64_t bits_per_second);
protected: protected:
void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id); void handle_input_data(size_t byte_count, uint8_t endpoint_id);
virtual BAN::ErrorOr<void> initialize_control_endpoint() = 0; virtual BAN::ErrorOr<void> initialize_control_endpoint() = 0;
private: private:

View File

@ -75,15 +75,13 @@ namespace Kernel
}; };
public: public:
static BAN::ErrorOr<BAN::UniqPtr<USBHIDDriver>> create(USBDevice&, const USBDevice::InterfaceDescriptor&); void handle_input_data(size_t byte_count, uint8_t endpoint_id) override;
void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id) override;
private: private:
USBHIDDriver(USBDevice&, const USBDevice::InterfaceDescriptor&); USBHIDDriver(USBDevice&, const USBDevice::InterfaceDescriptor&);
~USBHIDDriver(); ~USBHIDDriver();
BAN::ErrorOr<void> initialize(); BAN::ErrorOr<void> initialize() override;
BAN::ErrorOr<BAN::Vector<DeviceReport>> initializes_device_reports(const BAN::Vector<USBHID::Collection>&); BAN::ErrorOr<BAN::Vector<DeviceReport>> initializes_device_reports(const BAN::Vector<USBHID::Collection>&);
@ -94,6 +92,9 @@ namespace Kernel
bool m_uses_report_id { false }; bool m_uses_report_id { false };
BAN::Vector<DeviceReport> m_device_inputs; BAN::Vector<DeviceReport> m_device_inputs;
uint8_t m_data_endpoint_id = 0;
BAN::UniqPtr<DMARegion> m_data_buffer;
friend class BAN::UniqPtr<USBHIDDriver>; friend class BAN::UniqPtr<USBHIDDriver>;
}; };

View File

@ -27,7 +27,6 @@ namespace Kernel
volatile uint32_t transfer_count { 0 }; volatile uint32_t transfer_count { 0 };
volatile XHCI::TRB completion_trb; volatile XHCI::TRB completion_trb;
BAN::UniqPtr<DMARegion> data_region;
void(XHCIDevice::*callback)(XHCI::TRB); void(XHCIDevice::*callback)(XHCI::TRB);
}; };
@ -36,6 +35,7 @@ namespace Kernel
BAN::ErrorOr<void> initialize_endpoint(const USBEndpointDescriptor&) override; BAN::ErrorOr<void> initialize_endpoint(const USBEndpointDescriptor&) override;
BAN::ErrorOr<size_t> send_request(const USBDeviceRequest&, paddr_t buffer) override; BAN::ErrorOr<size_t> send_request(const USBDeviceRequest&, paddr_t buffer) override;
void send_data_buffer(uint8_t endpoint_id, paddr_t buffer, size_t buffer_size) override;
void on_transfer_event(const volatile XHCI::TRB&); void on_transfer_event(const volatile XHCI::TRB&);
@ -47,7 +47,7 @@ namespace Kernel
~XHCIDevice(); ~XHCIDevice();
BAN::ErrorOr<void> update_actual_max_packet_size(); BAN::ErrorOr<void> update_actual_max_packet_size();
void on_interrupt_endpoint_event(XHCI::TRB); void on_interrupt_or_bulk_endpoint_event(XHCI::TRB);
void advance_endpoint_enqueue(Endpoint&, bool chain); void advance_endpoint_enqueue(Endpoint&, bool chain);

View File

@ -154,7 +154,7 @@ namespace Kernel
dprintln_if(DEBUG_USB, "Found CommunicationAndCDCControl interface"); dprintln_if(DEBUG_USB, "Found CommunicationAndCDCControl interface");
break; break;
case USB::InterfaceBaseClass::HID: case USB::InterfaceBaseClass::HID:
if (auto result = USBHIDDriver::create(*this, interface); !result.is_error()) if (auto result = BAN::UniqPtr<USBHIDDriver>::create(*this, interface); !result.is_error())
TRY(m_class_drivers.push_back(result.release_value())); TRY(m_class_drivers.push_back(result.release_value()));
break; break;
case USB::InterfaceBaseClass::Physical: case USB::InterfaceBaseClass::Physical:
@ -220,6 +220,15 @@ namespace Kernel
} }
} }
for (size_t i = 0; i < m_class_drivers.size(); i++)
{
if (auto ret = m_class_drivers[i]->initialize(); ret.is_error())
{
dwarnln("Could not initialize USB interface {}", ret.error());
m_class_drivers.remove(i--);
}
}
if (!m_class_drivers.empty()) if (!m_class_drivers.empty())
{ {
dprintln("Successfully initialized USB device with {}/{} interfaces", dprintln("Successfully initialized USB device with {}/{} interfaces",
@ -317,10 +326,10 @@ namespace Kernel
return BAN::move(configuration); return BAN::move(configuration);
} }
void USBDevice::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id) void USBDevice::handle_input_data(size_t byte_count, uint8_t endpoint_id)
{ {
for (auto& driver : m_class_drivers) for (auto& driver : m_class_drivers)
driver->handle_input_data(data, endpoint_id); driver->handle_input_data(byte_count, endpoint_id);
} }
USB::SpeedClass USBDevice::determine_speed_class(uint64_t bits_per_second) USB::SpeedClass USBDevice::determine_speed_class(uint64_t bits_per_second)

View File

@ -1,4 +1,5 @@
#include <BAN/ByteSpan.h> #include <BAN/ByteSpan.h>
#include <BAN/ScopeGuard.h>
#include <kernel/FS/DevFS/FileSystem.h> #include <kernel/FS/DevFS/FileSystem.h>
#include <kernel/USB/HID/HIDDriver.h> #include <kernel/USB/HID/HIDDriver.h>
@ -68,13 +69,6 @@ namespace Kernel
static BAN::ErrorOr<BAN::Vector<Collection>> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id); static BAN::ErrorOr<BAN::Vector<Collection>> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id);
BAN::ErrorOr<BAN::UniqPtr<USBHIDDriver>> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface)
{
auto result = TRY(BAN::UniqPtr<USBHIDDriver>::create(device, interface));
TRY(result->initialize());
return result;
}
USBHIDDriver::USBHIDDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface) USBHIDDriver::USBHIDDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface)
: m_device(device) : m_device(device)
, m_interface(interface) , m_interface(interface)
@ -192,7 +186,29 @@ namespace Kernel
m_device_inputs = TRY(initializes_device_reports(collections)); m_device_inputs = TRY(initializes_device_reports(collections));
for (const auto& endpoint : m_interface.endpoints) for (const auto& endpoint : m_interface.endpoints)
TRY(m_device.initialize_endpoint(endpoint.descriptor)); {
const auto& desc = endpoint.descriptor;
if (!(desc.bEndpointAddress & 0x80))
continue;
if ((desc.bmAttributes & 0x03) != 0x03)
continue;
TRY(m_device.initialize_endpoint(desc));
m_data_buffer = TRY(DMARegion::create(desc.wMaxPacketSize & 0x07FF));
m_data_endpoint_id = (desc.bEndpointAddress & 0x0F) * 2 + !!(desc.bEndpointAddress & 0x80);
break;
}
if (m_data_endpoint_id == 0)
{
dwarnln("HID device does not an interrupt IN endpoints");
return BAN::Error::from_errno(EINVAL);
}
m_device.send_data_buffer(m_data_endpoint_id, m_data_buffer->paddr(), m_data_buffer->size());
return {}; return {};
} }
@ -256,23 +272,15 @@ namespace Kernel
return BAN::move(result); return BAN::move(result);
} }
void USBHIDDriver::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id) void USBHIDDriver::handle_input_data(size_t byte_count, uint8_t endpoint_id)
{ {
{ if (m_data_endpoint_id != endpoint_id)
bool found = false; return;
for (const auto& endpoint : m_interface.endpoints)
{ auto data = BAN::ConstByteSpan(reinterpret_cast<uint8_t*>(m_data_buffer->vaddr()), byte_count);
const auto& desc = endpoint.descriptor; BAN::ScopeGuard _([&] {
if (endpoint_id == (desc.bEndpointAddress & 0x0F) * 2 + !!(desc.bEndpointAddress & 0x80)) m_device.send_data_buffer(m_data_endpoint_id, m_data_buffer->paddr(), m_data_buffer->size());
{ });
found = true;
break;
}
}
// If this packet is not for us, skip it
if (!found)
return;
}
if constexpr(DEBUG_USB_HID) if constexpr(DEBUG_USB_HID)
{ {

View File

@ -200,20 +200,20 @@ namespace Kernel
BAN::ErrorOr<void> XHCIDevice::initialize_endpoint(const USBEndpointDescriptor& endpoint_descriptor) BAN::ErrorOr<void> XHCIDevice::initialize_endpoint(const USBEndpointDescriptor& endpoint_descriptor)
{ {
ASSERT(m_controller.port(m_port_id).revision_major == 2); const bool is_control { (endpoint_descriptor.bmAttributes & 0x03) == 0x00 };
const bool is_isoch { (endpoint_descriptor.bmAttributes & 0x03) == 0x01 };
const bool is_bulk { (endpoint_descriptor.bmAttributes & 0x03) == 0x02 };
const bool is_interrupt { (endpoint_descriptor.bmAttributes & 0x03) == 0x03 };
const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); (void)is_control;
const uint32_t max_packet_size = endpoint_descriptor.wMaxPacketSize & 0x07FF; (void)is_isoch;
const uint32_t max_burst_size = (endpoint_descriptor.wMaxPacketSize >> 11) & 0x0003; (void)is_bulk;
const uint32_t max_esit_payload = max_packet_size * (max_burst_size + 1); (void)is_interrupt;
const uint32_t interval = determine_interval(endpoint_descriptor, m_speed_class);
const uint32_t average_trb_length = ((endpoint_descriptor.bmAttributes & 3) == 0b00) ? 8 : max_esit_payload;
const uint32_t error_count = ((endpoint_descriptor.bmAttributes & 3) == 0b01) ? 0 : 3;
XHCI::EndpointType endpoint_type; XHCI::EndpointType endpoint_type;
switch ((endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03)) switch ((endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03))
{ {
case 0x00: case 0x00: ASSERT_NOT_REACHED();
case 0x80: endpoint_type = XHCI::EndpointType::Control; break; case 0x80: endpoint_type = XHCI::EndpointType::Control; break;
case 0x01: endpoint_type = XHCI::EndpointType::IsochOut; break; case 0x01: endpoint_type = XHCI::EndpointType::IsochOut; break;
case 0x81: endpoint_type = XHCI::EndpointType::IsochIn; break; case 0x81: endpoint_type = XHCI::EndpointType::IsochIn; break;
@ -224,6 +224,16 @@ namespace Kernel
default: ASSERT_NOT_REACHED(); default: ASSERT_NOT_REACHED();
} }
// FIXME: Streams
const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80);
const uint32_t max_packet_size = (is_control || is_bulk) ? endpoint_descriptor.wMaxPacketSize : endpoint_descriptor.wMaxPacketSize & 0x07FF;
const uint32_t max_burst_size = (is_control || is_bulk) ? 0 : (endpoint_descriptor.wMaxPacketSize & 0x1800) >> 11;
const uint32_t max_esit_payload = max_packet_size * (max_burst_size + 1);
const uint32_t interval = determine_interval(endpoint_descriptor, m_speed_class);
const uint32_t average_trb_length = (is_control) ? 8 : max_esit_payload;
const uint32_t error_count = (is_isoch) ? 0 : 3;
auto& endpoint = m_endpoints[endpoint_id - 1]; auto& endpoint = m_endpoints[endpoint_id - 1];
ASSERT(!endpoint.transfer_ring); ASSERT(!endpoint.transfer_ring);
@ -237,8 +247,7 @@ namespace Kernel
endpoint.dequeue_index = 0; endpoint.dequeue_index = 0;
endpoint.enqueue_index = 0; endpoint.enqueue_index = 0;
endpoint.cycle_bit = 1; endpoint.cycle_bit = 1;
endpoint.callback = &XHCIDevice::on_interrupt_endpoint_event; endpoint.callback = (is_interrupt || is_bulk) ? &XHCIDevice::on_interrupt_or_bulk_endpoint_event : nullptr;
endpoint.data_region = TRY(DMARegion::create(endpoint.max_packet_size));
memset(reinterpret_cast<void*>(endpoint.transfer_ring->vaddr()), 0, endpoint.transfer_ring->size()); memset(reinterpret_cast<void*>(endpoint.transfer_ring->vaddr()), 0, endpoint.transfer_ring->size());
@ -276,62 +285,27 @@ namespace Kernel
configure_endpoint.configure_endpoint_command.slot_id = m_slot_id; configure_endpoint.configure_endpoint_command.slot_id = m_slot_id;
TRY(m_controller.send_command(configure_endpoint)); TRY(m_controller.send_command(configure_endpoint));
if (endpoint_type == XHCI::EndpointType::InterruptIn)
{
auto& trb = *reinterpret_cast<volatile XHCI::TRB*>(endpoint.transfer_ring->vaddr());
memset(const_cast<XHCI::TRB*>(&trb), 0, sizeof(XHCI::TRB));
trb.normal.trb_type = XHCI::TRBType::Normal;
trb.normal.data_buffer_pointer = endpoint.data_region->paddr();
trb.normal.trb_transfer_length = endpoint.data_region->size();
trb.normal.td_size = 0;
trb.normal.interrupt_target = 0;
trb.normal.cycle_bit = 1;
trb.normal.interrupt_on_completion = 1;
trb.normal.interrupt_on_short_packet = 1;
advance_endpoint_enqueue(endpoint, false);
m_controller.doorbell_reg(m_slot_id) = endpoint_id;
}
else
{
dwarnln("Configured unsupported endpoint {2H}",
(endpoint_descriptor.bEndpointAddress & 0x80) | (endpoint_descriptor.bmAttributes & 0x03)
);
}
return {}; return {};
} }
void XHCIDevice::on_interrupt_endpoint_event(XHCI::TRB trb) void XHCIDevice::on_interrupt_or_bulk_endpoint_event(XHCI::TRB trb)
{ {
ASSERT(trb.trb_type == XHCI::TRBType::TransferEvent); ASSERT(trb.trb_type == XHCI::TRBType::TransferEvent);
if (trb.transfer_event.completion_code != 1 && trb.transfer_event.completion_code != 13) if (trb.transfer_event.completion_code != 1 && trb.transfer_event.completion_code != 13)
{ {
dwarnln("Interrupt endpoint got transfer event with completion code {}", +trb.transfer_event.completion_code); dwarnln("Interrupt or bulk endpoint got transfer event with completion code {}", +trb.transfer_event.completion_code);
return; return;
} }
const uint32_t endpoint_id = trb.transfer_event.endpoint_id; const uint32_t endpoint_id = trb.transfer_event.endpoint_id;
auto& endpoint = m_endpoints[endpoint_id - 1]; auto& endpoint = m_endpoints[endpoint_id - 1];
ASSERT(endpoint.transfer_ring && endpoint.data_region);
const uint32_t transfer_length = endpoint.max_packet_size - trb.transfer_event.trb_transfer_length; const auto* transfer_trb_arr = reinterpret_cast<volatile XHCI::TRB*>(endpoint.transfer_ring->vaddr());
auto received_data = BAN::ConstByteSpan(reinterpret_cast<uint8_t*>(endpoint.data_region->vaddr()), transfer_length); const uint32_t transfer_trb_index = (trb.transfer_event.trb_pointer - endpoint.transfer_ring->paddr()) / sizeof(XHCI::TRB);
handle_input_data(received_data, endpoint_id); const uint32_t original_len = transfer_trb_arr[transfer_trb_index].normal.trb_transfer_length;
auto& new_trb = *reinterpret_cast<volatile XHCI::TRB*>(endpoint.transfer_ring->vaddr() + endpoint.enqueue_index * sizeof(XHCI::TRB)); const uint32_t transfer_length = original_len - trb.transfer_event.trb_transfer_length;
memset(const_cast<XHCI::TRB*>(&new_trb), 0, sizeof(XHCI::TRB)); handle_input_data(transfer_length, endpoint_id);
new_trb.normal.trb_type = XHCI::TRBType::Normal;
new_trb.normal.data_buffer_pointer = endpoint.data_region->paddr();
new_trb.normal.trb_transfer_length = endpoint.max_packet_size;
new_trb.normal.td_size = 0;
new_trb.normal.interrupt_target = 0;
new_trb.normal.cycle_bit = endpoint.cycle_bit;
new_trb.normal.interrupt_on_completion = 1;
new_trb.normal.interrupt_on_short_packet = 1;
advance_endpoint_enqueue(endpoint, false);
m_controller.doorbell_reg(m_slot_id) = endpoint_id;
} }
void XHCIDevice::on_transfer_event(const volatile XHCI::TRB& trb) void XHCIDevice::on_transfer_event(const volatile XHCI::TRB& trb)
@ -495,6 +469,28 @@ namespace Kernel
return endpoint.transfer_count; return endpoint.transfer_count;
} }
void XHCIDevice::send_data_buffer(uint8_t endpoint_id, paddr_t buffer, size_t buffer_len)
{
ASSERT(endpoint_id != 0);
auto& endpoint = m_endpoints[endpoint_id - 1];
ASSERT(buffer_len <= endpoint.max_packet_size);
auto& trb = *reinterpret_cast<volatile XHCI::TRB*>(endpoint.transfer_ring->vaddr() + endpoint.enqueue_index * sizeof(XHCI::TRB));
memset(const_cast<XHCI::TRB*>(&trb), 0, sizeof(XHCI::TRB));
trb.normal.trb_type = XHCI::TRBType::Normal;
trb.normal.data_buffer_pointer = buffer;
trb.normal.trb_transfer_length = buffer_len;
trb.normal.td_size = 0;
trb.normal.interrupt_target = 0;
trb.normal.cycle_bit = endpoint.cycle_bit;
trb.normal.interrupt_on_completion = 1;
trb.normal.interrupt_on_short_packet = 1;
advance_endpoint_enqueue(endpoint, false);
m_controller.doorbell_reg(m_slot_id) = endpoint_id;
}
void XHCIDevice::advance_endpoint_enqueue(Endpoint& endpoint, bool chain) void XHCIDevice::advance_endpoint_enqueue(Endpoint& endpoint, bool chain)
{ {
endpoint.enqueue_index++; endpoint.enqueue_index++;