From 7de689055c61ceb7b30c94163ef2ae94264efe29 Mon Sep 17 00:00:00 2001
From: Bananymous <oskari.alaranta@bananymous.com>
Date: Thu, 6 Feb 2025 22:33:45 +0200
Subject: [PATCH] Kernel: Pass xHCI device information in structs

This makes code more readable and extendable
---
 kernel/include/kernel/USB/XHCI/Controller.h |  2 +-
 kernel/include/kernel/USB/XHCI/Device.h     | 15 +++--
 kernel/kernel/USB/XHCI/Controller.cpp       | 36 ++++++++---
 kernel/kernel/USB/XHCI/Device.cpp           | 69 +++++++++------------
 4 files changed, 69 insertions(+), 53 deletions(-)

diff --git a/kernel/include/kernel/USB/XHCI/Controller.h b/kernel/include/kernel/USB/XHCI/Controller.h
index df73c4f01..93af62321 100644
--- a/kernel/include/kernel/USB/XHCI/Controller.h
+++ b/kernel/include/kernel/USB/XHCI/Controller.h
@@ -46,7 +46,7 @@ namespace Kernel
 
 		void port_updater_task();
 
-		BAN::ErrorOr<uint8_t> initialize_slot(int port_index);
+		BAN::ErrorOr<uint8_t> initialize_device(uint32_t route_string, USB::SpeedClass speed_class);
 		void deinitialize_slot(uint8_t slot_id);
 
 		BAN::ErrorOr<XHCI::TRB> send_command(const XHCI::TRB&);
diff --git a/kernel/include/kernel/USB/XHCI/Device.h b/kernel/include/kernel/USB/XHCI/Device.h
index 5017d8d2a..09c5fc066 100644
--- a/kernel/include/kernel/USB/XHCI/Device.h
+++ b/kernel/include/kernel/USB/XHCI/Device.h
@@ -30,8 +30,15 @@ namespace Kernel
 			void(XHCIDevice::*callback)(XHCI::TRB);
 		};
 
+		struct Info
+		{
+			USB::SpeedClass speed_class;
+			uint8_t slot_id;
+			uint32_t route_string;
+		};
+
 	public:
-		static BAN::ErrorOr<BAN::UniqPtr<XHCIDevice>> create(XHCIController&, uint32_t port_id, uint32_t slot_id);
+		static BAN::ErrorOr<BAN::UniqPtr<XHCIDevice>> create(XHCIController&, const Info& info);
 
 		BAN::ErrorOr<void> configure_endpoint(const USBEndpointDescriptor&) override;
 		BAN::ErrorOr<size_t> send_request(const USBDeviceRequest&, paddr_t buffer) override;
@@ -43,8 +50,9 @@ namespace Kernel
 		BAN::ErrorOr<void> initialize_control_endpoint() override;
 
 	private:
-		XHCIDevice(XHCIController& controller, uint32_t port_id, uint32_t slot_id);
+		XHCIDevice(XHCIController& controller, const Info& info);
 		~XHCIDevice();
+
 		BAN::ErrorOr<void> update_actual_max_packet_size();
 
 		void on_interrupt_or_bulk_endpoint_event(XHCI::TRB);
@@ -55,8 +63,7 @@ namespace Kernel
 		static constexpr uint32_t m_transfer_ring_trb_count = PAGE_SIZE / sizeof(XHCI::TRB);
 
 		XHCIController& m_controller;
-		const uint32_t m_port_id;
-		const uint32_t m_slot_id;
+		Info m_info;
 
 		Mutex m_mutex;
 
diff --git a/kernel/kernel/USB/XHCI/Controller.cpp b/kernel/kernel/USB/XHCI/Controller.cpp
index d765b621b..de9eab94f 100644
--- a/kernel/kernel/USB/XHCI/Controller.cpp
+++ b/kernel/kernel/USB/XHCI/Controller.cpp
@@ -363,7 +363,8 @@ namespace Kernel
 						continue;
 				}
 
-				if (auto ret = initialize_slot(i); !ret.is_error())
+				const uint8_t speed_id = (op_port.portsc >> XHCI::PORTSC::PORT_SPEED_SHIFT) & XHCI::PORTSC::PORT_SPEED_MASK;
+				if (auto ret = initialize_device(i + 1, speed_id_to_class(speed_id)); !ret.is_error())
 					my_port.slot_id = ret.value();
 				else
 				{
@@ -377,10 +378,8 @@ namespace Kernel
 		}
 	}
 
