Kernel: Add support for HID Report ID and parsing all collections

Only the first top-level collection is used for the device, but that
seems to generally be what keyboard and mouse use for input.
This commit is contained in:
Bananymous 2024-07-15 15:51:07 +03:00
parent 60b396fee5
commit 42c3fa24f0
2 changed files with 88 additions and 50 deletions

View File

@ -17,6 +17,7 @@ namespace Kernel
uint16_t usage_id; uint16_t usage_id;
Type type; Type type;
uint8_t report_id;
uint32_t report_count; uint32_t report_count;
uint32_t report_size; uint32_t report_size;
@ -77,15 +78,17 @@ namespace Kernel
BAN::ErrorOr<void> initialize(); BAN::ErrorOr<void> initialize();
void forward_collection_inputs(const USBHID::Collection&, BAN::ConstByteSpan& data, size_t bit_offset); void forward_collection_inputs(const USBHID::Collection&, BAN::Optional<uint8_t> report_id, BAN::ConstByteSpan& data, size_t bit_offset);
private: private:
USBDevice& m_device; USBDevice& m_device;
USBDevice::InterfaceDescriptor m_interface; USBDevice::InterfaceDescriptor m_interface;
const uint8_t m_interface_index; const uint8_t m_interface_index;
bool m_uses_report_id { false };
uint8_t m_endpoint_id { 0 }; uint8_t m_endpoint_id { 0 };
USBHID::Collection m_collection; BAN::Vector<USBHID::Collection> m_collections;
BAN::RefPtr<USBHIDDevice> m_hid_device; BAN::RefPtr<USBHIDDevice> m_hid_device;
friend class BAN::UniqPtr<USBHIDDriver>; friend class BAN::UniqPtr<USBHIDDriver>;

View File

