Kernel: Update USB HID code to support multiple top-level collections

This allows me to use my laptops own keyboard!
This commit is contained in:
Bananymous 2024-07-16 00:01:53 +03:00
parent a5cb4057f9
commit a60b460701
2 changed files with 142 additions and 109 deletions

View File

@ -67,6 +67,13 @@ namespace Kernel
BAN_NON_COPYABLE(USBHIDDriver); BAN_NON_COPYABLE(USBHIDDriver);
BAN_NON_MOVABLE(USBHIDDriver); BAN_NON_MOVABLE(USBHIDDriver);
public:
struct DeviceReport
{
BAN::Vector<USBHID::Report> inputs;
BAN::RefPtr<USBHIDDevice> device;
};
public: public:
static BAN::ErrorOr<BAN::UniqPtr<USBHIDDriver>> create(USBDevice&, const USBDevice::InterfaceDescriptor&, uint8_t interface_index); static BAN::ErrorOr<BAN::UniqPtr<USBHIDDriver>> create(USBDevice&, const USBDevice::InterfaceDescriptor&, uint8_t interface_index);
@ -78,7 +85,7 @@ namespace Kernel
BAN::ErrorOr<void> initialize(); BAN::ErrorOr<void> initialize();
void forward_collection_inputs(const USBHID::Collection&, BAN::Optional<uint8_t> report_id, BAN::ConstByteSpan& data, size_t bit_offset); BAN::ErrorOr<BAN::Vector<DeviceReport>> initializes_device_reports(const BAN::Vector<USBHID::Collection>&);
private: private:
USBDevice& m_device; USBDevice& m_device;
@ -88,8 +95,7 @@ namespace Kernel
bool m_uses_report_id { false }; bool m_uses_report_id { false };
uint8_t m_endpoint_id { 0 }; uint8_t m_endpoint_id { 0 };
BAN::Vector<USBHID::Collection> m_collections; BAN::Vector<DeviceReport> m_device_inputs;
BAN::RefPtr<USBHIDDevice> m_hid_device;
friend class BAN::UniqPtr<USBHIDDriver>; friend class BAN::UniqPtr<USBHIDDriver>;
}; };

View File