-	BAN::ErrorOr<uint8_t> XHCIController::initialize_slot(int port_index)
+	BAN::ErrorOr<uint8_t> XHCIController::initialize_device(uint32_t route_string, USB::SpeedClass speed_class)
 	{
-		auto& my_port = m_ports[port_index];
-
 		XHCI::TRB enable_slot { .enable_slot_command {} };
 		enable_slot.enable_slot_command.trb_type  = XHCI::TRBType::EnableSlotCommand;
 		// 7.2.2.1.4: The Protocol Slot Type field of a USB3 or USB2 xHCI Supported Protocol Capability shall be set to ‘0’.
@@ -393,9 +392,24 @@ namespace Kernel
 			dwarnln("EnableSlotCommand returned an invalid slot {}", slot_id);
 			return BAN::Error::from_errno(EFAULT);
 		}
-		dprintln_if(DEBUG_XHCI, "allocated slot {} for port {}", slot_id, port_index + 1);
 
-		m_slots[slot_id - 1] = TRY(XHCIDevice::create(*this, port_index + 1, slot_id));
+#if DEBUG_XHCI
+		const auto& root_port = m_ports[(route_string & 0x0F) - 1];
+
+		dprintln("Initializing USB {H}.{H} device on slot {}",
+			root_port.revision_major,
+			root_port.revision_minor,
+			slot_id
+		);
+#endif
+
+		const XHCIDevice::Info info {
+			.speed_class = speed_class,
+			.slot_id = slot_id,
+			.route_string = route_string,
+		};
+
+		m_slots[slot_id - 1] = TRY(XHCIDevice::create(*this, info));
 		if (auto ret = m_slots[slot_id - 1]->initialize(); ret.is_error())
 		{
 			dwarnln("Could not initialize device on slot {}: {}", slot_id, ret.error());
@@ -403,9 +417,13 @@ namespace Kernel
 			return ret.release_error();
 		}
 
-		my_port.slot_id = slot_id;
-
-		dprintln_if(DEBUG_XHCI, "device on slot {} initialized", slot_id);
+#if DEBUG_XHCI
+		dprintln("USB {H}.{H} device on slot {} initialized",
+			root_port.revision_major,
+			root_port.revision_minor,
+			slot_id
+		);
+#endif
 
 		return slot_id;
 	}
