From 7b580b8f564a51287249f882835cd4f9396d1a6a Mon Sep 17 00:00:00 2001 From: Bananymous Date: Sun, 9 Nov 2025 23:36:49 +0200 Subject: [PATCH] Kernel: Implement fd passing with SCM_RIGTHS --- .../include/kernel/Networking/UNIX/Socket.h | 22 +++- kernel/include/kernel/OpenFileDescriptorSet.h | 26 +++++ kernel/include/kernel/Process.h | 2 + kernel/kernel/Networking/UNIX/Socket.cpp | 109 +++++++++++++----- kernel/kernel/OpenFileDescriptorSet.cpp | 62 ++++++++++ 5 files changed, 186 insertions(+), 35 deletions(-) diff --git a/kernel/include/kernel/Networking/UNIX/Socket.h b/kernel/include/kernel/Networking/UNIX/Socket.h index 6616df74..4e8233be 100644 --- a/kernel/include/kernel/Networking/UNIX/Socket.h +++ b/kernel/include/kernel/Networking/UNIX/Socket.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace Kernel { @@ -16,6 +17,9 @@ namespace Kernel BAN_NON_COPYABLE(UnixDomainSocket); BAN_NON_MOVABLE(UnixDomainSocket); + public: + using FDWrapper = OpenFileDescriptorSet::FDWrapper; + public: static BAN::ErrorOr> create(Socket::Type, const Socket::Info&); BAN::ErrorOr make_socket_pair(UnixDomainSocket&); @@ -38,7 +42,7 @@ namespace Kernel UnixDomainSocket(Socket::Type, const Socket::Info&); ~UnixDomainSocket(); - BAN::ErrorOr add_packet(const msghdr&, size_t total_size); + BAN::ErrorOr add_packet(const msghdr&, size_t total_size, BAN::Vector&& fds_to_send); bool is_bound() const { return !m_bound_file.canonical_path.empty(); } bool is_bound_to_unused() const { return !m_bound_file.inode; } @@ -62,17 +66,23 @@ namespace Kernel BAN::String peer_address; }; + struct PacketInfo + { + size_t size; + BAN::Vector fds; + }; + private: const Socket::Type m_socket_type; VirtualFileSystem::File m_bound_file; BAN::Variant m_info; - BAN::CircularQueue m_packet_sizes; - size_t m_packet_size_total { 0 }; - BAN::UniqPtr m_packet_buffer; - Mutex m_packet_lock; - ThreadBlocker m_packet_thread_blocker; + BAN::CircularQueue m_packet_infos; + size_t m_packet_size_total { 0 }; + BAN::UniqPtr m_packet_buffer; + Mutex m_packet_lock; + ThreadBlocker m_packet_thread_blocker; friend class BAN::RefPtr; }; diff --git a/kernel/include/kernel/OpenFileDescriptorSet.h b/kernel/include/kernel/OpenFileDescriptorSet.h index ae990354..5dc668e1 100644 --- a/kernel/include/kernel/OpenFileDescriptorSet.h +++ b/kernel/include/kernel/OpenFileDescriptorSet.h @@ -109,6 +109,32 @@ namespace Kernel BAN::ErrorOr get_free_fd() const; BAN::ErrorOr get_free_fd_pair(int fds[2]) const; + public: + class FDWrapper + { + public: + FDWrapper(BAN::RefPtr); + FDWrapper(const FDWrapper& other) { *this = other; } + FDWrapper(FDWrapper&& other) { *this = BAN::move(other); } + ~FDWrapper(); + + FDWrapper& operator=(const FDWrapper&); + FDWrapper& operator=(FDWrapper&&); + + int fd() const { return m_fd; } + + void clear(); + + private: + BAN::RefPtr m_description; + int m_fd { -1 }; + + friend class OpenFileDescriptorSet; + }; + + BAN::ErrorOr get_fd_wrapper(int fd); + size_t open_all_fd_wrappers(BAN::Span fd_wrappers); + private: const Credentials& m_credentials; mutable Mutex m_mutex; diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index 67f9a289..8a46c6ac 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -240,6 +240,8 @@ namespace Kernel // FIXME: remove this API BAN::ErrorOr absolute_path_of(BAN::StringView) const; + OpenFileDescriptorSet& open_file_descriptor_set() { return m_open_file_descriptors; } + // ONLY CALLED BY TIMER INTERRUPT static void update_alarm_queue(); diff --git a/kernel/kernel/Networking/UNIX/Socket.cpp b/kernel/kernel/Networking/UNIX/Socket.cpp index 4f5a852f..3aa49a37 100644 --- a/kernel/kernel/Networking/UNIX/Socket.cpp +++ b/kernel/kernel/Networking/UNIX/Socket.cpp @@ -294,10 +294,11 @@ namespace Kernel } } - BAN::ErrorOr UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size) + BAN::ErrorOr UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size, BAN::Vector&& fds_to_send) { LockGuard _(m_packet_lock); - while (m_packet_sizes.full() || m_packet_size_total + total_size > s_packet_buffer_size) + + while (m_packet_infos.full() || m_packet_size_total + total_size > s_packet_buffer_size) TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr() + m_packet_size_total); @@ -311,7 +312,7 @@ namespace Kernel ASSERT(offset == total_size); m_packet_size_total += total_size; - m_packet_sizes.push(total_size); + m_packet_infos.emplace(total_size, BAN::move(fds_to_send)); m_packet_thread_blocker.unblock(); @@ -366,12 +367,6 @@ namespace Kernel return BAN::Error::from_errno(ENOTSUP); } - if (CMSG_FIRSTHDR(&message)) - { - dwarnln("ignoring recvmsg control message"); - message.msg_controllen = 0; - } - LockGuard _(m_packet_lock); while (m_packet_size_total == 0) { @@ -388,31 +383,68 @@ namespace Kernel TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); } + auto* cheader = CMSG_FIRSTHDR(&message); + if (cheader != nullptr) + cheader->cmsg_len = message.msg_controllen; + size_t cheader_len = 0; + uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); - const size_t max_recv_size = is_streaming() ? m_packet_size_total : m_packet_sizes.front(); + message.msg_flags = 0; size_t total_recv = 0; for (int i = 0; i < message.msg_iovlen; i++) { - const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, max_recv_size - total_recv); - memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv); + auto& packet_info = m_packet_infos.front(); + + auto fds_to_open = BAN::move(packet_info.fds); + if (cheader == nullptr && !fds_to_open.empty()) + message.msg_flags |= MSG_CTRUNC; + else if (!fds_to_open.empty()) + { + const size_t max_fd_count = (cheader->cmsg_len - sizeof(cmsghdr)) / sizeof(int); + if (max_fd_count < fds_to_open.size()) + message.msg_flags |= MSG_CTRUNC; + + const size_t fd_count = BAN::Math::min(fds_to_open.size(), max_fd_count); + const size_t fds_opened = Process::current().open_file_descriptor_set().open_all_fd_wrappers(fds_to_open.span().slice(0, fd_count)); + + auto* fd_data = reinterpret_cast(CMSG_DATA(cheader)); + for (size_t i = 0; i < fds_opened; i++) + fd_data[i] = fds_to_open[i].fd(); + + const size_t header_length = CMSG_LEN(fds_opened * sizeof(int)); + cheader->cmsg_level = SOL_SOCKET; + cheader->cmsg_type = SCM_RIGHTS; + cheader->cmsg_len = header_length; + cheader = CMSG_NXTHDR(&message, cheader); + if (cheader != nullptr) + cheader->cmsg_len = message.msg_controllen - header_length; + cheader_len += header_length; + } + + const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, packet_info.size); + memcpy(message.msg_iov[i].iov_base, packet_buffer, nrecv); total_recv += nrecv; + + if (!is_streaming() && nrecv < packet_info.size) + message.msg_flags |= MSG_TRUNC; + + const size_t to_discard = is_streaming() ? nrecv : packet_info.size; + + packet_info.size -= to_discard; + if (packet_info.size == 0) + m_packet_infos.pop(); + + // FIXME: get rid of this memmove :) + memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard); + m_packet_size_total -= to_discard; + + if (!is_streaming() || m_packet_infos.empty()) + break; } - size_t bytes_to_handle = total_recv; - while (bytes_to_handle) - { - const size_t to_handle = BAN::Math::min(bytes_to_handle, m_packet_sizes.front()); - bytes_to_handle -= to_handle; - m_packet_sizes.front() -= to_handle; - if (m_packet_sizes.front() == 0) - m_packet_sizes.pop(); - } - - const size_t to_discard = is_streaming() ? total_recv : max_recv_size; - memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard); - m_packet_size_total -= to_discard; + message.msg_controllen = cheader_len; m_packet_thread_blocker.unblock(); @@ -429,8 +461,27 @@ namespace Kernel return BAN::Error::from_errno(ENOTSUP); } - if (CMSG_FIRSTHDR(&message)) - dwarnln("ignoring sendmsg control message"); + BAN::Vector fds_to_send; + + for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header)) + { + if (header->cmsg_level != SOL_SOCKET) + { + dwarnln("ignoring control message with level {}", header->cmsg_level); + continue; + } + + if (header->cmsg_type != SCM_RIGHTS) + { + dwarnln("ignoring control message with type {}", header->cmsg_type); + continue; + } + + const auto* fd_data = reinterpret_cast(CMSG_DATA(header)); + const size_t fd_count = (header->cmsg_len - sizeof(cmsghdr)) / sizeof(int); + for (size_t i = 0; i < fd_count; i++) + TRY(fds_to_send.push_back(TRY(Process::current().open_file_descriptor_set().get_fd_wrapper(fd_data[i])))); + } const size_t total_message_size = [&message]() -> size_t { @@ -449,7 +500,7 @@ namespace Kernel auto target = connection_info.connection.lock(); if (!target) return BAN::Error::from_errno(ENOTCONN); - TRY(target->add_packet(message, total_message_size)); + TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send))); return total_message_size; } else @@ -493,7 +544,7 @@ namespace Kernel if (!target) return BAN::Error::from_errno(EDESTADDRREQ); - TRY(target->add_packet(message, total_message_size)); + TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send))); return total_message_size; } diff --git a/kernel/kernel/OpenFileDescriptorSet.cpp b/kernel/kernel/OpenFileDescriptorSet.cpp index 95eadb3c..a88e7a2d 100644 --- a/kernel/kernel/OpenFileDescriptorSet.cpp +++ b/kernel/kernel/OpenFileDescriptorSet.cpp @@ -659,4 +659,66 @@ namespace Kernel return BAN::Error::from_errno(EMFILE); } + using FDWrapper = OpenFileDescriptorSet::FDWrapper; + + FDWrapper::FDWrapper(BAN::RefPtr description) + : m_description(description) + { + if (m_description) + m_description->file.inode->on_clone(m_description->status_flags); + } + + FDWrapper::~FDWrapper() + { + clear(); + } + + FDWrapper& FDWrapper::operator=(const FDWrapper& other) + { + clear(); + m_description = other.m_description; + if (m_description) + m_description->file.inode->on_clone(m_description->status_flags); + return *this; + } + + FDWrapper& FDWrapper::operator=(FDWrapper&& other) + { + clear(); + m_description = BAN::move(other.m_description); + return *this; + } + + void FDWrapper::clear() + { + if (m_description) + m_description->file.inode->on_close(m_description->status_flags); + } + + BAN::ErrorOr OpenFileDescriptorSet::get_fd_wrapper(int fd) + { + LockGuard _(m_mutex); + TRY(validate_fd(fd)); + return FDWrapper { m_open_files[fd].description }; + } + + size_t OpenFileDescriptorSet::open_all_fd_wrappers(BAN::Span fd_wrappers) + { + LockGuard _(m_mutex); + + for (size_t i = 0; i < fd_wrappers.size(); i++) + { + auto fd_or_error = get_free_fd(); + if (fd_or_error.is_error()) + return i; + + const int fd = fd_or_error.release_value(); + m_open_files[fd].description = BAN::move(fd_wrappers[i].m_description); + m_open_files[fd].descriptor_flags = 0; + fd_wrappers[i].m_fd = fd; + } + + return fd_wrappers.size(); + } + }