diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 1c2e763c4d..0dd6d04380 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -95,6 +95,7 @@ set(KERNEL_SOURCES kernel/Timer/RTC.cpp kernel/Timer/Timer.cpp kernel/USB/Device.cpp + kernel/USB/HID/HIDDriver.cpp kernel/USB/USBManager.cpp kernel/USB/XHCI/Controller.cpp kernel/USB/XHCI/Device.cpp diff --git a/kernel/include/kernel/USB/Device.h b/kernel/include/kernel/USB/Device.h index a630e325d4..59a17f272e 100644 --- a/kernel/include/kernel/USB/Device.h +++ b/kernel/include/kernel/USB/Device.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -9,6 +10,18 @@ namespace Kernel { + class USBClassDriver + { + BAN_NON_COPYABLE(USBClassDriver); + BAN_NON_MOVABLE(USBClassDriver); + + public: + USBClassDriver() = default; + virtual ~USBClassDriver() = default; + + virtual void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id) = 0; + }; + class USBDevice { BAN_NON_COPYABLE(USBDevice); @@ -47,11 +60,13 @@ namespace Kernel const BAN::Vector& configurations() { return m_descriptor.configurations; } + virtual BAN::ErrorOr initialize_endpoint(const USBEndpointDescriptor&) = 0; virtual BAN::ErrorOr send_request(const USBDeviceRequest&, paddr_t buffer) = 0; static USB::SpeedClass determine_speed_class(uint64_t bits_per_second); protected: + void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id); virtual BAN::ErrorOr initialize_control_endpoint() = 0; private: @@ -60,6 +75,9 @@ namespace Kernel private: DeviceDescriptor m_descriptor; BAN::UniqPtr m_dma_buffer; + + // FIXME: support more than one interface from a configuration + BAN::UniqPtr m_class_driver; }; } diff --git a/kernel/include/kernel/USB/HID/HIDDriver.h b/kernel/include/kernel/USB/HID/HIDDriver.h new file mode 100644 index 0000000000..ef2aaa9ef1 --- /dev/null +++ b/kernel/include/kernel/USB/HID/HIDDriver.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include + +namespace Kernel +{ + + namespace USBHID + { + + struct Report + { + enum class Type { Input, Output, Feature }; + + uint16_t usage_page; + uint16_t usage_id; + Type type; + + uint32_t report_count; + uint32_t report_size; + + uint32_t usage_minimum; + uint32_t usage_maximum; + + int64_t logical_minimum; + int64_t logical_maximum; + + int64_t physical_minimum; + int64_t physical_maximum; + + uint8_t flags; + }; + + struct Collection + { + uint16_t usage_page; + uint16_t usage_id; + uint8_t type; + + BAN::Vector> entries; + }; + + } + + class USBHIDDevice : public InputDevice + { + BAN_NON_COPYABLE(USBHIDDevice); + BAN_NON_MOVABLE(USBHIDDevice); + + public: + USBHIDDevice(InputDevice::Type type) + : InputDevice(type) + {} + virtual ~USBHIDDevice() = default; + + virtual void start_report() = 0; + virtual void stop_report() = 0; + + virtual void handle_variable(uint16_t usage_page, uint16_t usage, int64_t state) = 0; + virtual void handle_array(uint16_t usage_page, uint16_t usage) = 0; + }; + + class USBHIDDriver final : public USBClassDriver + { + BAN_NON_COPYABLE(USBHIDDriver); + BAN_NON_MOVABLE(USBHIDDriver); + + public: + static BAN::ErrorOr> create(USBDevice&, const USBDevice::InterfaceDescriptor&, uint8_t interface_index); + + void handle_input_data(BAN::ConstByteSpan, uint8_t endpoint_id) override; + + private: + USBHIDDriver(USBDevice&, const USBDevice::InterfaceDescriptor&, uint8_t interface_index); + ~USBHIDDriver(); + + BAN::ErrorOr initialize(); + + void forward_collection_inputs(const USBHID::Collection&, BAN::ConstByteSpan& data, size_t bit_offset); + + private: + USBDevice& m_device; + USBDevice::InterfaceDescriptor m_interface; + const uint8_t m_interface_index; + + uint8_t m_endpoint_id { 0 }; + USBHID::Collection m_collection; + BAN::RefPtr m_hid_device; + + friend class BAN::UniqPtr; + }; + +} diff --git a/kernel/include/kernel/USB/XHCI/Device.h b/kernel/include/kernel/USB/XHCI/Device.h index 2d954369f0..8e6e51b8d1 100644 --- a/kernel/include/kernel/USB/XHCI/Device.h +++ b/kernel/include/kernel/USB/XHCI/Device.h @@ -26,11 +26,15 @@ namespace Kernel Mutex mutex; volatile uint32_t transfer_count { 0 }; volatile XHCI::TRB completion_trb; + + BAN::UniqPtr data_region; + void(XHCIDevice::*callback)(XHCI::TRB); }; public: static BAN::ErrorOr> create(XHCIController&, uint32_t port_id, uint32_t slot_id); + BAN::ErrorOr initialize_endpoint(const USBEndpointDescriptor&) override; BAN::ErrorOr send_request(const USBDeviceRequest&, paddr_t buffer) override; void on_transfer_event(const volatile XHCI::TRB&); @@ -47,6 +51,8 @@ namespace Kernel ~XHCIDevice(); BAN::ErrorOr update_actual_max_packet_size(); + void on_interrupt_endpoint_event(XHCI::TRB); + void advance_endpoint_enqueue(Endpoint&, bool chain); private: diff --git a/kernel/kernel/USB/Device.cpp b/kernel/kernel/USB/Device.cpp index dab33fa369..4aedebf669 100644 --- a/kernel/kernel/USB/Device.cpp +++ b/kernel/kernel/USB/Device.cpp @@ -1,5 +1,6 @@ #include #include +#include #define DEBUG_USB 0 #define USB_DUMP_DESCRIPTORS 0 @@ -152,7 +153,8 @@ namespace Kernel dprintln_if(DEBUG_USB, "Found CommunicationAndCDCControl interface"); break; case USB::InterfaceBaseClass::HID: - dprintln_if(DEBUG_USB, "Found HID interface"); + if (auto result = USBHIDDriver::create(*this, interface, j); !result.is_error()) + m_class_driver = result.release_value(); break; case USB::InterfaceBaseClass::Physical: dprintln_if(DEBUG_USB, "Found Physical interface"); @@ -215,6 +217,12 @@ namespace Kernel dprintln_if(DEBUG_USB, "Invalid interface base class {2H}", interface.descriptor.bInterfaceClass); break; } + + if (m_class_driver) + { + dprintln("Successfully initialized USB interface"); + return {}; + } } } @@ -299,6 +307,12 @@ namespace Kernel return BAN::move(configuration); } + void USBDevice::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id) + { + if (m_class_driver) + m_class_driver->handle_input_data(data, endpoint_id); + } + USB::SpeedClass USBDevice::determine_speed_class(uint64_t bits_per_second) { if (bits_per_second <= 1'500'000) diff --git a/kernel/kernel/USB/HID/HIDDriver.cpp b/kernel/kernel/USB/HID/HIDDriver.cpp new file mode 100644 index 0000000000..2452bb42df --- /dev/null +++ b/kernel/kernel/USB/HID/HIDDriver.cpp @@ -0,0 +1,734 @@ +#include + +#include +#include + +#define DEBUG_HID 0 + +namespace Kernel +{ + + enum class HIDDescriptorType : uint8_t + { + HID = 0x21, + Report = 0x22, + Physical = 0x23, + }; + + struct HIDDescriptor + { + uint8_t bLength; + uint8_t bDescriptorType; + uint16_t bcdHID; + uint8_t bCountryCode; + uint8_t bNumDescriptors; + struct + { + uint8_t bDescriptorType; + uint16_t wItemLength; + } __attribute__((packed)) descriptors[]; + } __attribute__((packed)); + static_assert(sizeof(HIDDescriptor) == 6); + + + struct GlobalState + { + BAN::Optional usage_page; + BAN::Optional logical_minimum; + BAN::Optional logical_maximum_signed; + BAN::Optional logical_maximum_unsigned; + BAN::Optional physical_minimum; + BAN::Optional physical_maximum; + // FIXME: support units + BAN::Optional report_size; + // FIXME: support report id + BAN::Optional report_count; + }; + + struct LocalState + { + BAN::Vector usage_stack; + BAN::Optional usage_minimum; + BAN::Optional usage_maximum; + // FIXME: support all local items + }; + + using namespace USBHID; + +#if DEBUG_HID + static void dump_hid_collection(const Collection& collection, size_t indent); +#endif + + static BAN::ErrorOr parse_report_descriptor(BAN::ConstByteSpan report_data); + + BAN::ErrorOr> USBHIDDriver::create(USBDevice& device, const USBDevice::InterfaceDescriptor& interface, uint8_t interface_index) + { + auto result = TRY(BAN::UniqPtr::create(device, interface, interface_index)); + TRY(result->initialize()); + return result; + } + + USBHIDDriver::USBHIDDriver(USBDevice& device, const USBDevice::InterfaceDescriptor& interface, uint8_t interface_index) + : m_device(device) + , m_interface(interface) + , m_interface_index(interface_index) + {} + + USBHIDDriver::~USBHIDDriver() + {} + + BAN::ErrorOr USBHIDDriver::initialize() + { + auto dma_buffer = TRY(DMARegion::create(1024)); + + ASSERT(static_cast(m_interface.descriptor.bInterfaceClass) == USB::InterfaceBaseClass::HID); + + size_t endpoint_index = static_cast(-1); + for (size_t i = 0; i < m_interface.endpoints.size(); i++) + { + const auto& endpoint = m_interface.endpoints[i]; + if (!(endpoint.descriptor.bEndpointAddress & 0x80)) + continue; + if (endpoint.descriptor.bmAttributes != 0x03) + continue; + endpoint_index = i; + break; + } + + if (endpoint_index >= m_interface.endpoints.size()) + { + dwarnln("HID device does not contain IN interrupt endpoint"); + return BAN::Error::from_errno(EFAULT); + } + + bool hid_descriptor_invalid = false; + size_t hid_descriptor_index = static_cast(-1); + for (size_t i = 0; i < m_interface.misc_descriptors.size(); i++) + { + if (static_cast(m_interface.misc_descriptors[i][1]) != HIDDescriptorType::HID) + continue; + if (m_interface.misc_descriptors[i].size() < sizeof(HIDDescriptor)) + hid_descriptor_invalid = true; + const auto& hid_descriptor = *reinterpret_cast(m_interface.misc_descriptors[i].data()); + if (hid_descriptor.bLength != m_interface.misc_descriptors[i].size()) + hid_descriptor_invalid = true; + if (hid_descriptor.bLength != sizeof(HIDDescriptor) + hid_descriptor.bNumDescriptors * 3) + hid_descriptor_invalid = true; + hid_descriptor_index = i; + break; + } + + if (hid_descriptor_index >= m_interface.misc_descriptors.size()) + { + dwarnln("HID device does not contain HID descriptor"); + return BAN::Error::from_errno(EFAULT); + } + if (hid_descriptor_invalid) + { + dwarnln("HID device contains an invalid HID descriptor"); + return BAN::Error::from_errno(EFAULT); + } + + // If this device supports boot protocol, make sure it is not used + if (m_interface.endpoints.front().descriptor.bDescriptorType & 0x80) + { + USBDeviceRequest request; + request.bmRequestType = USB::RequestType::HostToDevice | USB::RequestType::Class | USB::RequestType::Interface; + request.bRequest = USB::Request::SET_INTERFACE; + request.wValue = 1; // report protocol + request.wIndex = m_interface_index; + request.wLength = 0; + TRY(m_device.send_request(request, 0)); + } + + Collection collection {}; + + const auto& hid_descriptor = *reinterpret_cast(m_interface.misc_descriptors[hid_descriptor_index].data()); + dprintln_if(DEBUG_HID, "HID descriptor ({} bytes)", m_interface.misc_descriptors[hid_descriptor_index].size()); + dprintln_if(DEBUG_HID, " bLength: {}", hid_descriptor.bLength); + dprintln_if(DEBUG_HID, " bDescriptorType: {}", hid_descriptor.bDescriptorType); + dprintln_if(DEBUG_HID, " bcdHID: {H}.{2H}", hid_descriptor.bcdHID >> 8, hid_descriptor.bcdHID & 0xFF); + dprintln_if(DEBUG_HID, " bCountryCode: {}", hid_descriptor.bCountryCode); + dprintln_if(DEBUG_HID, " bNumDescriptors: {}", hid_descriptor.bNumDescriptors); + + bool report_descriptor_parsed = false; + for (size_t i = 0; i < hid_descriptor.bNumDescriptors; i++) + { + auto descriptor = hid_descriptor.descriptors[i]; + + if (static_cast(descriptor.bDescriptorType) != HIDDescriptorType::Report) + { + dprintln_if(DEBUG_HID, "Skipping HID descriptor type 0x{2H}", descriptor.bDescriptorType); + continue; + } + if (report_descriptor_parsed) + { + dwarnln("Multiple report descriptors specified"); + return BAN::Error::from_errno(ENOTSUP); + } + + if (descriptor.wItemLength > dma_buffer->size()) + { + dwarnln("Too big report descriptor size {} bytes ({} supported)", +descriptor.wItemLength, dma_buffer->size()); + return BAN::Error::from_errno(ENOBUFS); + } + + { + USBDeviceRequest request; + request.bmRequestType = USB::RequestType::DeviceToHost | USB::RequestType::Standard | USB::RequestType::Interface; + request.bRequest = USB::Request::GET_DESCRIPTOR; + request.wValue = static_cast(HIDDescriptorType::Report) << 8; + request.wIndex = m_interface_index; + request.wLength = descriptor.wItemLength; + auto transferred = TRY(m_device.send_request(request, dma_buffer->paddr())); + + if (transferred < descriptor.wItemLength) + { + dwarnln("HID device did not respond with full report descriptor"); + return BAN::Error::from_errno(EFAULT); + } + } + + dprintln_if(DEBUG_HID, "Parsing {} byte report descriptor", +descriptor.wItemLength); + + auto report_data = BAN::ConstByteSpan(reinterpret_cast(dma_buffer->vaddr()), descriptor.wItemLength); + collection = TRY(parse_report_descriptor(report_data)); + + report_descriptor_parsed = true; + } + + if (!report_descriptor_parsed) + { + dwarnln("No report descriptors specified"); + return BAN::Error::from_errno(EFAULT); + } + + if (collection.usage_page != 0x01) + { + dwarnln("Top most collection is not generic desktop page"); + return BAN::Error::from_errno(EFAULT); + } + +#if DEBUG_HID + { + SpinLockGuard _(Debug::s_debug_lock); + dump_hid_collection(collection, 0); + } +#endif + + switch (collection.usage_id) + { + default: + dwarnln("Unsupported generic descript page usage 0x{2H}", collection.usage_id); + return BAN::Error::from_errno(ENOTSUP); + } + DevFileSystem::get().add_device(m_hid_device); + + const auto& endpoint_descriptor = m_interface.endpoints[endpoint_index].descriptor; + + m_endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); + m_collection = BAN::move(collection); + + TRY(m_device.initialize_endpoint(endpoint_descriptor)); + + return {}; + } + + void USBHIDDriver::forward_collection_inputs(const Collection& collection, BAN::ConstByteSpan& data, size_t bit_offset) + { + const auto extract_bits = + [data](size_t bit_offset, size_t bit_count, bool as_unsigned) -> int64_t + { + if (bit_offset >= data.size() * 8) + return 0; + if (bit_count + bit_offset > data.size() * 8) + bit_count = data.size() * 8 - bit_offset; + + uint32_t result = 0; + uint32_t result_offset = 0; + + while (result_offset < bit_count) + { + const uint32_t byte = bit_offset / 8; + const uint32_t bit = bit_offset % 8; + const uint32_t count = BAN::Math::min(bit_count - result_offset, 8 - bit); + const uint32_t mask = (1 << count) - 1; + + result |= static_cast((data[byte] >> bit) & mask) << result_offset; + + bit_offset += count; + result_offset += count; + } + + if (!as_unsigned && (result & (1u << (bit_count - 1)))) + { + const uint32_t mask = (1u << bit_count) - 1; + return -(static_cast(~result & mask) + 1); + } + + return result; + }; + + for (const auto& entry : collection.entries) + { + if (entry.has()) + { + forward_collection_inputs(entry.get(), data, bit_offset); + continue; + } + + ASSERT(entry.has()); + const auto& input = entry.get(); + if (input.type != Report::Type::Input) + continue; + + ASSERT(input.report_size <= 32); + + if (input.usage_id == 0 && input.usage_minimum == 0 && input.usage_maximum == 0) + { + bit_offset += input.report_size * input.report_count; + continue; + } + + for (uint32_t i = 0; i < input.report_count; i++) + { + const int64_t logical = extract_bits(bit_offset, input.report_size, input.logical_minimum >= 0); + if (logical < input.logical_minimum || logical > input.logical_maximum) + { + bit_offset += input.report_size; + continue; + } + + const int64_t physical = + (input.physical_maximum - input.physical_minimum) * + (logical - input.logical_minimum) / + (input.logical_maximum - input.logical_minimum) + + input.physical_minimum; + + const uint32_t usage_base = input.usage_id ? input.usage_id : input.usage_minimum; + if (input.flags & 0x02) + m_hid_device->handle_variable(input.usage_page, usage_base + i, physical); + else + m_hid_device->handle_array(input.usage_page, usage_base + physical); + + bit_offset += input.report_size; + } + } + } + + void USBHIDDriver::handle_input_data(BAN::ConstByteSpan data, uint8_t endpoint_id) + { + ASSERT(m_endpoint_id == endpoint_id); + + if constexpr(DEBUG_HID) + { + const auto nibble_to_hex = [](uint8_t x) -> char { return x + (x < 10 ? '0' : 'A' - 10); }; + + char buffer[512]; + char* ptr = buffer; + for (size_t i = 0; i < BAN::Math::min((sizeof(buffer) - 1) / 3, data.size()); i++) + { + *ptr++ = nibble_to_hex(data[i] >> 4); + *ptr++ = nibble_to_hex(data[i] & 0xF); + *ptr++ = ' '; + } + *ptr = '\0'; + + dprintln_if(DEBUG_HID, "Received {} bytes from endpoint {}: {}", data.size(), endpoint_id, buffer); + } + + m_hid_device->start_report(); + forward_collection_inputs(m_collection, data, 0); + m_hid_device->stop_report(); + } + + BAN::ErrorOr parse_report_descriptor(BAN::ConstByteSpan report_data) + { + BAN::Vector global_stack; + GlobalState global_state; + + LocalState local_state; + + BAN::Optional result; + BAN::Vector collection_stack; + + const auto extract_report_item = + [&](bool as_unsigned) -> int64_t + { + uint32_t value = 0; + auto value_data = report_data.slice(1); + switch (report_data[0] & 0x03) + { + case 1: value = as_unsigned ? value_data.as() : value_data.as(); break; + case 2: value = as_unsigned ? value_data.as() : value_data.as(); break; + case 3: value = as_unsigned ? value_data.as() : value_data.as(); break; + } + return value; + }; + + constexpr auto get_correct_sign = + [](int64_t min, int64_t max_signed, int64_t max_unsigned) -> int64_t + { + if (min < 0 || max_signed >= 0) + return max_signed; + return max_unsigned; + }; + + const auto add_data_item = + [&](Report::Type type, uint32_t item_data, BAN::Vector>& container) -> BAN::ErrorOr + { + if (!global_state.report_count.has_value() || !global_state.report_size.has_value()) + { + dwarnln("Report count and/or report size is not defined"); + return BAN::Error::from_errno(EFAULT); + } + if (!global_state.usage_page.has_value()) + { + dwarnln("Usage page is not defined"); + return BAN::Error::from_errno(EFAULT); + } + if (!global_state.logical_minimum.has_value() || !global_state.logical_maximum_signed.has_value()) + { + dwarnln("Logical minimum and/or logical maximum is not defined"); + return BAN::Error::from_errno(EFAULT); + } + if (global_state.physical_minimum.has_value() != global_state.physical_minimum.has_value()) + { + dwarnln("Only one of physical minimum and physical maximum is defined"); + return BAN::Error::from_errno(EFAULT); + } + if (local_state.usage_minimum.has_value() != local_state.usage_maximum.has_value()) + { + dwarnln("Only one of logical minimum and logical maximum is defined"); + return BAN::Error::from_errno(EFAULT); + } + + const int64_t logical_minimum = global_state.logical_minimum.value(); + const int64_t logical_maximum = get_correct_sign( + global_state.logical_minimum.value(), + global_state.logical_maximum_signed.value(), + global_state.logical_maximum_unsigned.value() + ); + + int64_t physical_minimum = logical_minimum; + int64_t physical_maximum = logical_maximum; + if (global_state.physical_minimum.has_value() && (global_state.physical_minimum.value() || global_state.physical_maximum.value())) + { + physical_minimum = global_state.physical_minimum.value(); + physical_maximum = global_state.physical_maximum.value(); + } + + if (local_state.usage_stack.empty()) + { + if (local_state.usage_minimum.has_value() && local_state.usage_maximum.has_value()) + { + Report item; + item.usage_page = global_state.usage_page.value(); + item.usage_id = 0; + item.usage_minimum = local_state.usage_minimum.value(); + item.usage_maximum = local_state.usage_maximum.value(); + item.type = type; + item.report_count = global_state.report_count.value(); + item.report_size = global_state.report_size.value(); + item.logical_minimum = logical_minimum; + item.logical_maximum = logical_maximum; + item.physical_minimum = physical_minimum; + item.physical_maximum = physical_maximum; + item.flags = item_data; + TRY(container.push_back(item)); + + return {}; + } + + Report item; + item.usage_page = global_state.usage_page.value(); + item.usage_id = 0; + item.usage_minimum = 0; + item.usage_maximum = 0; + item.type = type; + item.report_count = global_state.report_count.value(); + item.report_size = global_state.report_size.value(); + item.logical_minimum = 0; + item.logical_maximum = 0; + item.physical_minimum = 0; + item.physical_maximum = 0; + item.flags = item_data; + TRY(container.push_back(item)); + + return {}; + } + + for (size_t i = 0; i < global_state.report_count.value(); i++) + { + const uint32_t usage = local_state.usage_stack[BAN::Math::min(i, local_state.usage_stack.size() - 1)]; + + Report item; + item.usage_page = (usage >> 16) ? (usage >> 16) : global_state.usage_page.value(); + item.usage_id = usage & 0xFFFF; + item.usage_minimum = 0; + item.usage_maximum = 0; + item.type = type; + item.report_count = 1; + item.report_size = global_state.report_size.value(); + item.logical_minimum = logical_minimum; + item.logical_maximum = logical_maximum; + item.physical_minimum = physical_minimum; + item.physical_maximum = physical_maximum; + item.flags = item_data; + TRY(container.push_back(item)); + } + + return {}; + }; + + while (report_data.size() > 0) + { + const uint8_t item_size = report_data[0] & 0x03; + const uint8_t item_type = (report_data[0] >> 2) & 0x03; + const uint8_t item_tag = (report_data[0] >> 4) & 0x0F; + + if (item_type == 0) + { + switch (item_tag) + { + case 0b1000: // input + if (collection_stack.empty()) + { + dwarnln("Invalid input item outside of collection"); + return BAN::Error::from_errno(EFAULT); + } + TRY(add_data_item(Report::Type::Input, extract_report_item(true), collection_stack.back().entries)); + break; + case 0b1001: // output + if (collection_stack.empty()) + { + dwarnln("Invalid input item outside of collection"); + return BAN::Error::from_errno(EFAULT); + } + TRY(add_data_item(Report::Type::Output, extract_report_item(true), collection_stack.back().entries)); + break; + case 0b1011: // feature + if (collection_stack.empty()) + { + dwarnln("Invalid input item outside of collection"); + return BAN::Error::from_errno(EFAULT); + } + TRY(add_data_item(Report::Type::Feature, extract_report_item(true), collection_stack.back().entries)); + break; + case 0b1010: // collection + { + if (local_state.usage_stack.size() != 1) + { + dwarnln("{} usages specified for collection", local_state.usage_stack.empty() ? "No" : "Multiple"); + return BAN::Error::from_errno(EFAULT); + } + uint16_t usage_page = 0; + if (global_state.usage_page.has_value()) + usage_page = global_state.usage_page.value(); + if (local_state.usage_stack.front() >> 16) + usage_page = local_state.usage_stack.front() >> 16; + if (usage_page == 0) + { + dwarnln("Usage page not specified for a collection"); + return BAN::Error::from_errno(EFAULT); + } + + TRY(collection_stack.emplace_back()); + collection_stack.back().usage_page = usage_page; + collection_stack.back().usage_id = local_state.usage_stack.front(); + break; + } + case 0b1100: // end collection + if (collection_stack.empty()) + { + dwarnln("End collection outside of collection"); + return BAN::Error::from_errno(EFAULT); + } + if (collection_stack.size() == 1) + { + if (result.has_value()) + { + dwarnln("Multiple top-level collections not supported"); + return BAN::Error::from_errno(ENOTSUP); + } + result = BAN::move(collection_stack.back()); + collection_stack.pop_back(); + } + else + { + TRY(collection_stack[collection_stack.size() - 2].entries.push_back(BAN::move(collection_stack.back()))); + collection_stack.pop_back(); + } + break; + default: + dwarnln("Report has reserved main item tag 0b{4b}", item_tag); + return BAN::Error::from_errno(EFAULT); + } + + local_state = LocalState(); + } + else if (item_type == 1) + { + switch (item_tag) + { + case 0b0000: // usage page + global_state.usage_page = extract_report_item(true); + break; + case 0b0001: // logical minimum + global_state.logical_minimum = extract_report_item(false); + break; + case 0b0010: // logical maximum + global_state.logical_maximum_signed = extract_report_item(false); + global_state.logical_maximum_unsigned = extract_report_item(true); + break; + case 0b0011: // physical minimum + global_state.physical_minimum = extract_report_item(false); + break; + case 0b0100: // physical maximum + global_state.physical_maximum = extract_report_item(false); + break; + case 0b0101: // unit exponent + dwarnln("Report units are not supported"); + return BAN::Error::from_errno(ENOTSUP); + case 0b0110: // unit + dwarnln("Report units are not supported"); + return BAN::Error::from_errno(ENOTSUP); + case 0b0111: // report size + global_state.report_size = extract_report_item(true); + break; + case 0b1000: // report id + dwarnln("Report IDs are not supported"); + return BAN::Error::from_errno(ENOTSUP); + case 0b1001: // report count + global_state.report_count = extract_report_item(true); + break; + case 0b1010: // push + TRY(global_stack.push_back(global_state)); + break; + case 0b1011: // pop + if (global_stack.empty()) + { + dwarnln("Report pop from empty stack"); + return BAN::Error::from_errno(EFAULT); + } + global_state = global_stack.back(); + global_stack.pop_back(); + break; + default: + dwarnln("Report has reserved global item tag 0b{4b}", item_tag); + return BAN::Error::from_errno(EFAULT); + } + } + else if (item_type == 2) + { + switch (item_tag) + { + case 0b0000: // usage + TRY(local_state.usage_stack.emplace_back(extract_report_item(true))); + break; + case 0b0001: // usage minimum + local_state.usage_minimum = extract_report_item(true); + break; + case 0b0010: // usage maximum + local_state.usage_maximum = extract_report_item(true); + break; + case 0b0011: // designator index + case 0b0100: // designator minimum + case 0b0101: // designator maximum + case 0b0111: // string index + case 0b1000: // string minimum + case 0b1001: // string maximum + case 0b1010: // delimeter + dwarnln("Unsupported local item tag 0b{4b}", item_tag); + return BAN::Error::from_errno(ENOTSUP); + default: + dwarnln("Report has reserved local item tag 0b{4b}", item_tag); + return BAN::Error::from_errno(EFAULT); + } + } + else + { + dwarnln("Report has reserved item type 0b{2b}", item_type); + return BAN::Error::from_errno(EFAULT); + } + + report_data = report_data.slice(1 + item_size); + } + + if (!result.has_value()) + { + dwarnln("No collection defined in report descriptor"); + return BAN::Error::from_errno(EFAULT); + } + + return result.release_value(); + } + +#if DEBUG_HID + static void print_indent(size_t indent) + { + for (size_t i = 0; i < indent; i++) + Debug::putchar(' '); + } + + static void dump_hid_report(const Report& report, size_t indent) + { + const char* report_type = ""; + switch (report.type) + { + case Report::Type::Input: report_type = "input"; break; + case Report::Type::Output: report_type = "output"; break; + case Report::Type::Feature: report_type = "feature"; break; + } + print_indent(indent); + BAN::Formatter::println(Debug::putchar, "report {}", report_type); + + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "usage page: {2H}", report.usage_page); + + if (report.usage_id || report.usage_minimum || report.usage_maximum) + { + print_indent(indent + 4); + if (report.usage_id) + BAN::Formatter::println(Debug::putchar, "usage: {2H}", report.usage_id); + else + BAN::Formatter::println(Debug::putchar, "usage: {2H}->{2H}", report.usage_minimum, report.usage_maximum); + } + + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "flags: 0b{8b}", report.flags); + + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "size: {}", report.report_size); + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "count: {}", report.report_count); + + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "lminimum: {}", report.logical_minimum); + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "lmaximum: {}", report.logical_maximum); + + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "pminimum: {}", report.physical_minimum); + print_indent(indent + 4); + BAN::Formatter::println(Debug::putchar, "pmaximum: {}", report.physical_maximum); + } + + static void dump_hid_collection(const Collection& collection, size_t indent) + { + print_indent(indent); + BAN::Formatter::println(Debug::putchar, "collection {}", collection.type); + print_indent(indent); + BAN::Formatter::println(Debug::putchar, "usage {H}:{H}", collection.usage_page, collection.usage_id); + + for (const auto& entry : collection.entries) + { + if (entry.has()) + dump_hid_collection(entry.get(), indent + 4); + if (entry.has()) + dump_hid_report(entry.get(), indent + 4); + } + } +#endif + +} diff --git a/kernel/kernel/USB/XHCI/Controller.cpp b/kernel/kernel/USB/XHCI/Controller.cpp index d7368474f8..89a1b4e002 100644 --- a/kernel/kernel/USB/XHCI/Controller.cpp +++ b/kernel/kernel/USB/XHCI/Controller.cpp @@ -253,7 +253,8 @@ namespace Kernel case 3: if (!connection_change) continue; - break; + dprintln_if(DEBUG_XHCI, "USB 3 devices not supported"); + continue; default: continue; } diff --git a/kernel/kernel/USB/XHCI/Device.cpp b/kernel/kernel/USB/XHCI/Device.cpp index c6e46d1cfe..8437bba346 100644 --- a/kernel/kernel/USB/XHCI/Device.cpp +++ b/kernel/kernel/USB/XHCI/Device.cpp @@ -65,14 +65,18 @@ namespace Kernel auto& slot_context = *reinterpret_cast (m_input_context->vaddr() + 1 * context_size); auto& endpoint0_context = *reinterpret_cast (m_input_context->vaddr() + 2 * context_size); + memset(&input_control_context, 0, context_size); input_control_context.add_context_flags = 0b11; - slot_context.root_hub_port_number = m_port_id; + memset(&slot_context, 0, context_size); slot_context.route_string = 0; + slot_context.root_hub_port_number = m_port_id; slot_context.context_entries = 1; slot_context.interrupter_target = 0; slot_context.speed = speed_id; + // FIXME: 4.5.2 hub + memset(&endpoint0_context, 0, context_size); endpoint0_context.endpoint_type = XHCI::EndpointType::Control; endpoint0_context.max_packet_size = m_endpoints[0].max_packet_size; endpoint0_context.error_count = 3; @@ -122,18 +126,21 @@ namespace Kernel { auto& input_control_context = *reinterpret_cast(m_input_context->vaddr() + 0 * context_size); + auto& slot_context = *reinterpret_cast (m_input_context->vaddr() + 1 * context_size); auto& endpoint0_context = *reinterpret_cast (m_input_context->vaddr() + 2 * context_size); - input_control_context.add_context_flags = 0b10; + memset(&input_control_context, 0, context_size); + input_control_context.add_context_flags = 0b11; + memset(&slot_context, 0, context_size); + slot_context.max_exit_latency = 0; // FIXME: + slot_context.interrupter_target = 0; + + memset(&endpoint0_context, 0, context_size); endpoint0_context.endpoint_type = XHCI::EndpointType::Control; endpoint0_context.max_packet_size = m_endpoints[0].max_packet_size; - endpoint0_context.max_burst_size = 0; - endpoint0_context.tr_dequeue_pointer = (m_endpoints[0].transfer_ring->paddr() + (m_endpoints[0].enqueue_index * sizeof(XHCI::TRB))) | 1; - endpoint0_context.interval = 0; - endpoint0_context.max_primary_streams = 0; - endpoint0_context.mult = 0; endpoint0_context.error_count = 3; + endpoint0_context.tr_dequeue_pointer = m_endpoints[0].transfer_ring->paddr() | 1; } XHCI::TRB evaluate_context { .address_device_command = {} }; @@ -148,6 +155,114 @@ namespace Kernel return {}; } + BAN::ErrorOr XHCIDevice::initialize_endpoint(const USBEndpointDescriptor& endpoint_descriptor) + { + const uint32_t endpoint_id = (endpoint_descriptor.bEndpointAddress & 0x0F) * 2 + !!(endpoint_descriptor.bEndpointAddress & 0x80); + + auto& endpoint = m_endpoints[endpoint_id - 1]; + ASSERT(!endpoint.transfer_ring); + + uint32_t last_valid_endpoint_id = endpoint_id; + for (size_t i = endpoint_id; i < m_endpoints.size(); i++) + if (m_endpoints[i].transfer_ring) + last_valid_endpoint_id = i + 1; + + endpoint.transfer_ring = TRY(DMARegion::create(m_transfer_ring_trb_count * sizeof(XHCI::TRB))); + endpoint.max_packet_size = endpoint_descriptor.wMaxPacketSize & 0x07FF; + endpoint.dequeue_index = 0; + endpoint.enqueue_index = 0; + endpoint.cycle_bit = 1; + endpoint.callback = &XHCIDevice::on_interrupt_endpoint_event; + endpoint.data_region = TRY(DMARegion::create(endpoint.max_packet_size)); + + memset(reinterpret_cast(endpoint.transfer_ring->vaddr()), 0, endpoint.transfer_ring->size()); + + { + const uint32_t context_size = m_controller.context_size_set() ? 64 : 32; + + auto& input_control_context = *reinterpret_cast(m_input_context->vaddr()); + auto& slot_context = *reinterpret_cast (m_input_context->vaddr() + context_size); + auto& endpoint_context = *reinterpret_cast (m_input_context->vaddr() + (endpoint_id + 1) * context_size); + + memset(&input_control_context, 0, context_size); + input_control_context.add_context_flags = (1u << endpoint_id) | 1; + + memset(&slot_context, 0, context_size); + slot_context.context_entries = last_valid_endpoint_id; + // FIXME: 4.5.2 hub + + ASSERT(endpoint_descriptor.bEndpointAddress & 0x80); + ASSERT((endpoint_descriptor.bmAttributes & 0x03) == 3); + ASSERT(m_controller.port(m_port_id).revision_major == 2); + + memset(&endpoint_context, 0, context_size); + endpoint_context.endpoint_type = XHCI::EndpointType::InterruptIn; + endpoint_context.max_packet_size = endpoint.max_packet_size; + endpoint_context.max_burst_size = (endpoint_descriptor.wMaxPacketSize >> 11) & 0x0003; + endpoint_context.mult = 0; + endpoint_context.error_count = 3; + endpoint_context.tr_dequeue_pointer = endpoint.transfer_ring->paddr() | 1; + const uint32_t max_esit_payload = endpoint_context.max_packet_size * (endpoint_context.max_burst_size + 1); + endpoint_context.max_esit_payload_lo = max_esit_payload & 0xFFFF; + endpoint_context.max_esit_payload_hi = max_esit_payload >> 16; + } + + XHCI::TRB configure_endpoint { .configure_endpoint_command = {} }; + configure_endpoint.configure_endpoint_command.trb_type = XHCI::TRBType::ConfigureEndpointCommand; + configure_endpoint.configure_endpoint_command.input_context_pointer = m_input_context->paddr(); + configure_endpoint.configure_endpoint_command.deconfigure = 0; + configure_endpoint.configure_endpoint_command.slot_id = m_slot_id; + TRY(m_controller.send_command(configure_endpoint)); + + auto& trb = *reinterpret_cast(endpoint.transfer_ring->vaddr()); + memset(const_cast(&trb), 0, sizeof(XHCI::TRB)); + trb.normal.trb_type = XHCI::TRBType::Normal; + trb.normal.data_buffer_pointer = endpoint.data_region->paddr(); + trb.normal.trb_transfer_length = endpoint.data_region->size(); + trb.normal.td_size = 0; + trb.normal.interrupt_target = 0; + trb.normal.cycle_bit = 1; + trb.normal.interrupt_on_completion = 1; + trb.normal.interrupt_on_short_packet = 1; + advance_endpoint_enqueue(endpoint, false); + + m_controller.doorbell_reg(m_slot_id) = endpoint_id; + + return {}; + } + + void XHCIDevice::on_interrupt_endpoint_event(XHCI::TRB trb) + { + ASSERT(trb.trb_type == XHCI::TRBType::TransferEvent); + if (trb.transfer_event.completion_code != 1 && trb.transfer_event.completion_code != 13) + { + dwarnln("Interrupt endpoint got transfer event with completion code {}", +trb.transfer_event.completion_code); + return; + } + + const uint32_t endpoint_id = trb.transfer_event.endpoint_id; + auto& endpoint = m_endpoints[endpoint_id - 1]; + ASSERT(endpoint.transfer_ring && endpoint.data_region); + + const uint32_t transfer_length = endpoint.max_packet_size - trb.transfer_event.trb_transfer_length; + auto received_data = BAN::ConstByteSpan(reinterpret_cast(endpoint.data_region->vaddr()), transfer_length); + handle_input_data(received_data, endpoint_id); + + auto& new_trb = *reinterpret_cast(endpoint.transfer_ring->vaddr() + endpoint.enqueue_index * sizeof(XHCI::TRB)); + memset(const_cast(&new_trb), 0, sizeof(XHCI::TRB)); + new_trb.normal.trb_type = XHCI::TRBType::Normal; + new_trb.normal.data_buffer_pointer = endpoint.data_region->paddr(); + new_trb.normal.trb_transfer_length = endpoint.max_packet_size; + new_trb.normal.td_size = 0; + new_trb.normal.interrupt_target = 0; + new_trb.normal.cycle_bit = endpoint.cycle_bit; + new_trb.normal.interrupt_on_completion = 1; + new_trb.normal.interrupt_on_short_packet = 1; + advance_endpoint_enqueue(endpoint, false); + + m_controller.doorbell_reg(m_slot_id) = endpoint_id; + } + void XHCIDevice::on_transfer_event(const volatile XHCI::TRB& trb) { ASSERT(trb.trb_type == XHCI::TRBType::TransferEvent); @@ -157,10 +272,22 @@ namespace Kernel return; } + auto& endpoint = m_endpoints[trb.transfer_event.endpoint_id - 1]; + + if (endpoint.callback) + { + XHCI::TRB copy; + copy.raw.dword0 = trb.raw.dword0; + copy.raw.dword1 = trb.raw.dword1; + copy.raw.dword2 = trb.raw.dword2; + copy.raw.dword3 = trb.raw.dword3; + (this->*endpoint.callback)(copy); + return; + } + // Get received bytes from short packet if (trb.transfer_event.completion_code == 13) { - auto& endpoint = m_endpoints[trb.transfer_event.endpoint_id - 1]; auto* transfer_trb_arr = reinterpret_cast(endpoint.transfer_ring->vaddr()); const uint32_t trb_index = (trb.transfer_event.trb_pointer - endpoint.transfer_ring->paddr()) / sizeof(XHCI::TRB); @@ -179,7 +306,7 @@ namespace Kernel } // NOTE: dword2 is last (and atomic) as that is what send_request is waiting for - auto& completion_trb = m_endpoints[trb.transfer_event.endpoint_id - 1].completion_trb; + auto& completion_trb = endpoint.completion_trb; completion_trb.raw.dword0 = trb.raw.dword0; completion_trb.raw.dword1 = trb.raw.dword1; completion_trb.raw.dword3 = trb.raw.dword3; diff --git a/kernel/kernel/kernel.cpp b/kernel/kernel/kernel.cpp index 0dea8e73cf..a9ae444dda 100644 --- a/kernel/kernel/kernel.cpp +++ b/kernel/kernel/kernel.cpp @@ -213,6 +213,9 @@ static void init2(void*) PCI::PCIManager::get().initialize_devices(); dprintln("PCI devices initialized"); + // FIXME: This is very hacky way to wait until USB stack is initialized + SystemTimer::get().sleep(500); + VirtualFileSystem::initialize(cmdline.root); dprintln("VFS initialized"); diff --git a/script/qemu.sh b/script/qemu.sh index 9a49fc3cd2..ff365ebb92 100755 --- a/script/qemu.sh +++ b/script/qemu.sh @@ -31,5 +31,6 @@ qemu-system-$QEMU_ARCH \ -drive format=raw,id=disk,file=${BANAN_DISK_IMAGE_PATH},if=none \ -device e1000e,netdev=net \ -netdev user,id=net \ + -device qemu-xhci \ $DISK_ARGS \ $@ \