@ -43,8 +43,8 @@ namespace Kernel
BAN::Optional<int32_t> physical_minimum; BAN::Optional<int32_t> physical_minimum;
BAN::Optional<int32_t> physical_maximum; BAN::Optional<int32_t> physical_maximum;
// FIXME: support units // FIXME: support units
BAN::Optional<uint8_t> report_id;
BAN::Optional<uint32_t> report_size; BAN::Optional<uint32_t> report_size;
// FIXME: support report id
BAN::Optional<uint32_t> report_count; BAN::Optional<uint32_t> report_count;
}; };
@ -59,10 +59,10 @@ namespace Kernel
using namespace USBHID; using namespace USBHID;
#if DUMP_HID_REPORT #if DUMP_HID_REPORT
static void dump_hid_collection(const Collection& collection, size_t indent); static void dump_hid_collection(const Collection& collection, size_t indent, bool use_report_id);
#endif #endif
static BAN::ErrorOr<Collection> parse_report_descriptor(BAN::ConstByteSpan report_data); static BAN::ErrorOr<BAN::Vector<Collection>> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id);
BAN::ErrorOr<BAN::UniqPtr<USBHIDDriver>> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface, uint8_t interface_index) BAN::ErrorOr<BAN::UniqPtr<USBHIDDriver>> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface, uint8_t interface_index)
{ {
@ -144,7 +144,6 @@ namespace Kernel
TRY(m_device.send_request(request, 0)); TRY(m_device.send_request(request, 0));
} }
Collection collection {};
const auto& hid_descriptor = *reinterpret_cast<const HIDDescriptor*>(m_interface.misc_descriptors[hid_descriptor_index].data()); const auto& hid_descriptor = *reinterpret_cast<const HIDDescriptor*>(m_interface.misc_descriptors[hid_descriptor_index].data());
dprintln_if(DEBUG_HID, "HID descriptor ({} bytes)", m_interface.misc_descriptors[hid_descriptor_index].size()); dprintln_if(DEBUG_HID, "HID descriptor ({} bytes)", m_interface.misc_descriptors[hid_descriptor_index].size());
@ -154,7 +153,7 @@ namespace Kernel
dprintln_if(DEBUG_HID, " bCountryCode: {}", hid_descriptor.bCountryCode); dprintln_if(DEBUG_HID, " bCountryCode: {}", hid_descriptor.bCountryCode);
dprintln_if(DEBUG_HID, " bNumDescriptors: {}", hid_descriptor.bNumDescriptors); dprintln_if(DEBUG_HID, " bNumDescriptors: {}", hid_descriptor.bNumDescriptors);
bool report_descriptor_parsed = false; BAN::Vector<Collection> collections;
for (size_t i = 0; i < hid_descriptor.bNumDescriptors; i++) for (size_t i = 0; i < hid_descriptor.bNumDescriptors; i++)
{ {
auto descriptor = hid_descriptor.descriptors[i]; auto descriptor = hid_descriptor.descriptors[i];
@ -164,11 +163,6 @@ namespace Kernel
dprintln_if(DEBUG_HID, "Skipping HID descriptor type 0x{2H}", descriptor.bDescriptorType); dprintln_if(DEBUG_HID, "Skipping HID descriptor type 0x{2H}", descriptor.bDescriptorType);
continue; continue;
} }
if (report_descriptor_parsed)
{
dwarnln("Multiple report descriptors specified");
return BAN::Error::from_errno(ENOTSUP);
}
if (descriptor.wItemLength > dma_buffer->size()) if (descriptor.wItemLength > dma_buffer->size())
{ {
@ -195,31 +189,26 @@ namespace Kernel
dprintln_if(DEBUG_HID, "Parsing {} byte report descriptor", +descriptor.wItemLength); dprintln_if(DEBUG_HID, "Parsing {} byte report descriptor", +descriptor.wItemLength);
auto report_data = BAN::ConstByteSpan(reinterpret_cast<uint8_t*>(dma_buffer->vaddr()), descriptor.wItemLength); auto report_data = BAN::ConstByteSpan(reinterpret_cast<uint8_t*>(dma_buffer->vaddr()), descriptor.wItemLength);
collection = TRY(parse_report_descriptor(report_data)); auto new_collections = TRY(parse_report_descriptor(report_data, m_uses_report_id));
for (auto& collection : new_collections)
report_descriptor_parsed = true; TRY(collections.push_back(BAN::move(collection)));
} }
if (!report_descriptor_parsed) if (collections.empty())
{ {
dwarnln("No report descriptors specified"); dwarnln("No collections specified for HID device");
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
} }
if (collection.usage_page != 0x01) // FIXME: Handle other collections?
if (collections.front().usage_page != 0x01)
{ {
dwarnln("Top most collection is not generic desktop page"); dwarnln("Top most collection is not generic desktop page");
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
} }
#if DUMP_HID_REPORT switch (collections.front().usage_id)
{
SpinLockGuard _(Debug::s_debug_lock);
dump_hid_collection(collection, 0);
}
#endif
switch (collection.usage_id)
{ {
case 0x02: case 0x02:
m_hid_device = TRY(BAN::RefPtr<USBMouse>::create()); m_hid_device = TRY(BAN::RefPtr<USBMouse>::create());
@ -230,7 +219,7 @@ namespace Kernel
dprintln("Initialized an USB Keyboard"); dprintln("Initialized an USB Keyboard");
break; break;
default: default:
dwarnln("Unsupported generic descript page usage 0x{2H}", collection.usage_id); dwarnln("Unsupported generic descript page usage 0x{2H}", collections.front().usage_id);
return BAN::Error::from_errno(ENOTSUP); return BAN::Error::from_errno(ENOTSUP);
} }
DevFileSystem::get().add_device(m_hid_device); DevFileSystem::get().add_device(m_hid_device);
@ -238,14 +227,14 @@ namespace Kernel
const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor; const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor;
m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80);
m_collection = BAN::move(collection); m_collections = BAN::move(collections);
TRY(m_device.initialize_endpoint(endpoint_descriptor)); TRY(m_device.initialize_endpoint(endpoint_descriptor));
return {}; return {};
} }
void USBHIDDriver::forward_collection_inputs(const Collection& collection, BAN::ConstByteSpan& data, size_t bit_offset) void USBHIDDriver::forward_collection_inputs(const Collection& collection, BAN::Optional<uint8_t> report_id, BAN::ConstByteSpan& data, size_t bit_offset)
{ {
const auto extract_bits = const auto extract_bits =
[data](size_t bit_offset, size_t bit_count, bool as_unsigned) -> int64_t [data](size_t bit_offset, size_t bit_count, bool as_unsigned) -> int64_t
@ -284,7 +273,7 @@ namespace Kernel
{ {
if (entry.has<Collection>()) if (entry.has<Collection>())
{ {
forward_collection_inputs(entry.get<Collection>(), data, bit_offset); forward_collection_inputs(entry.get<Collection>(), report_id, data, bit_offset);
continue; continue;
} }
@ -292,6 +281,8 @@ namespace Kernel
const auto& input = entry.get<Report>(); const auto& input = entry.get<Report>();
if (input.type != Report::Type::Input) if (input.type != Report::Type::Input)
continue; continue;
if (report_id.value_or(input.report_id) != input.report_id)
continue;
ASSERT(input.report_size <= 32); ASSERT(input.report_size <= 32);
@ -350,21 +341,32 @@ namespace Kernel
dprintln_if(DEBUG_HID, "Received {} bytes from endpoint {}: {}", data.size(), endpoint_id, buffer); dprintln_if(DEBUG_HID, "Received {} bytes from endpoint {}: {}", data.size(), endpoint_id, buffer);
} }
BAN::Optional<uint8_t> report_id;
if (m_uses_report_id)
{
report_id = data[0];
data = data.slice(1);
}
m_hid_device->start_report(); m_hid_device->start_report();
forward_collection_inputs(m_collection, data, 0); // FIXME: Handle other collections?
forward_collection_inputs(m_collections.front(), report_id, data, 0);
m_hid_device->stop_report(); m_hid_device->stop_report();
} }
BAN::ErrorOr<Collection> parse_report_descriptor(BAN::ConstByteSpan report_data) BAN::ErrorOr<BAN::Vector<Collection>> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id)
{ {
BAN::Vector<GlobalState> global_stack; BAN::Vector<GlobalState> global_stack;
GlobalState global_state; GlobalState global_state;
LocalState local_state; LocalState local_state;
BAN::Optional<Collection> result; BAN::Vector<Collection> result_stack;
BAN::Vector<Collection> collection_stack; BAN::Vector<Collection> collection_stack;
bool one_has_report_id = false;
bool all_has_report_id = true;
const auto extract_report_item = const auto extract_report_item =
[&](bool as_unsigned) -> int64_t [&](bool as_unsigned) -> int64_t
{ {
@ -416,6 +418,11 @@ namespace Kernel
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
} }
if (global_state.report_id.has_value())
one_has_report_id = true;
else
all_has_report_id = false;
const int64_t logical_minimum = global_state.logical_minimum.value(); const int64_t logical_minimum = global_state.logical_minimum.value();
const int64_t logical_maximum = get_correct_sign( const int64_t logical_maximum = get_correct_sign(
global_state.logical_minimum.value(), global_state.logical_minimum.value(),
@ -441,6 +448,7 @@ namespace Kernel
item.usage_minimum = local_state.usage_minimum.value(); item.usage_minimum = local_state.usage_minimum.value();
item.usage_maximum = local_state.usage_maximum.value(); item.usage_maximum = local_state.usage_maximum.value();
item.type = type; item.type = type;
item.report_id = global_state.report_id.value_or(0);
item.report_count = global_state.report_count.value(); item.report_count = global_state.report_count.value();
item.report_size = global_state.report_size.value(); item.report_size = global_state.report_size.value();
item.logical_minimum = logical_minimum; item.logical_minimum = logical_minimum;
@ -459,6 +467,7 @@ namespace Kernel
item.usage_minimum = 0; item.usage_minimum = 0;
item.usage_maximum = 0; item.usage_maximum = 0;
item.type = type; item.type = type;
item.report_id = global_state.report_id.value_or(0);
item.report_count = global_state.report_count.value(); item.report_count = global_state.report_count.value();
item.report_size = global_state.report_size.value(); item.report_size = global_state.report_size.value();
item.logical_minimum = 0; item.logical_minimum = 0;
@ -471,9 +480,10 @@ namespace Kernel
return {}; return {};
} }
for (size_t i = 0; i < global_state.report_count.value(); i++) for (size_t i = 0; i < local_state.usage_stack.size(); i++)
{ {
const uint32_t usage = local_state.usage_stack[BAN::Math::min<size_t>(i, local_state.usage_stack.size() - 1)]; const uint32_t usage = local_state.usage_stack[i];
const uint32_t count = (i + 1 < local_state.usage_stack.size()) ? 1 : global_state.report_count.value() - i;
Report item; Report item;
item.usage_page = (usage >> 16) ? (usage >> 16) : global_state.usage_page.value(); item.usage_page = (usage >> 16) ? (usage >> 16) : global_state.usage_page.value();
@ -481,7 +491,8 @@ namespace Kernel
item.usage_minimum = 0; item.usage_minimum = 0;
item.usage_maximum = 0; item.usage_maximum = 0;
item.type = type; item.type = type;
item.report_count = 1; item.report_id = global_state.report_id.value_or(0);
item.report_count = count;
item.report_size = global_state.report_size.value(); item.report_size = global_state.report_size.value();
item.logical_minimum = logical_minimum; item.logical_minimum = logical_minimum;
item.logical_maximum = logical_maximum; item.logical_maximum = logical_maximum;
@ -559,12 +570,7 @@ namespace Kernel
} }
if (collection_stack.size() == 1) if (collection_stack.size() == 1)
{ {
if (result.has_value()) TRY(result_stack.push_back(BAN::move(collection_stack.back())));
{
dwarnln("Multiple top-level collections not supported");
return BAN::Error::from_errno(ENOTSUP);
}
result = BAN::move(collection_stack.back());
collection_stack.pop_back(); collection_stack.pop_back();
} }
else else
@ -610,8 +616,16 @@ namespace Kernel
global_state.report_size = extract_report_item(true); global_state.report_size = extract_report_item(true);
break; break;
case 0b1000: // report id case 0b1000: // report id
dwarnln("Report IDs are not supported"); {
return BAN::Error::from_errno(ENOTSUP); auto report_id = extract_report_item(true);
if (report_id > 0xFF)
{
dwarnln("Multi-byte report id");
return BAN::Error::from_errno(EFAULT);
}
global_state.report_id = report_id;
break;
}
case 0b1001: // report count case 0b1001: // report count
global_state.report_count = extract_report_item(true); global_state.report_count = extract_report_item(true);
break; break;
@ -668,13 +682,28 @@ namespace Kernel
report_data = report_data.slice(1 + item_size); report_data = report_data.slice(1 + item_size);
} }
if (!result.has_value()) if (result_stack.empty())
{ {
dwarnln("No collection defined in report descriptor"); dwarnln("No collection defined in report descriptor");
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
} }
return result.release_value(); if (one_has_report_id != all_has_report_id)
{
dwarnln("Some but not all reports have report id");
return BAN::Error::from_errno(EFAULT);
}
#if DUMP_HID_REPORT
{
SpinLockGuard _(Debug::s_debug_lock);
for (const auto& collection : result_stack)
dump_hid_collection(collection, 0, one_has_report_id);
}
#endif
out_use_report_id = one_has_report_id;
return BAN::move(result_stack);
} }
#if DUMP_HID_REPORT #if DUMP_HID_REPORT
@ -684,7 +713,7 @@ namespace Kernel
Debug::putchar(' '); Debug::putchar(' ');
} }
static void dump_hid_report(const Report& report, size_t indent) static void dump_hid_report(const Report& report, size_t indent, bool use_report_id)
{ {
const char* report_type = ""; const char* report_type = "";
switch (report.type) switch (report.type)
@ -696,6 +725,12 @@ namespace Kernel
print_indent(indent); print_indent(indent);
BAN::Formatter::println(Debug::putchar, "report {}", report_type); BAN::Formatter::println(Debug::putchar, "report {}", report_type);
if (use_report_id)
{
print_indent(indent + 4);
BAN::Formatter::println(Debug::putchar, "report id: {2H}", report.report_id);
}
print_indent(indent + 4); print_indent(indent + 4);
BAN::Formatter::println(Debug::putchar, "usage page: {2H}", report.usage_page); BAN::Formatter::println(Debug::putchar, "usage page: {2H}", report.usage_page);
@ -727,7 +762,7 @@ namespace Kernel
BAN::Formatter::println(Debug::putchar, "pmaximum: {}", report.physical_maximum); BAN::Formatter::println(Debug::putchar, "pmaximum: {}", report.physical_maximum);
} }
static void dump_hid_collection(const Collection& collection, size_t indent) static void dump_hid_collection(const Collection& collection, size_t indent, bool use_report_id)
{ {
print_indent(indent); print_indent(indent);
BAN::Formatter::println(Debug::putchar, "collection {}", collection.type); BAN::Formatter::println(Debug::putchar, "collection {}", collection.type);
@ -737,9 +772,9 @@ namespace Kernel
for (const auto& entry : collection.entries) for (const auto& entry : collection.entries)
{ {
if (entry.has<Collection>()) if (entry.has<Collection>())
dump_hid_collection(entry.get<Collection>(), indent + 4); dump_hid_collection(entry.get<Collection>(), indent + 4, use_report_id);
if (entry.has<Report>()) if (entry.has<Report>())
dump_hid_report(entry.get<Report>(), indent + 4); dump_hid_report(entry.get<Report>(), indent + 4, use_report_id);
} }
} }
#endif #endif