From 088f77a226c8811f501af2c38ac521d0283f845e Mon Sep 17 00:00:00 2001
From: Bananymous <oskari.alaranta@bananymous.com>
Date: Tue, 11 Feb 2025 02:18:50 +0200
Subject: [PATCH] Kernel: Add super basic support for USB keyboard LEDs

This is very hacky but it seems to mostly work. Also for some reason
this fixed my Razer Mamba mouse????
---
 kernel/include/kernel/USB/HID/HIDDriver.h |  3 +
 kernel/include/kernel/USB/HID/Keyboard.h  | 12 ++-
 kernel/kernel/USB/HID/HIDDriver.cpp       | 13 ++--
 kernel/kernel/USB/HID/Keyboard.cpp        | 93 +++++++++++++++++++++++
 4 files changed, 113 insertions(+), 8 deletions(-)

diff --git a/kernel/include/kernel/USB/HID/HIDDriver.h b/kernel/include/kernel/USB/HID/HIDDriver.h
index 7f1216c94..826c5ad4e 100644
--- a/kernel/include/kernel/USB/HID/HIDDriver.h
+++ b/kernel/include/kernel/USB/HID/HIDDriver.h
@@ -78,6 +78,9 @@ namespace Kernel
 		void handle_stall(uint8_t endpoint_id) override;
 		void handle_input_data(size_t byte_count, uint8_t endpoint_id) override;
 
+		USBDevice& device() { return m_device; }
+		const USBDevice::InterfaceDescriptor& interface() const { return m_interface; }
+
 	private:
 		USBHIDDriver(USBDevice&, const USBDevice::InterfaceDescriptor&);
 		~USBHIDDriver();
diff --git a/kernel/include/kernel/USB/HID/Keyboard.h b/kernel/include/kernel/USB/HID/Keyboard.h
index df4f32865..6ae9da693 100644
--- a/kernel/include/kernel/USB/HID/Keyboard.h
+++ b/kernel/include/kernel/USB/HID/Keyboard.h
@@ -20,12 +20,15 @@ namespace Kernel
 		void update() override;
 
 	private:
-		USBKeyboard()
-			: USBHIDDevice(InputDevice::Type::Keyboard)
-		{}
+		USBKeyboard(USBHIDDriver& driver, BAN::Vector<USBHID::Report>&& outputs);
 		~USBKeyboard() = default;
 
+		void set_leds(uint16_t mask);
+		void set_leds(uint8_t report_id, uint16_t mask);
+
 	private:
+		USBHIDDriver& m_driver;
+
 		SpinLock m_keyboard_lock;
 		InterruptState m_lock_state;
 
@@ -33,6 +36,9 @@ namespace Kernel
 		BAN::Array<bool, 0x100> m_keyboard_state_temp { false };
 		uint16_t m_toggle_mask { 0 };
 
+		uint16_t m_led_mask { 0 };
+		BAN::Vector<USBHID::Report> m_outputs;
+
 		BAN::Optional<uint8_t> m_repeat_scancode;
 		uint8_t m_repeat_modifier { 0 };
 		uint64_t m_next_repeat_event_ms { 0 };
diff --git a/kernel/kernel/USB/HID/HIDDriver.cpp b/kernel/kernel/USB/HID/HIDDriver.cpp
index ee4f20d12..2f9a23702 100644
--- a/kernel/kernel/USB/HID/HIDDriver.cpp
+++ b/kernel/kernel/USB/HID/HIDDriver.cpp
@@ -213,18 +213,18 @@ namespace Kernel
 		return {};
 	}
 
