From c648ea12f24a5633cdadfd2f9c39746831ad6f73 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Sun, 8 Feb 2026 19:38:28 +0200 Subject: [PATCH] Kernel: Cleanup and fix UNIX sockets EPOLLOUT is now sent to the correct socket and buffer is now a ring buffer to avoid unnecessary memmove on every packet --- .../include/kernel/Networking/UNIX/Socket.h | 7 +- kernel/kernel/Networking/UNIX/Socket.cpp | 137 ++++++++++++------ 2 files changed, 101 insertions(+), 43 deletions(-) diff --git a/kernel/include/kernel/Networking/UNIX/Socket.h b/kernel/include/kernel/Networking/UNIX/Socket.h index 4282c856..0423b39e 100644 --- a/kernel/include/kernel/Networking/UNIX/Socket.h +++ b/kernel/include/kernel/Networking/UNIX/Socket.h @@ -70,6 +70,7 @@ namespace Kernel size_t size; BAN::Vector fds; BAN::Optional ucred; + BAN::WeakPtr sender; }; BAN::ErrorOr add_packet(const msghdr&, PacketInfo&&, bool dont_block); @@ -82,10 +83,14 @@ namespace Kernel BAN::CircularQueue m_packet_infos; size_t m_packet_size_total { 0 }; + size_t m_packet_buffer_tail { 0 }; BAN::UniqPtr m_packet_buffer; - Mutex m_packet_lock; + mutable Mutex m_packet_lock; ThreadBlocker m_packet_thread_blocker; + BAN::Atomic m_sndbuf { 0 }; + BAN::Atomic m_bytes_sent { 0 }; + friend class BAN::RefPtr; }; diff --git a/kernel/kernel/Networking/UNIX/Socket.cpp b/kernel/kernel/Networking/UNIX/Socket.cpp index 4fdb4685..eb372859 100644 --- a/kernel/kernel/Networking/UNIX/Socket.cpp +++ b/kernel/kernel/Networking/UNIX/Socket.cpp @@ -25,7 +25,7 @@ namespace Kernel static BAN::HashMap, BAN::WeakPtr, UnixSocketHash> s_bound_sockets; static Mutex s_bound_socket_lock; - static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE; + static constexpr size_t s_packet_buffer_size = 0x10000; static BAN::ErrorOr validate_sockaddr_un(const sockaddr* address, socklen_t address_len) { @@ -45,8 +45,6 @@ namespace Kernel return BAN::StringView { sockaddr_un.sun_path, length }; } - // FIXME: why is this using spinlocks instead of mutexes?? - BAN::ErrorOr> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info) { auto socket = TRY(BAN::RefPtr::create(socket_type, info)); @@ -64,6 +62,7 @@ namespace Kernel UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info) : Socket(info) , m_socket_type(socket_type) + , m_sndbuf(s_packet_buffer_size) { switch (socket_type) { @@ -290,29 +289,55 @@ namespace Kernel case Socket::Type::DGRAM: return false; } + ASSERT_NOT_REACHED(); } BAN::ErrorOr UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info, bool dont_block) { LockGuard _(m_packet_lock); - while (m_packet_infos.full() || m_packet_size_total + packet_info.size > s_packet_buffer_size) + const auto has_space = + [&]() -> bool + { + if (m_packet_infos.full()) + return false; + if (is_streaming()) + return m_packet_size_total < m_packet_buffer->size(); + return m_packet_size_total + packet_info.size <= m_packet_buffer->size(); + }; + + while (!has_space()) { if (dont_block) return BAN::Error::from_errno(EAGAIN); TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); } - uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr() + m_packet_size_total); - - size_t offset = 0; - for (int i = 0; i < packet.msg_iovlen; i++) + if (auto available = m_packet_buffer->size() - m_packet_size_total; available < packet_info.size) { - memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len); - offset += packet.msg_iov[i].iov_len; + ASSERT(is_streaming()); + packet_info.size = available; } - ASSERT(offset == packet_info.size); + uint8_t* packet_buffer_base_u8 = reinterpret_cast(m_packet_buffer->vaddr()); + + size_t bytes_copied = 0; + for (int i = 0; i < packet.msg_iovlen && bytes_copied < packet_info.size; i++) + { + const uint8_t* iov_base_u8 = static_cast(packet.msg_iov[i].iov_base); + + const size_t to_copy = BAN::Math::min(packet.msg_iov[i].iov_len, packet_info.size - bytes_copied); + + const size_t copy_offset = (m_packet_buffer_tail + m_packet_size_total + bytes_copied) % m_packet_buffer->size(); + const size_t before_wrap = BAN::Math::min(to_copy, m_packet_buffer->size() - copy_offset); + memcpy(packet_buffer_base_u8 + copy_offset, iov_base_u8, before_wrap); + if (const size_t after_wrap = to_copy - before_wrap) + memcpy(packet_buffer_base_u8, iov_base_u8 + before_wrap, after_wrap); + + bytes_copied += to_copy; + } + + ASSERT(bytes_copied == packet_info.size); m_packet_size_total += packet_info.size; m_packet_infos.emplace(BAN::move(packet_info)); @@ -320,7 +345,7 @@ namespace Kernel epoll_notify(EPOLLIN); - return {}; + return bytes_copied; } bool UnixDomainSocket::can_read_impl() const @@ -336,24 +361,13 @@ namespace Kernel return false; } + LockGuard _(m_packet_lock); return m_packet_size_total > 0; } bool UnixDomainSocket::can_write_impl() const { - if (m_info.has()) - { - auto& connection_info = m_info.get(); - auto connection = connection_info.connection.lock(); - if (!connection) - return false; - if (connection->m_packet_infos.full()) - return false; - if (connection->m_packet_size_total >= s_packet_buffer_size) - return false; - } - - return true; + return m_bytes_sent < m_sndbuf; } bool UnixDomainSocket::has_hungup_impl() const @@ -377,6 +391,7 @@ namespace Kernel } LockGuard _(m_packet_lock); + while (m_packet_size_total == 0) { if (m_info.has()) @@ -397,7 +412,7 @@ namespace Kernel cheader->cmsg_len = message.msg_controllen; size_t cheader_len = 0; - uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); + uint8_t* packet_buffer_base_u8 = reinterpret_cast(m_packet_buffer->vaddr()); message.msg_flags = 0; @@ -471,7 +486,12 @@ namespace Kernel uint8_t* iov_base = static_cast(iov.iov_base); const size_t nrecv = BAN::Math::min(iov.iov_len - iov_offset, packet_info.size - packet_received); - memcpy(iov_base + iov_offset, packet_buffer + packet_received, nrecv); + + const size_t copy_offset = (m_packet_buffer_tail + packet_received) % m_packet_buffer->size(); + const size_t before_wrap = BAN::Math::min(nrecv, m_packet_buffer->size() - copy_offset); + memcpy(iov_base + iov_offset, packet_buffer_base_u8 + copy_offset, before_wrap); + if (const size_t after_wrap = nrecv - before_wrap) + memcpy(iov_base + iov_offset + before_wrap, packet_buffer_base_u8, after_wrap); packet_received += nrecv; @@ -492,12 +512,17 @@ namespace Kernel if (packet_info.size == 0) m_packet_infos.pop(); - // FIXME: get rid of this memmove :) - memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard); + m_packet_buffer_tail = (m_packet_buffer_tail + to_discard) % m_packet_buffer->size(); m_packet_size_total -= to_discard; total_recv += packet_received; + if (auto sender = packet_info.sender.lock()) + { + sender->m_bytes_sent -= to_discard; + sender->epoll_notify(EPOLLOUT); + } + // on linux ancillary data is a barrier on stream sockets, lets do the same if (!is_streaming() || had_ancillary_data) break; @@ -507,8 +532,6 @@ namespace Kernel m_packet_thread_blocker.unblock(); - epoll_notify(EPOLLOUT); - return total_recv; } @@ -528,13 +551,14 @@ namespace Kernel return result; }(); - if (total_message_size > s_packet_buffer_size) - return BAN::Error::from_errno(ENOBUFS); + if (!is_streaming() && total_message_size > m_packet_buffer->size()) + return BAN::Error::from_errno(EMSGSIZE); PacketInfo packet_info { - .size = total_message_size, - .fds = {}, - .ucred = {}, + .size = total_message_size, + .fds = {}, + .ucred = {}, + .sender = TRY(get_weak_ptr()), }; for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header)) @@ -607,8 +631,9 @@ namespace Kernel auto target = connection_info.connection.lock(); if (!target) return BAN::Error::from_errno(ENOTCONN); - TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT)); - return total_message_size; + const size_t bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT)); + m_bytes_sent += bytes_sent; + return bytes_sent; } else { @@ -625,7 +650,7 @@ namespace Kernel Process::current().root_file().inode, Process::current().credentials(), absolute_path, - O_RDWR + O_WRONLY )).inode; } else @@ -651,8 +676,11 @@ namespace Kernel if (!target) return BAN::Error::from_errno(EDESTADDRREQ); - TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT)); - return total_message_size; + if (target->m_socket_type != m_socket_type) + return BAN::Error::from_errno(EPROTOTYPE); + const auto bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT)); + m_bytes_sent += bytes_sent; + return bytes_sent; } } @@ -688,7 +716,7 @@ namespace Kernel result = 0; break; case SO_SNDBUF: - result = m_packet_buffer->size(); + result = m_sndbuf; break; case SO_RCVBUF: result = m_packet_buffer->size(); @@ -705,4 +733,29 @@ namespace Kernel return {}; } + BAN::ErrorOr UnixDomainSocket::setsockopt_impl(int level, int option, const void* value, socklen_t value_len) + { + if (level != SOL_SOCKET) + return BAN::Error::from_errno(EINVAL); + + switch (option) + { + case SO_SNDBUF: + { + if (value_len != sizeof(int)) + return BAN::Error::from_errno(EINVAL); + const int new_sndbuf = *static_cast(value); + if (new_sndbuf < 0) + return BAN::Error::from_errno(EINVAL); + m_sndbuf = new_sndbuf; + break; + } + default: + dwarnln("setsockopt(SOL_SOCKET, {})", option); + return BAN::Error::from_errno(ENOTSUP); + } + + return {}; + } + }