@ -79,8 +79,9 @@ namespace Kernel
USBHIDDriver::~USBHIDDriver() USBHIDDriver::~USBHIDDriver()
{ {
if (m_hid_device) for (auto& device_input : m_device_inputs)
DevFileSystem::get().remove_device(m_hid_device); if (device_input.device)
DevFileSystem::get().remove_device(device_input.device);
} }
BAN::ErrorOr<void> USBHIDDriver::initialize() BAN::ErrorOr<void> USBHIDDriver::initialize()
@ -203,44 +204,99 @@ namespace Kernel
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
} }
// FIXME: Handle other collections? m_device_inputs = TRY(initializes_device_reports(collections));
if (collections.front().usage_page != 0x01)
{
dwarnln("Top most collection is not generic desktop page");
return BAN::Error::from_errno(EFAULT);
}
switch (collections.front().usage_id)
{
case 0x02:
m_hid_device = TRY(BAN::RefPtr<USBMouse>::create());
dprintln("Initialized an USB Mouse");
break;
case 0x06:
m_hid_device = TRY(BAN::RefPtr<USBKeyboard>::create());
dprintln("Initialized an USB Keyboard");
break;
default:
dwarnln("Unsupported generic descript page usage 0x{2H}", collections.front().usage_id);
return BAN::Error::from_errno(ENOTSUP);
}
DevFileSystem::get().add_device(m_hid_device);
const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor; const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor;
m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80);
m_collections = BAN::move(collections);
TRY(m_device.initialize_endpoint(endpoint_descriptor)); TRY(m_device.initialize_endpoint(endpoint_descriptor));
return {}; return {};
} }
void USBHIDDriver::forward_collection_inputs(const Collection& collection, BAN::Optional<uint8_t> report_id, BAN::ConstByteSpan& data, size_t bit_offset) static BAN::ErrorOr<void> gather_collection_inputs(const USBHID::Collection& collection, BAN::Vector<USBHID::Report>& output)
{ {
for (const auto& entry : collection.entries)
{
if (entry.has<USBHID::Collection>())
{
TRY(gather_collection_inputs(entry.get<USBHID::Collection>(), output));
continue;
}
const auto& report = entry.get<USBHID::Report>();
if (report.type != USBHID::Report::Type::Input)
continue;
TRY(output.push_back(report));
}
return {};
}
BAN::ErrorOr<BAN::Vector<USBHIDDriver::DeviceReport>> USBHIDDriver::initializes_device_reports(const BAN::Vector<USBHID::Collection>& collection_list)
{
BAN::Vector<USBHIDDriver::DeviceReport> result;
TRY(result.reserve(collection_list.size()));
for (size_t i = 0; i < collection_list.size(); i++)
{
const auto& collection = collection_list[i];
USBHIDDriver::DeviceReport report;
TRY(gather_collection_inputs(collection, report.inputs));
if (collection.usage_page == 0x01)
{
switch (collection.usage_id)
{
case 0x02:
report.device = TRY(BAN::RefPtr<USBMouse>::create());
dprintln("Initialized an USB Mouse");
break;
case 0x06:
report.device = TRY(BAN::RefPtr<USBKeyboard>::create());
dprintln("Initialized an USB Keyboard");
break;
default:
dwarnln("Unsupported generic descript page usage 0x{2H}", collection.usage_id);
break;
}
}
TRY(result.push_back(BAN::move(report)));
}
for (auto& report : result)
if (report.device)
DevFileSystem::get().add_device(report.device);
return BAN::move(result);
}
void USBHIDDriver::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id)
{
// If this packet is not for us, skip it
if (m_endpoint_id != endpoint_id)
return;
if constexpr(DEBUG_HID)
{
const auto nibble_to_hex = [](uint8_t x) -> char { return x + (x < 10 ? '0' : 'A' - 10); };
char buffer[512];
char* ptr = buffer;
for (size_t i = 0; i < BAN::Math::min<size_t>((sizeof(buffer) - 1) / 3, data.size()); i++)
{
*ptr++ = nibble_to_hex(data[i] >> 4);
*ptr++ = nibble_to_hex(data[i] & 0xF);
*ptr++ = ' ';
}
*ptr = '\0';
dprintln_if(DEBUG_HID, "Received {} bytes from endpoint {}: {}", data.size(), endpoint_id, buffer);
}
const auto extract_bits = const auto extract_bits =
[data](size_t bit_offset, size_t bit_count, bool as_unsigned) -> int64_t [&data](size_t bit_offset, size_t bit_count, bool as_unsigned) -> int64_t
{ {
if (bit_offset >= data.size() * 8) if (bit_offset >= data.size() * 8)
return 0; return 0;
@ -272,24 +328,27 @@ namespace Kernel
return result; return result;
}; };
for (const auto& entry : collection.entries) BAN::Optional<uint8_t> report_id;
if (m_uses_report_id)
{ {
if (entry.has<Collection>()) report_id = data[0];
{ data = data.slice(1);
forward_collection_inputs(entry.get<Collection>(), report_id, data, bit_offset);
continue;
} }
ASSERT(entry.has<Report>()); size_t bit_offset = 0;
const auto& input = entry.get<Report>(); for (auto& device_input : m_device_inputs)
if (input.type != Report::Type::Input) {
continue; if (device_input.device)
device_input.device->start_report();
for (const auto& input : device_input.inputs)
{
if (report_id.value_or(input.report_id) != input.report_id) if (report_id.value_or(input.report_id) != input.report_id)
continue; continue;
ASSERT(input.report_size <= 32); ASSERT(input.report_size <= 32);
if (input.usage_id == 0 && input.usage_minimum == 0 && input.usage_maximum == 0) if (!device_input.device || (input.usage_id == 0 && input.usage_minimum == 0 && input.usage_maximum == 0))
{ {
bit_offset += input.report_size * input.report_count; bit_offset += input.report_size * input.report_count;
continue; continue;
@ -312,49 +371,17 @@ namespace Kernel
const uint32_t usage_base = input.usage_id ? input.usage_id : input.usage_minimum; const uint32_t usage_base = input.usage_id ? input.usage_id : input.usage_minimum;
if (input.flags & 0x02) if (input.flags & 0x02)
m_hid_device->handle_variable(input.usage_page, usage_base + i, physical); device_input.device->handle_variable(input.usage_page, usage_base + i, physical);
else else
m_hid_device->handle_array(input.usage_page, usage_base + physical); device_input.device->handle_array(input.usage_page, usage_base + physical);
bit_offset += input.report_size; bit_offset += input.report_size;
} }
} }
if (device_input.device)
device_input.device->stop_report();
} }
void USBHIDDriver::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id)
{
// If this packet is not for us, skip it
if (m_endpoint_id != endpoint_id)
return;
if constexpr(DEBUG_HID)
{
const auto nibble_to_hex = [](uint8_t x) -> char { return x + (x < 10 ? '0' : 'A' - 10); };
char buffer[512];
char* ptr = buffer;
for (size_t i = 0; i < BAN::Math::min<size_t>((sizeof(buffer) - 1) / 3, data.size()); i++)
{
*ptr++ = nibble_to_hex(data[i] >> 4);
*ptr++ = nibble_to_hex(data[i] & 0xF);
*ptr++ = ' ';
}
*ptr = '\0';
dprintln_if(DEBUG_HID, "Received {} bytes from endpoint {}: {}", data.size(), endpoint_id, buffer);
}
BAN::Optional<uint8_t> report_id;
if (m_uses_report_id)
{
report_id = data[0];
data = data.slice(1);
}
m_hid_device->start_report();
// FIXME: Handle other collections?
forward_collection_inputs(m_collections.front(), report_id, data, 0);
m_hid_device->stop_report();
} }
BAN::ErrorOr<BAN::Vector<Collection>> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id) BAN::ErrorOr<BAN::Vector<Collection>> parse_report_descriptor(BAN::ConstByteSpan report_data, bool& out_use_report_id)