diff --git a/kernel/kernel/USB/XHCI/Device.cpp b/kernel/kernel/USB/XHCI/Device.cpp
index b3b6c35e9..fe2267bc5 100644
--- a/kernel/kernel/USB/XHCI/Device.cpp
+++ b/kernel/kernel/USB/XHCI/Device.cpp
@@ -8,35 +8,32 @@
 namespace Kernel
 {
 
-	BAN::ErrorOr<BAN::UniqPtr<XHCIDevice>> XHCIDevice::create(XHCIController& controller, uint32_t port_id, uint32_t slot_id)
+	BAN::ErrorOr<BAN::UniqPtr<XHCIDevice>> XHCIDevice::create(XHCIController& controller, const Info& info)
 	{
-		return TRY(BAN::UniqPtr<XHCIDevice>::create(controller, port_id, slot_id));
+		return TRY(BAN::UniqPtr<XHCIDevice>::create(controller, info));
 	}
 
-	XHCIDevice::XHCIDevice(XHCIController& controller, uint32_t port_id, uint32_t slot_id)
-		: USBDevice(controller.speed_id_to_class((controller.operational_regs().ports[port_id - 1].portsc >> XHCI::PORTSC::PORT_SPEED_SHIFT) & XHCI::PORTSC::PORT_SPEED_MASK))
+	XHCIDevice::XHCIDevice(XHCIController& controller, const Info& info)
+		: USBDevice(info.speed_class)
 		, m_controller(controller)
-		, m_port_id(port_id)
-		, m_slot_id(slot_id)
+		, m_info(info)
 	{}
 
 	XHCIDevice::~XHCIDevice()
 	{
 		XHCI::TRB disable_slot { .disable_slot_command {} };
 		disable_slot.disable_slot_command.trb_type = XHCI::TRBType::DisableSlotCommand;
-		disable_slot.disable_slot_command.slot_id = m_slot_id;
+		disable_slot.disable_slot_command.slot_id = m_info.slot_id;
 		if (auto ret = m_controller.send_command(disable_slot); ret.is_error())
-			dwarnln("Could not disable slot {}: {}", m_slot_id, ret.error());
+			dwarnln("Could not disable slot {}: {}", m_info.slot_id, ret.error());
 		else
-			dprintln_if(DEBUG_XHCI, "Slot {} disabled", m_slot_id);
+			dprintln_if(DEBUG_XHCI, "Slot {} disabled", m_info.slot_id);
 	}
 
 	BAN::ErrorOr<void> XHCIDevice::initialize_control_endpoint()
 	{
 		const uint32_t context_size = m_controller.context_size_set() ? 64 : 32;
 
-		const uint32_t portsc = m_controller.operational_regs().ports[m_port_id - 1].portsc;
-		const uint32_t speed_id = (portsc >> XHCI::PORTSC::PORT_SPEED_SHIFT) & XHCI::PORTSC::PORT_SPEED_MASK;
 
 		m_endpoints[0].max_packet_size = 0;
 		switch (m_speed_class)
@@ -68,25 +65,27 @@ namespace Kernel
 			auto& slot_context          = *reinterpret_cast<XHCI::SlotContext*>        (m_input_context->vaddr() + 1 * context_size);
 			auto& endpoint0_context     = *reinterpret_cast<XHCI::EndpointContext*>    (m_input_context->vaddr() + 2 * context_size);
 
-			memset(&input_control_context, 0, context_size);
-			input_control_context.add_context_flags = 0b11;
+			input_control_context.add_context_flags = (1 << 1) | (1 << 0);
 
-			memset(&slot_context, 0, context_size);
-			slot_context.route_string         = 0;
-			slot_context.root_hub_port_number = m_port_id;
+			slot_context.route_string         = m_info.route_string >> 4;
+			slot_context.root_hub_port_number = m_info.route_string & 0x0F;
 			slot_context.context_entries      = 1;
 			slot_context.interrupter_target   = 0;
-			slot_context.speed                = speed_id;
+			slot_context.speed                = m_controller.speed_class_to_id(m_info.speed_class);
 			// FIXME: 4.5.2 hub
 
-			memset(&endpoint0_context, 0, context_size);
 			endpoint0_context.endpoint_type       = XHCI::EndpointType::Control;
 			endpoint0_context.max_packet_size     = m_endpoints[0].max_packet_size;
-			endpoint0_context.error_count         = 3;
+			endpoint0_context.max_burst_size      = 0; // FIXME: SuperSpeed
+			endpoint0_context.interval            = 0;
 			endpoint0_context.tr_dequeue_pointer  = m_endpoints[0].transfer_ring->paddr() | 1;
+			endpoint0_context.max_primary_streams = 0;
+			endpoint0_context.error_count         = 3;
 		}
 
-		m_controller.dcbaa_reg(m_slot_id) = m_output_context->paddr();
+		m_controller.dcbaa_reg(m_info.slot_id) = m_output_context->paddr();
+
+		dprintln_if(DEBUG_XHCI, "Addressing device on slot {}", m_info.slot_id);
 
 		for (int i = 0; i < 2; i++)
 		{
@@ -95,7 +94,7 @@ namespace Kernel
 			address_device.address_device_command.input_context_pointer     = m_input_context->paddr();
 			// NOTE: some legacy devices require sending request with BSR=1 before actual BSR=0
 			address_device.address_device_command.block_set_address_request = (i == 0);
-			address_device.address_device_command.slot_id                   = m_slot_id;
+			address_device.address_device_command.slot_id                   = m_info.slot_id;
 			TRY(m_controller.send_command(address_device));
 		}
 
@@ -116,12 +115,12 @@ namespace Kernel
 		USBDeviceRequest request;
 		request.bmRequestType = USB::RequestType::DeviceToHost | USB::RequestType::Standard | USB::RequestType::Device;
 		request.bRequest      = USB::Request::GET_DESCRIPTOR;
-		request.wValue        = 0x0100;
+		request.wValue        = USB::DescriptorType::DEVICE << 8;
 		request.wIndex        = 0;
 		request.wLength       = 8;
 		TRY(send_request(request, kmalloc_paddr_of((vaddr_t)buffer.data()).value()));
 
-		const bool is_usb3 = m_controller.port(m_port_id).revision_major == 3;
+		const bool is_usb3 = (m_speed_class == USB::SpeedClass::SuperSpeed);
 		const uint32_t new_max_packet_size = is_usb3 ? 1u << buffer.back() : buffer.back();
 
 		if (m_endpoints[0].max_packet_size == new_max_packet_size)
@@ -133,31 +132,23 @@ namespace Kernel
 
 		{
 			auto& input_control_context = *reinterpret_cast<XHCI::InputControlContext*>(m_input_context->vaddr() + 0 * context_size);
-			auto& slot_context          = *reinterpret_cast<XHCI::SlotContext*>        (m_input_context->vaddr() + 1 * context_size);
 			auto& endpoint0_context     = *reinterpret_cast<XHCI::EndpointContext*>    (m_input_context->vaddr() + 2 * context_size);
 
 			memset(&input_control_context, 0, context_size);
-			input_control_context.add_context_flags = 0b11;
+			input_control_context.add_context_flags = (1 << 1);
 
-			memset(&slot_context, 0, context_size);
-			slot_context.max_exit_latency   = 0; // FIXME:
-			slot_context.interrupter_target = 0;
-
-			memset(&endpoint0_context, 0, context_size);
-			endpoint0_context.endpoint_type       = XHCI::EndpointType::Control;
-			endpoint0_context.max_packet_size     = m_endpoints[0].max_packet_size;
-			endpoint0_context.error_count         = 3;
-			endpoint0_context.tr_dequeue_pointer  = m_endpoints[0].transfer_ring->paddr() | 1;
+			// Only update max packet size. Other fields should be fine from initial configuration
+			endpoint0_context.max_packet_size = new_max_packet_size;
 		}
 
 		XHCI::TRB evaluate_context { .address_device_command = {} };
 		evaluate_context.address_device_command.trb_type                  = XHCI::TRBType::EvaluateContextCommand;
 		evaluate_context.address_device_command.input_context_pointer     = m_input_context->paddr();
 		evaluate_context.address_device_command.block_set_address_request = 0;
-		evaluate_context.address_device_command.slot_id                   = m_slot_id;
+		evaluate_context.address_device_command.slot_id                   = m_info.slot_id;
 		TRY(m_controller.send_command(evaluate_context));
 
-		dprintln_if(DEBUG_XHCI, "successfully updated max packet size to {}", m_endpoints[0].max_packet_size);
+		dprintln_if(DEBUG_XHCI, "Updated max packet size to {}", new_max_packet_size);
 
 		return {};
 	}