-	static BAN::ErrorOr<void> gather_collection_inputs(const USBHID::Collection& collection, BAN::Vector<USBHID::Report>& output)
+	static BAN::ErrorOr<void> gather_collection_reports(const USBHID::Collection& collection, BAN::Vector<USBHID::Report>& output, USBHID::Report::Type type)
 	{
 		for (const auto& entry : collection.entries)
 		{
 			if (entry.has<USBHID::Collection>())
 			{
-				TRY(gather_collection_inputs(entry.get<USBHID::Collection>(), output));
+				TRY(gather_collection_reports(entry.get<USBHID::Collection>(), output, type));
 				continue;
 			}
 
 			const auto& report = entry.get<USBHID::Report>();
-			if (report.type != USBHID::Report::Type::Input)
+			if (report.type != type)
 				continue;
 
 			TRY(output.push_back(report));
@@ -243,7 +243,10 @@ namespace Kernel
 			const auto& collection = collection_list[i];
 
 			USBHIDDriver::DeviceReport report;
-			TRY(gather_collection_inputs(collection, report.inputs));
+			TRY(gather_collection_reports(collection, report.inputs,  USBHID::Report::Type::Input));
+
+			BAN::Vector<USBHID::Report> outputs;
+			TRY(gather_collection_reports(collection, outputs, USBHID::Report::Type::Output));
 
 			if (collection.usage_page == 0x01)
 			{
@@ -254,7 +257,7 @@ namespace Kernel
 						dprintln("Initialized an USB Mouse");
 						break;
 					case 0x06:
-						report.device = TRY(BAN::RefPtr<USBKeyboard>::create());
+						report.device = TRY(BAN::RefPtr<USBKeyboard>::create(*this, BAN::move(outputs)));
 						dprintln("Initialized an USB Keyboard");
 						break;
 					default:
diff --git a/kernel/kernel/USB/HID/Keyboard.cpp b/kernel/kernel/USB/HID/Keyboard.cpp
index 404b3282c..6242f8839 100644
--- a/kernel/kernel/USB/HID/Keyboard.cpp
+++ b/kernel/kernel/USB/HID/Keyboard.cpp
@@ -14,6 +14,14 @@ namespace Kernel
 	static void initialize_scancode_to_keycode();
 	static constexpr bool is_repeatable_scancode(uint8_t scancode);
 
+	USBKeyboard::USBKeyboard(USBHIDDriver& driver, BAN::Vector<USBHID::Report>&& outputs)
+		: USBHIDDevice(InputDevice::Type::Keyboard)
+		, m_driver(driver)
+		, m_outputs(BAN::move(outputs))
+	{
+		set_leds(0);
+	}
+
 	void USBKeyboard::start_report()
 	{
 		m_lock_state = m_keyboard_lock.lock();
@@ -126,6 +134,12 @@ namespace Kernel
 	{
 		using KeyModifier = LibInput::KeyEvent::Modifier;
 
+		const auto toggle_mask = ({ SpinLockGuard _(m_keyboard_lock); m_toggle_mask; });
+
+		if (m_led_mask != toggle_mask)
+			set_leds(toggle_mask);
+		m_led_mask = toggle_mask;
+
 		SpinLockGuard _(m_keyboard_lock);
 
 		if (!m_repeat_scancode.has_value() || SystemTimer::get().ms_since_boot() < m_next_repeat_event_ms)
@@ -143,6 +157,85 @@ namespace Kernel
 		m_next_repeat_event_ms += s_repeat_interval_ms;
 	}
 
+	void USBKeyboard::set_leds(uint16_t mask)
+	{
+		uint8_t report_ids_done[0x100 / 8] {};
+
+		for (const auto& report : m_outputs)
+		{
+			if (report.usage_page != 0x08)
+				continue;
+
+			const auto byte = report.report_id / 8;
+			const auto bit  = report.report_id % 8;
+			if (report_ids_done[byte] & (1u << bit))
+				continue;
+
+			set_leds(report.report_id, mask);
+			report_ids_done[byte] |= (1u << bit);
+		}
+	}
+
+	void USBKeyboard::set_leds(uint8_t report_id, uint16_t mask)
+	{
+		using KeyModifier = LibInput::KeyEvent::Modifier;
+
+		size_t report_bits = 0;
+		for (const auto& report : m_outputs)
+		{
+			if (report.report_id != report_id)
+				continue;
+			report_bits += report.report_size * report.report_count;
+		}
+
+		const size_t report_bytes = (report_bits + 7) / 8;
+
+		uint8_t* data = static_cast<uint8_t*>(kmalloc(report_bytes));
+		if (data == nullptr)
+			return;
+		memset(data, 0, report_bytes);
+
+		size_t bit_offset = 0;
+		for (const auto& report : m_outputs)
+		{
+			if (report.report_id != report_id)
+				continue;
+
+			for (size_t i = 0; report.report_size == 1 && i < report.report_count; i++, bit_offset++)
+			{
+				const size_t usage = (report.usage_id ? report.usage_id : report.usage_minimum) + bit_offset;
+				switch (usage)
+				{
+					case 0x01:
+						if (mask & KeyModifier::NumLock)
+							data[bit_offset / 8] |= 1u << (bit_offset % 8);
+						break;
+					case 0x02:
+						if (mask & KeyModifier::CapsLock)
+							data[bit_offset / 8] |= 1u << (bit_offset % 8);
+						break;
+					case 0x03:
+						if (mask & KeyModifier::ScrollLock)
+							data[bit_offset / 8] |= 1u << (bit_offset % 8);
+						break;
+				}
+			}
+
+			bit_offset += report.report_size * report.report_count;
+		}
+
+		USBDeviceRequest request;
+		request.bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Class | USB::RequestType::Interface;
+		request.bRequest = 0x09;
+		request.wValue = 0x0200 | report_id;
+		request.wIndex = m_driver.interface().descriptor.bInterfaceNumber;
+		request.wLength = report_bytes;
+		if (auto ret = m_driver.device().send_request(request, kmalloc_paddr_of(reinterpret_cast<vaddr_t>(data)).value()); ret.is_error())
+			dprintln_if(DEBUG_USB_KEYBOARD, "Failed to update LEDs: {}", ret.error());
+
+		kfree(data);
+	}
+
 	void initialize_scancode_to_keycode()
 	{
 		using LibInput::keycode_function;