Kernel: Add support for HID Report ID and parsing all collections

Only the first top-level collection is used for the device, but that
seems to generally be what keyboard and mouse use for input.
This commit is contained in:
Bananymous 2024-07-15 15:51:07 +03:00
parent 60b396fee5
commit 42c3fa24f0
2 changed files with 88 additions and 50 deletions

View File

@ -17,6 +17,7 @@ namespace Kernel
uint16_t usage_id;
Type type;
uint8_t report_id;
uint32_t report_count;
uint32_t report_size;
@ -77,15 +78,17 @@ namespace Kernel
BAN::ErrorOr<void> initialize();
void forward_collection_inputs(const USBHID::Collection&, BAN::ConstByteSpan& data, size_t bit_offset);
void forward_collection_inputs(const USBHID::Collection&, BAN::Optional<uint8_t> report_id, BAN::ConstByteSpan& data, size_t bit_offset);
private:
USBDevice& m_device;
USBDevice::InterfaceDescriptor m_interface;
const uint8_t m_interface_index;
bool m_uses_report_id { false };
uint8_t m_endpoint_id { 0 };
USBHID::Collection m_collection;
BAN::Vector<USBHID::Collection> m_collections;
BAN::RefPtr<USBHIDDevice> m_hid_device;
friend class BAN::UniqPtr<USBHIDDriver>;

View File