@@ -282,7 +273,7 @@ namespace Kernel
 		configure_endpoint.configure_endpoint_command.trb_type              = XHCI::TRBType::ConfigureEndpointCommand;
 		configure_endpoint.configure_endpoint_command.input_context_pointer = m_input_context->paddr();
 		configure_endpoint.configure_endpoint_command.deconfigure           = 0;
-		configure_endpoint.configure_endpoint_command.slot_id               = m_slot_id;
+		configure_endpoint.configure_endpoint_command.slot_id               = m_info.slot_id;
 		TRY(m_controller.send_command(configure_endpoint));
 
 		return {};
@@ -455,7 +446,7 @@ namespace Kernel
 
 		endpoint.transfer_count = request.wLength;
 
-		m_controller.doorbell_reg(m_slot_id) = 1;
+		m_controller.doorbell_reg(m_info.slot_id) = 1;
 
 		const uint64_t timeout_ms = SystemTimer::get().ms_since_boot() + 1000;
 		while ((__atomic_load_n(&completion_trb.raw.dword2, __ATOMIC_SEQ_CST) >> 24) == 0)
@@ -492,7 +483,7 @@ namespace Kernel
 		trb.normal.interrupt_on_short_packet = 1;
 		advance_endpoint_enqueue(endpoint, false);
 
-		m_controller.doorbell_reg(m_slot_id) = endpoint_id;
+		m_controller.doorbell_reg(m_info.slot_id) = endpoint_id;
 	}
 
 	void XHCIDevice::advance_endpoint_enqueue(Endpoint& endpoint, bool chain)