From 42c3fa24f0cb5c94c2d3d167f81e3b54861b04d7 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Mon, 15 Jul 2024 15:51:07 +0300 Subject: [PATCH] Kernel: Add support for HID Report ID and parsing all collections Only the first top-level collection is used for the device, but that seems to generally be what keyboard and mouse use for input. --- kernel/include/kernel/USB/HID/HIDDriver.h | 7 +- kernel/kernel/USB/HID/HIDDriver.cpp | 131 ++++++++++++++-------- 2 files changed, 88 insertions(+), 50 deletions(-) diff --git a/kernel/include/kernel/USB/HID/HIDDriver.h b/kernel/include/kernel/USB/HID/HIDDriver.h index ef2aaa9ef1..3ec91e5982 100644 --- a/kernel/include/kernel/USB/HID/HIDDriver.h +++ b/kernel/include/kernel/USB/HID/HIDDriver.h @@ -17,6 +17,7 @@ namespace Kernel uint16_t usage_id; Type type; + uint8_t report_id; uint32_t report_count; uint32_t report_size; @@ -77,15 +78,17 @@ namespace Kernel BAN::ErrorOr initialize(); - void forward_collection_inputs(const USBHID::Collection&, BAN::ConstByteSpan& data, size_t bit_offset); + void forward_collection_inputs(const USBHID::Collection&, BAN::Optional report_id, BAN::ConstByteSpan& data, size_t bit_offset); private: USBDevice& m_device; USBDevice::InterfaceDescriptor m_interface; const uint8_t m_interface_index; + bool m_uses_report_id { false }; + uint8_t m_endpoint_id { 0 }; - USBHID::Collection m_collection; + BAN::Vector m_collections; BAN::RefPtr m_hid_device; friend class BAN::UniqPtr; diff --git a/kernel/kernel/USB/HID/HIDDriver.cpp b/kernel/kernel/USB/HID/HIDDriver.cpp index 360263c5f7..19f878b34d 100644 --- a/kernel/kernel/USB/HID/HIDDriver.cpp +++ b/kernel/kernel/USB/HID/HIDDriver.cpp @@ -43,8 +43,8 @@ namespace Kernel BAN::Optional physical_minimum; BAN::Optional physical_maximum; // FIXME: support units + BAN::Optional report_id; BAN::Optional report_size; - // FIXME: support report id BAN::Optional report_count; }; @@ -59,10 +59,10 @@ namespace Kernel using namespace USBHID; #if DUMP_HID_REPORT - static void dump_hid_collection(const Collection& collection, size_t indent); + static void dump_hid_collection(const Collection& collection, size_t indent, bool use_report_id); #endif - static BAN::ErrorOr parse_report_descriptor(BAN::ConstByteSpan report_data); + static BAN::ErrorOr> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id); BAN::ErrorOr> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface, uint8_t interface_index) { @@ -144,7 +144,6 @@ namespace Kernel TRY(m_device.send_request(request, 0)); } - Collection collection {}; const auto& hid_descriptor = *reinterpret_cast(m_interface.misc_descriptors[hid_descriptor_index].data()); dprintln_if(DEBUG_HID, "HID descriptor ({} bytes)", m_interface.misc_descriptors[hid_descriptor_index].size()); @@ -154,7 +153,7 @@ namespace Kernel dprintln_if(DEBUG_HID, " bCountryCode: {}", hid_descriptor.bCountryCode); dprintln_if(DEBUG_HID, " bNumDescriptors: {}", hid_descriptor.bNumDescriptors); - bool report_descriptor_parsed = false; + BAN::Vector collections; for (size_t i = 0; i < hid_descriptor.bNumDescriptors; i++) { auto descriptor = hid_descriptor.descriptors[i]; @@ -164,11 +163,6 @@ namespace Kernel dprintln_if(DEBUG_HID, "Skipping HID descriptor type 0x{2H}", descriptor.bDescriptorType); continue; } - if (report_descriptor_parsed) - { - dwarnln("Multiple report descriptors specified"); - return BAN::Error::from_errno(ENOTSUP); - } if (descriptor.wItemLength > dma_buffer->size()) { @@ -195,31 +189,26 @@ namespace Kernel dprintln_if(DEBUG_HID, "Parsing {} byte report descriptor", +descriptor.wItemLength); auto report_data = BAN::ConstByteSpan(reinterpret_cast(dma_buffer->vaddr()), descriptor.wItemLength); - collection = TRY(parse_report_descriptor(report_data)); - - report_descriptor_parsed = true; + auto new_collections = TRY(parse_report_descriptor(report_data, m_uses_report_id)); + for (auto& collection : new_collections) + TRY(collections.push_back(BAN::move(collection))); } - if (!report_descriptor_parsed) + if (collections.empty()) { - dwarnln("No report descriptors specified"); + dwarnln("No collections specified for HID device"); return BAN::Error::from_errno(EFAULT); } - if (collection.usage_page != 0x01) + // FIXME: Handle other collections? + + if (collections.front().usage_page != 0x01) { dwarnln("Top most collection is not generic desktop page"); return BAN::Error::from_errno(EFAULT); } -#if DUMP_HID_REPORT - { - SpinLockGuard _(Debug::s_debug_lock); - dump_hid_collection(collection, 0); - } -#endif - - switch (collection.usage_id) + switch (collections.front().usage_id) { case 0x02: m_hid_device = TRY(BAN::RefPtr::create()); @@ -230,7 +219,7 @@ namespace Kernel dprintln("Initialized an USB Keyboard"); break; default: - dwarnln("Unsupported generic descript page usage 0x{2H}", collection.usage_id); + dwarnln("Unsupported generic descript page usage 0x{2H}", collections.front().usage_id); return BAN::Error::from_errno(ENOTSUP); } DevFileSystem::get().add_device(m_hid_device); @@ -238,14 +227,14 @@ namespace Kernel const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor; m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); - m_collection = BAN::move(collection); + m_collections = BAN::move(collections); TRY(m_device.initialize_endpoint(endpoint_descriptor)); return {}; } - void USBHIDDriver::forward_collection_inputs(const Collection& collection, BAN::ConstByteSpan& data, size_t bit_offset) + void USBHIDDriver::forward_collection_inputs(const Collection& collection, BAN::Optional report_id, BAN::ConstByteSpan& data, size_t bit_offset) { const auto extract_bits = [data](size_t bit_offset, size_t bit_count, bool as_unsigned) -> int64_t @@ -284,7 +273,7 @@ namespace Kernel { if (entry.has()) { - forward_collection_inputs(entry.get(), data, bit_offset); + forward_collection_inputs(entry.get(), report_id, data, bit_offset); continue; } @@ -292,6 +281,8 @@ namespace Kernel const auto& input = entry.get(); if (input.type != Report::Type::Input) continue; + if (report_id.value_or(input.report_id) != input.report_id) + continue; ASSERT(input.report_size <= 32); @@ -350,21 +341,32 @@ namespace Kernel dprintln_if(DEBUG_HID, "Received {} bytes from endpoint {}: {}", data.size(), endpoint_id, buffer); } + BAN::Optional report_id; + if (m_uses_report_id) + { + report_id = data[0]; + data = data.slice(1); + } + m_hid_device->start_report(); - forward_collection_inputs(m_collection, data, 0); + // FIXME: Handle other collections? + forward_collection_inputs(m_collections.front(), report_id, data, 0); m_hid_device->stop_report(); } - BAN::ErrorOr parse_report_descriptor(BAN::ConstByteSpan report_data) + BAN::ErrorOr> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id) { BAN::Vector global_stack; GlobalState global_state; LocalState local_state; - BAN::Optional result; + BAN::Vector result_stack; BAN::Vector collection_stack; + bool one_has_report_id = false; + bool all_has_report_id = true; + const auto extract_report_item = [&](bool as_unsigned) -> int64_t { @@ -416,6 +418,11 @@ namespace Kernel return BAN::Error::from_errno(EFAULT); } + if (global_state.report_id.has_value()) + one_has_report_id = true; + else + all_has_report_id = false; + const int64_t logical_minimum = global_state.logical_minimum.value(); const int64_t logical_maximum = get_correct_sign( global_state.logical_minimum.value(), @@ -441,6 +448,7 @@ namespace Kernel item.usage_minimum = local_state.usage_minimum.value(); item.usage_maximum = local_state.usage_maximum.value(); item.type = type; + item.report_id = global_state.report_id.value_or(0); item.report_count = global_state.report_count.value(); item.report_size = global_state.report_size.value(); item.logical_minimum = logical_minimum; @@ -459,6 +467,7 @@ namespace Kernel item.usage_minimum = 0; item.usage_maximum = 0; item.type = type; + item.report_id = global_state.report_id.value_or(0); item.report_count = global_state.report_count.value(); item.report_size = global_state.report_size.value(); item.logical_minimum = 0; @@ -471,9 +480,10 @@ namespace Kernel return {}; } - for (size_t i = 0; i < global_state.report_count.value(); i++) + for (size_t i = 0; i < local_state.usage_stack.size(); i++) { - const uint32_t usage = local_state.usage_stack[BAN::Math::min(i, local_state.usage_stack.size() - 1)]; + const uint32_t usage = local_state.usage_stack[i]; + const uint32_t count = (i + 1 < local_state.usage_stack.size()) ? 1 : global_state.report_count.value() - i; Report item; item.usage_page = (usage >> 16) ? (usage >> 16) : global_state.usage_page.value(); @@ -481,7 +491,8 @@ namespace Kernel item.usage_minimum = 0; item.usage_maximum = 0; item.type = type; - item.report_count = 1; + item.report_id = global_state.report_id.value_or(0); + item.report_count = count; item.report_size = global_state.report_size.value(); item.logical_minimum = logical_minimum; item.logical_maximum = logical_maximum; @@ -559,12 +570,7 @@ namespace Kernel } if (collection_stack.size() == 1) { - if (result.has_value()) - { - dwarnln("Multiple top-level collections not supported"); - return BAN::Error::from_errno(ENOTSUP); - } - result = BAN::move(collection_stack.back()); + TRY(result_stack.push_back(BAN::move(collection_stack.back()))); collection_stack.pop_back(); } else @@ -610,8 +616,16 @@ namespace Kernel global_state.report_size = extract_report_item(true); break; case 0b1000: // report id - dwarnln("Report IDs are not supported"); - return BAN::Error::from_errno(ENOTSUP); + { + auto report_id = extract_report_item(true); + if (report_id > 0xFF) + { + dwarnln("Multi-byte report id"); + return BAN::Error::from_errno(EFAULT); + } + global_state.report_id = report_id; + break; + } case 0b1001: // report count global_state.report_count = extract_report_item(true); break; @@ -668,13 +682,28 @@ namespace Kernel report_data = report_data.slice(1 + item_size); } - if (!result.has_value()) + if (result_stack.empty()) { dwarnln("No collection defined in report descriptor"); return BAN::Error::from_errno(EFAULT); } - return result.release_value(); + if (one_has_report_id != all_has_report_id) + { + dwarnln("Some but not all reports have report id"); + return BAN::Error::from_errno(EFAULT); + } + +#if DUMP_HID_REPORT + { + SpinLockGuard _(Debug::s_debug_lock); + for (const auto& collection : result_stack) + dump_hid_collection(collection, 0, one_has_report_id); + } +#endif + + out_use_report_id = one_has_report_id; + return BAN::move(result_stack); } #if DUMP_HID_REPORT @@ -684,7 +713,7 @@ namespace Kernel Debug::putchar(' '); } - static void dump_hid_report(const Report& report, size_t indent) + static void dump_hid_report(const Report& report, size_t indent, bool use_report_id) { const char* report_type = ""; switch (report.type) @@ -696,6 +725,12 @@ namespace Kernel print_indent(indent); BAN::Formatter::println(Debug::putchar, "report {}", report_type); + if (use_report_id) + { + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "report id: {2H}", report.report_id); + } + print_indent(indent + 4); BAN::Formatter::println(Debug::putchar, "usage page: {2H}", report.usage_page); @@ -727,7 +762,7 @@ namespace Kernel BAN::Formatter::println(Debug::putchar, "pmaximum: {}", report.physical_maximum); } - static void dump_hid_collection(const Collection& collection, size_t indent) + static void dump_hid_collection(const Collection& collection, size_t indent, bool use_report_id) { print_indent(indent); BAN::Formatter::println(Debug::putchar, "collection {}", collection.type); @@ -737,9 +772,9 @@ namespace Kernel for (const auto& entry : collection.entries) { if (entry.has()) - dump_hid_collection(entry.get(), indent + 4); + dump_hid_collection(entry.get(), indent + 4, use_report_id); if (entry.has()) - dump_hid_report(entry.get(), indent + 4); + dump_hid_report(entry.get(), indent + 4, use_report_id); } } #endif