@ -43,8 +43,8 @@ namespace Kernel
BAN::Optional<int32_t> physical_minimum;
BAN::Optional<int32_t> physical_maximum;
// FIXME: support units
BAN::Optional<uint8_t> report_id;
BAN::Optional<uint32_t> report_size;
// FIXME: support report id
BAN::Optional<uint32_t> report_count;
};
@ -59,10 +59,10 @@ namespace Kernel
using namespace USBHID;
#if DUMP_HID_REPORT
static void dump_hid_collection(const Collection& collection, size_t indent);
static void dump_hid_collection(const Collection& collection, size_t indent, bool use_report_id);
#endif
static BAN::ErrorOr<Collection> parse_report_descriptor(BAN::ConstByteSpan report_data);
static BAN::ErrorOr<BAN::Vector<Collection>> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id);
BAN::ErrorOr<BAN::UniqPtr<USBHIDDriver>> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface, uint8_t interface_index)
{
@ -144,7 +144,6 @@ namespace Kernel
TRY(m_device.send_request(request, 0));
}
Collection collection {};
const auto& hid_descriptor = *reinterpret_cast<const HIDDescriptor*>(m_interface.misc_descriptors[hid_descriptor_index].data());
dprintln_if(DEBUG_HID, "HID descriptor ({} bytes)", m_interface.misc_descriptors[hid_descriptor_index].size());
@ -154,7 +153,7 @@ namespace Kernel
dprintln_if(DEBUG_HID, " bCountryCode: {}", hid_descriptor.bCountryCode);
dprintln_if(DEBUG_HID, " bNumDescriptors: {}", hid_descriptor.bNumDescriptors);
bool report_descriptor_parsed = false;
BAN::Vector<Collection> collections;
for (size_t i = 0; i < hid_descriptor.bNumDescriptors; i++)
{
auto descriptor = hid_descriptor.descriptors[i];
@ -164,11 +163,6 @@ namespace Kernel
dprintln_if(DEBUG_HID, "Skipping HID descriptor type 0x{2H}", descriptor.bDescriptorType);
continue;
}
if (report_descriptor_parsed)
{
dwarnln("Multiple report descriptors specified");
return BAN::Error::from_errno(ENOTSUP);
}
if (descriptor.wItemLength > dma_buffer->size())
{
@ -195,31 +189,26 @@ namespace Kernel
dprintln_if(DEBUG_HID, "Parsing {} byte report descriptor", +descriptor.wItemLength);
auto report_data = BAN::ConstByteSpan(reinterpret_cast<uint8_t*>(dma_buffer->vaddr()), descriptor.wItemLength);
collection = TRY(parse_report_descriptor(report_data));
report_descriptor_parsed = true;
auto new_collections = TRY(parse_report_descriptor(report_data, m_uses_report_id));
for (auto& collection : new_collections)
TRY(collections.push_back(BAN::move(collection)));
}
if (!report_descriptor_parsed)
if (collections.empty())
{
dwarnln("No report descriptors specified");
dwarnln("No collections specified for HID device");
return BAN::Error::from_errno(EFAULT);
}
if (collection.usage_page != 0x01)
// FIXME: Handle other collections?
if (collections.front().usage_page != 0x01)
{
dwarnln("Top most collection is not generic desktop page");
return BAN::Error::from_errno(EFAULT);
}
#if DUMP_HID_REPORT
{
SpinLockGuard _(Debug::s_debug_lock);
dump_hid_collection(collection, 0);
}
#endif
switch (collection.usage_id)
switch (collections.front().usage_id)
{
case 0x02:
m_hid_device = TRY(BAN::RefPtr<USBMouse>::create());
@ -230,7 +219,7 @@ namespace Kernel
dprintln("Initialized an USB Keyboard");
break;
default:
dwarnln("Unsupported generic descript page usage 0x{2H}", collection.usage_id);
dwarnln("Unsupported generic descript page usage 0x{2H}", collections.front().usage_id);
return BAN::Error::from_errno(ENOTSUP);
}
DevFileSystem::get().add_device(m_hid_device);
@ -238,14 +227,14 @@ namespace Kernel
const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor;
m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80);
m_collection = BAN::move(collection);
m_collections = BAN::move(collections);
TRY(m_device.initialize_endpoint(endpoint_descriptor));
return {};
}
void USBHIDDriver::forward_collection_inputs(const Collection& collection, BAN::ConstByteSpan& data, size_t bit_offset)
void USBHIDDriver::forward_collection_inputs(const Collection& collection, BAN::Optional<uint8_t> report_id, BAN::ConstByteSpan& data, size_t bit_offset)
{
const auto extract_bits =
[data](size_t bit_offset, size_t bit_count, bool as_unsigned) -> int64_t
@ -284,7 +273,7 @@ namespace Kernel
{
if (entry.has<Collection>())
{
forward_collection_inputs(entry.get<Collection>(), data, bit_offset);
forward_collection_inputs(entry.get<Collection>(), report_id, data, bit_offset);
continue;
}
@ -292,6 +281,8 @@ namespace Kernel
const auto& input = entry.get<Report>();
if (input.type != Report::Type::Input)
continue;
if (report_id.value_or(input.report_id) != input.report_id)
continue;
ASSERT(input.report_size <= 32);
@ -350,21 +341,32 @@ namespace Kernel
dprintln_if(DEBUG_HID, "Received {} bytes from endpoint {}: {}", data.size(), endpoint_id, buffer);
}
BAN::Optional<uint8_t> report_id;
if (m_uses_report_id)
{
report_id = data[0];
data = data.slice(1);
}
m_hid_device->start_report();
forward_collection_inputs(m_collection, data, 0);
// FIXME: Handle other collections?
forward_collection_inputs(m_collections.front(), report_id, data, 0);
m_hid_device->stop_report();
}
BAN::ErrorOr<Collection> parse_report_descriptor(BAN::ConstByteSpan report_data)
BAN::ErrorOr<BAN::Vector<Collection>> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id)
{
BAN::Vector<GlobalState> global_stack;
GlobalState global_state;
LocalState local_state;
BAN::Optional<Collection> result;
BAN::Vector<Collection> result_stack;
BAN::Vector<Collection> collection_stack;
bool one_has_report_id = false;
bool all_has_report_id = true;
const auto extract_report_item =
[&](bool as_unsigned) -> int64_t
{
@ -416,6 +418,11 @@ namespace Kernel
return BAN::Error::from_errno(EFAULT);
}
if (global_state.report_id.has_value())
one_has_report_id = true;
else
all_has_report_id = false;
const int64_t logical_minimum = global_state.logical_minimum.value();
const int64_t logical_maximum = get_correct_sign(
global_state.logical_minimum.value(),
@ -441,6 +448,7 @@ namespace Kernel
item.usage_minimum = local_state.usage_minimum.value();
item.usage_maximum = local_state.usage_maximum.value();
item.type = type;
item.report_id = global_state.report_id.value_or(0);
item.report_count = global_state.report_count.value();
item.report_size = global_state.report_size.value();
item.logical_minimum = logical_minimum;
@ -459,6 +467,7 @@ namespace Kernel
item.usage_minimum = 0;
item.usage_maximum = 0;
item.type = type;
item.report_id = global_state.report_id.value_or(0);
item.report_count = global_state.report_count.value();
item.report_size = global_state.report_size.value();
item.logical_minimum = 0;
@ -471,9 +480,10 @@ namespace Kernel
return {};
}
for (size_t i = 0; i < global_state.report_count.value(); i++)
for (size_t i = 0; i < local_state.usage_stack.size(); i++)
{
const uint32_t usage = local_state.usage_stack[BAN::Math::min<size_t>(i, local_state.usage_stack.size() - 1)];
const uint32_t usage = local_state.usage_stack[i];
const uint32_t count = (i + 1 < local_state.usage_stack.size()) ? 1 : global_state.report_count.value() - i;
Report item;
item.usage_page = (usage >> 16) ? (usage >> 16) : global_state.usage_page.value();
@ -481,7 +491,8 @@ namespace Kernel
item.usage_minimum = 0;
item.usage_maximum = 0;
item.type = type;
item.report_count = 1;
item.report_id = global_state.report_id.value_or(0);
item.report_count = count;
item.report_size = global_state.report_size.value();
item.logical_minimum = logical_minimum;
item.logical_maximum = logical_maximum;
@ -559,12 +570,7 @@ namespace Kernel
}
if (collection_stack.size() == 1)
{
if (result.has_value())
{
dwarnln("Multiple top-level collections not supported");
return BAN::Error::from_errno(ENOTSUP);
}
result = BAN::move(collection_stack.back());
TRY(result_stack.push_back(BAN::move(collection_stack.back())));
collection_stack.pop_back();
}
else
@ -610,8 +616,16 @@ namespace Kernel
global_state.report_size = extract_report_item(true);
break;
case 0b1000: // report id
dwarnln("Report IDs are not supported");
return BAN::Error::from_errno(ENOTSUP);
{
auto report_id = extract_report_item(true);
if (report_id > 0xFF)
{
dwarnln("Multi-byte report id");
return BAN::Error::from_errno(EFAULT);
}
global_state.report_id = report_id;
break;
}
case 0b1001: // report count
global_state.report_count = extract_report_item(true);
break;
@ -668,13 +682,28 @@ namespace Kernel
report_data = report_data.slice(1 + item_size);
}
if (!result.has_value())
if (result_stack.empty())
{
dwarnln("No collection defined in report descriptor");
return BAN::Error::from_errno(EFAULT);
}
return result.release_value();
if (one_has_report_id != all_has_report_id)
{
dwarnln("Some but not all reports have report id");
return BAN::Error::from_errno(EFAULT);
}
#if DUMP_HID_REPORT
{
SpinLockGuard _(Debug::s_debug_lock);
for (const auto& collection : result_stack)
dump_hid_collection(collection, 0, one_has_report_id);
}
#endif
out_use_report_id = one_has_report_id;
return BAN::move(result_stack);
}
#if DUMP_HID_REPORT
@ -684,7 +713,7 @@ namespace Kernel
Debug::putchar(' ');
}
static void dump_hid_report(const Report& report, size_t indent)
static void dump_hid_report(const Report& report, size_t indent, bool use_report_id)
{
const char* report_type = "";
switch (report.type)
@ -696,6 +725,12 @@ namespace Kernel
print_indent(indent);
BAN::Formatter::println(Debug::putchar, "report {}", report_type);
if (use_report_id)
{
print_indent(indent + 4);
BAN::Formatter::println(Debug::putchar, "report id: {2H}", report.report_id);
}
print_indent(indent + 4);
BAN::Formatter::println(Debug::putchar, "usage page: {2H}", report.usage_page);
@ -727,7 +762,7 @@ namespace Kernel
BAN::Formatter::println(Debug::putchar, "pmaximum: {}", report.physical_maximum);
}
static void dump_hid_collection(const Collection& collection, size_t indent)
static void dump_hid_collection(const Collection& collection, size_t indent, bool use_report_id)
{
print_indent(indent);
BAN::Formatter::println(Debug::putchar, "collection {}", collection.type);
@ -737,9 +772,9 @@ namespace Kernel
for (const auto& entry : collection.entries)
{
if (entry.has<Collection>())
dump_hid_collection(entry.get<Collection>(), indent + 4);
dump_hid_collection(entry.get<Collection>(), indent + 4, use_report_id);
if (entry.has<Report>())
dump_hid_report(entry.get<Report>(), indent + 4);
dump_hid_report(entry.get<Report>(), indent + 4, use_report_id);
}
}
#endif