From 04d24bce703d4ef5b20dffb234b81b0733d06a04 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Sun, 9 Nov 2025 16:23:37 +0200 Subject: [PATCH] Kernel/LibC: Implement {recv,send}msg as syscalls This also removes the now old recvfrom and sendto syscalls. These are now implemented as wrappers around recvmsg and sendmsg. Also replace unnecessary spinlocks from unix socket with mutexes --- kernel/include/kernel/FS/Inode.h | 8 +- kernel/include/kernel/FS/Socket.h | 2 - kernel/include/kernel/Networking/TCPSocket.h | 4 +- kernel/include/kernel/Networking/UDPSocket.h | 4 +- .../include/kernel/Networking/UNIX/Socket.h | 10 +- kernel/include/kernel/OpenFileDescriptorSet.h | 4 +- kernel/include/kernel/Process.h | 4 +- kernel/kernel/FS/Inode.cpp | 12 +- kernel/kernel/Networking/TCPSocket.cpp | 70 +++-- kernel/kernel/Networking/UDPSocket.cpp | 74 ++++-- kernel/kernel/Networking/UNIX/Socket.cpp | 247 +++++++++++------- kernel/kernel/OpenFileDescriptorSet.cpp | 46 +++- kernel/kernel/Process.cpp | 93 ++++--- kernel/kernel/Syscall.cpp | 4 +- .../libraries/LibC/include/sys/syscall.h | 4 +- userspace/libraries/LibC/sys/socket.cpp | 143 ++++------ 16 files changed, 422 insertions(+), 307 deletions(-) diff --git a/kernel/include/kernel/FS/Inode.h b/kernel/include/kernel/FS/Inode.h index c5cb7728..a1f1f2e0 100644 --- a/kernel/include/kernel/FS/Inode.h +++ b/kernel/include/kernel/FS/Inode.h @@ -108,8 +108,8 @@ namespace Kernel BAN::ErrorOr bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr listen(int backlog); - BAN::ErrorOr sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len); - BAN::ErrorOr recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len); + BAN::ErrorOr sendmsg(const msghdr& message, int flags); + BAN::ErrorOr recvmsg(msghdr& message, int flags); BAN::ErrorOr getsockname(sockaddr* address, socklen_t* address_len); BAN::ErrorOr getpeername(sockaddr* address, socklen_t* address_len); @@ -155,8 +155,8 @@ namespace Kernel virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr recvmsg_impl(msghdr&, int) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr sendmsg_impl(const msghdr&, int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr getsockname_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } diff --git a/kernel/include/kernel/FS/Socket.h b/kernel/include/kernel/FS/Socket.h index 474ff279..7e719e80 100644 --- a/kernel/include/kernel/FS/Socket.h +++ b/kernel/include/kernel/FS/Socket.h @@ -51,8 +51,6 @@ namespace Kernel : m_info(info) {} - BAN::ErrorOr read_impl(off_t, BAN::ByteSpan buffer) override { return recvfrom_impl(buffer, nullptr, nullptr); } - BAN::ErrorOr write_impl(off_t, BAN::ConstByteSpan buffer) override { return sendto_impl(buffer, nullptr, 0); } BAN::ErrorOr fsync_impl() final override { return {}; } private: diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index 0a0b63d7..a3298d57 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -60,8 +60,8 @@ namespace Kernel virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr listen_impl(int) override; virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override; + virtual BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; + virtual BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override; virtual BAN::ErrorOr ioctl_impl(int, void*) override; diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index 6e39381e..83e6ed0b 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -34,8 +34,8 @@ namespace Kernel virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override; - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) override; + virtual BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; + virtual BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override { return BAN::Error::from_errno(ENOTCONN); } virtual BAN::ErrorOr ioctl_impl(int, void*) override; diff --git a/kernel/include/kernel/Networking/UNIX/Socket.h b/kernel/include/kernel/Networking/UNIX/Socket.h index d246bde5..6616df74 100644 --- a/kernel/include/kernel/Networking/UNIX/Socket.h +++ b/kernel/include/kernel/Networking/UNIX/Socket.h @@ -25,8 +25,8 @@ namespace Kernel virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr listen_impl(int) override; virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override; + virtual BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; + virtual BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override; virtual bool can_read_impl() const override; @@ -38,7 +38,7 @@ namespace Kernel UnixDomainSocket(Socket::Type, const Socket::Info&); ~UnixDomainSocket(); - BAN::ErrorOr add_packet(BAN::ConstByteSpan); + BAN::ErrorOr add_packet(const msghdr&, size_t total_size); bool is_bound() const { return !m_bound_file.canonical_path.empty(); } bool is_bound_to_unused() const { return !m_bound_file.inode; } @@ -54,7 +54,7 @@ namespace Kernel BAN::WeakPtr connection; BAN::Queue> pending_connections; ThreadBlocker pending_thread_blocker; - SpinLock pending_lock; + Mutex pending_lock; }; struct ConnectionlessInfo @@ -71,7 +71,7 @@ namespace Kernel BAN::CircularQueue m_packet_sizes; size_t m_packet_size_total { 0 }; BAN::UniqPtr m_packet_buffer; - SpinLock m_packet_lock; + Mutex m_packet_lock; ThreadBlocker m_packet_thread_blocker; friend class BAN::RefPtr; diff --git a/kernel/include/kernel/OpenFileDescriptorSet.h b/kernel/include/kernel/OpenFileDescriptorSet.h index 9486627c..791b0b89 100644 --- a/kernel/include/kernel/OpenFileDescriptorSet.h +++ b/kernel/include/kernel/OpenFileDescriptorSet.h @@ -51,8 +51,8 @@ namespace Kernel BAN::ErrorOr read_dir_entries(int fd, struct dirent* list, size_t list_len); - BAN::ErrorOr recvfrom(int fd, BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len); - BAN::ErrorOr sendto(int fd, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len); + BAN::ErrorOr recvmsg(int socket, msghdr& message, int flags); + BAN::ErrorOr sendmsg(int socket, const msghdr& message, int flags); BAN::ErrorOr file_of(int) const; BAN::ErrorOr path_of(int) const; diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index b44d4246..10a1cdcc 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -134,8 +134,8 @@ namespace Kernel BAN::ErrorOr sys_bind(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr sys_connect(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr sys_listen(int socket, int backlog); - BAN::ErrorOr sys_sendto(const sys_sendto_t*); - BAN::ErrorOr sys_recvfrom(sys_recvfrom_t*); + BAN::ErrorOr sys_recvmsg(int socket, msghdr* message, int flags); + BAN::ErrorOr sys_sendmsg(int socket, const msghdr* message, int flags); BAN::ErrorOr sys_ioctl(int fildes, int request, void* arg); diff --git a/kernel/kernel/FS/Inode.cpp b/kernel/kernel/FS/Inode.cpp index 250413ec..7a5724c8 100644 --- a/kernel/kernel/FS/Inode.cpp +++ b/kernel/kernel/FS/Inode.cpp @@ -176,21 +176,21 @@ namespace Kernel return listen_impl(backlog); } - BAN::ErrorOr Inode::sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr Inode::recvmsg(msghdr& message, int flags) { LockGuard _(m_mutex); if (!mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return sendto_impl(message, address, address_len); - }; + return recvmsg_impl(message, flags); + } - BAN::ErrorOr Inode::recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) + BAN::ErrorOr Inode::sendmsg(const msghdr& message, int flags) { LockGuard _(m_mutex); if (!mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return recvfrom_impl(buffer, address, address_len); - }; + return sendmsg_impl(message, flags); + } BAN::ErrorOr Inode::getsockname(sockaddr* address, socklen_t* address_len) { diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index b91b2d55..255d36ba 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -190,8 +190,20 @@ namespace Kernel return m_network_layer.bind_socket_to_address(this, address, address_len); } - BAN::ErrorOr TCPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) + BAN::ErrorOr TCPSocket::recvmsg_impl(msghdr& message, int flags) { + if (flags != 0) + { + dwarnln("TODO: recvmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + { + dwarnln("ignoring recvmsg control message"); + message.msg_controllen = 0; + } + if (!m_has_connected) return BAN::Error::from_errno(ENOTCONN); @@ -202,26 +214,38 @@ namespace Kernel TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex)); } - const uint32_t to_recv = BAN::Math::min(buffer.size(), m_recv_window.data_size); + size_t total_recv = 0; + for (int i = 0; i < message.msg_iovlen; i++) + { + auto* recv_buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); - auto* recv_buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); - memcpy(buffer.data(), recv_buffer, to_recv); + const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, m_recv_window.data_size); + memcpy(message.msg_iov[i].iov_base, recv_buffer, nrecv); - m_recv_window.data_size -= to_recv; - m_recv_window.start_seq += to_recv; - if (m_recv_window.data_size > 0) - memmove(recv_buffer, recv_buffer + to_recv, m_recv_window.data_size); + total_recv += nrecv; + m_recv_window.data_size -= nrecv; + m_recv_window.start_seq += nrecv; + if (m_recv_window.data_size == 0) + break; - return to_recv; + // TODO: use circular buffer to avoid this + memmove(recv_buffer, recv_buffer + nrecv, m_recv_window.data_size); + } + + return total_recv; } - BAN::ErrorOr TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr TCPSocket::sendmsg_impl(const msghdr& message, int flags) { - (void)address; - (void)address_len; + if (flags != 0) + { + dwarnln("TODO: sendmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + dwarnln("ignoring sendmsg control message"); - if (address) - return BAN::Error::from_errno(EISCONN); if (!m_has_connected) return BAN::Error::from_errno(ENOTCONN); @@ -232,17 +256,23 @@ namespace Kernel TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex)); } - const size_t to_send = BAN::Math::min(message.size(), m_send_window.buffer->size() - m_send_window.data_size); - + size_t total_sent = 0; + for (int i = 0; i < message.msg_iovlen; i++) { - auto* buffer = reinterpret_cast(m_send_window.buffer->vaddr()); - memcpy(buffer + m_send_window.data_size, message.data(), to_send); - m_send_window.data_size += to_send; + auto* send_buffer = reinterpret_cast(m_send_window.buffer->vaddr()); + + const size_t nsend = BAN::Math::min(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size); + memcpy(send_buffer + m_send_window.data_size, message.msg_iov[i].iov_base, nsend); + + total_sent += nsend; + m_send_window.data_size += nsend; + if (m_send_window.data_size == m_send_window.buffer->size()) + break; } m_thread_blocker.unblock(); - return to_send; + return total_sent; } BAN::ErrorOr TCPSocket::getpeername_impl(sockaddr* address, socklen_t* address_len) diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index 763ab4ec..93fd2d78 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -86,8 +86,20 @@ namespace Kernel return m_network_layer.bind_socket_to_address(this, address, address_len); } - BAN::ErrorOr UDPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) + BAN::ErrorOr UDPSocket::recvmsg_impl(msghdr& message, int flags) { + if (flags != 0) + { + dwarnln("TODO: recvmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + { + dwarnln("ignoring recvmsg control message"); + message.msg_controllen = 0; + } + if (!is_bound()) { dprintln("No interface bound"); @@ -106,14 +118,16 @@ namespace Kernel auto packet_info = m_packets.front(); m_packets.pop(); - size_t nread = BAN::Math::min(packet_info.packet_size, buffer.size()); + auto* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); + + size_t total_recv = 0; + for (int i = 0; i < message.msg_iovlen; i++) + { + const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, packet_info.packet_size - total_recv); + memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv); + total_recv += nrecv; + } - uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); - memcpy( - buffer.data(), - packet_buffer, - nread - ); memmove( packet_buffer, packet_buffer + packet_info.packet_size, @@ -122,21 +136,49 @@ namespace Kernel m_packet_total_size -= packet_info.packet_size; - if (address && address_len) + if (message.msg_name && message.msg_namelen) { - if (*address_len > (socklen_t)sizeof(sockaddr_storage)) - *address_len = sizeof(sockaddr_storage); - memcpy(address, &packet_info.sender, *address_len); + const size_t namelen = BAN::Math::min(message.msg_namelen, sizeof(sockaddr_storage)); + memcpy(message.msg_name, &packet_info.sender, namelen); + message.msg_namelen = namelen; } - return nread; + return total_recv; } - BAN::ErrorOr UDPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr UDPSocket::sendmsg_impl(const msghdr& message, int flags) { + if (flags != 0) + { + dwarnln("TODO: recvmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + dwarnln("ignoring recvmsg control message"); + if (!is_bound()) - TRY(m_network_layer.bind_socket_to_unused(this, address, address_len)); - return TRY(m_network_layer.sendto(*this, message, address, address_len)); + TRY(m_network_layer.bind_socket_to_unused(this, static_cast(message.msg_name), message.msg_namelen)); + + const size_t total_send_size = + [&message]() -> size_t { + size_t result = 0; + for (int i = 0; i < message.msg_iovlen; i++) + result += message.msg_iov[i].iov_len; + return result; + }(); + + BAN::Vector buffer; + TRY(buffer.resize(total_send_size)); + + size_t offset = 0; + for (int i = 0; i < message.msg_iovlen; i++) + { + memcpy(buffer.data() + offset, message.msg_iov[i].iov_base, message.msg_iov[i].iov_len); + offset += message.msg_iov[i].iov_len; + } + + return TRY(m_network_layer.sendto(*this, buffer.span(), static_cast(message.msg_name), message.msg_namelen)); } BAN::ErrorOr UDPSocket::ioctl_impl(int request, void* argument) diff --git a/kernel/kernel/Networking/UNIX/Socket.cpp b/kernel/kernel/Networking/UNIX/Socket.cpp index 0d426a24..4f5a852f 100644 --- a/kernel/kernel/Networking/UNIX/Socket.cpp +++ b/kernel/kernel/Networking/UNIX/Socket.cpp @@ -1,6 +1,7 @@ #include + #include -#include +#include #include #include #include @@ -22,10 +23,28 @@ namespace Kernel }; static BAN::HashMap, BAN::WeakPtr, UnixSocketHash> s_bound_sockets; - static SpinLock s_bound_socket_lock; + static Mutex s_bound_socket_lock; static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE; + static BAN::ErrorOr validate_sockaddr_un(const sockaddr* address, socklen_t address_len) + { + if (address_len < static_cast(sizeof(sa_family_t))) + return BAN::Error::from_errno(EINVAL); + if (address_len > static_cast(sizeof(sockaddr_un))) + address_len = sizeof(sockaddr_un); + + const auto& sockaddr_un = *reinterpret_cast(address); + if (sockaddr_un.sun_family != AF_UNIX) + return BAN::Error::from_errno(EINVAL); + + size_t length = 0; + while (length < address_len - sizeof(sa_family_t) && sockaddr_un.sun_path[length]) + length++; + + return BAN::StringView { sockaddr_un.sun_path, length }; + } + // FIXME: why is this using spinlocks instead of mutexes?? BAN::ErrorOr> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info) @@ -64,7 +83,7 @@ namespace Kernel { if (is_bound() && !is_bound_to_unused()) { - SpinLockGuard _(s_bound_socket_lock); + LockGuard _(s_bound_socket_lock); s_bound_sockets.remove(m_bound_file.inode); } if (m_info.has()) @@ -105,11 +124,9 @@ namespace Kernel BAN::RefPtr pending; { - SpinLockGuard guard(connection_info.pending_lock); - - SpinLockGuardAsMutex smutex(guard); + LockGuard _(connection_info.pending_lock); while (connection_info.pending_connections.empty()) - TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker, &smutex)); + TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker, &connection_info.pending_lock)); pending = connection_info.pending_connections.front(); connection_info.pending_connections.pop(); @@ -146,15 +163,11 @@ namespace Kernel BAN::ErrorOr UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len) { - if (address_len != sizeof(sockaddr_un)) - return BAN::Error::from_errno(EINVAL); - auto& sockaddr_un = *reinterpret_cast(address); - if (sockaddr_un.sun_family != AF_UNIX) - return BAN::Error::from_errno(EAFNOSUPPORT); + const auto sun_path = TRY(validate_sockaddr_un(address, address_len)); if (!is_bound()) TRY(m_bound_file.canonical_path.push_back('X')); - auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path)); + auto absolute_path = TRY(Process::current().absolute_path_of(sun_path)); auto file = TRY(VirtualFileSystem::get().file_from_absolute_path( Process::current().root_file().inode, Process::current().credentials(), @@ -165,7 +178,7 @@ namespace Kernel BAN::RefPtr target; { - SpinLockGuard _(s_bound_socket_lock); + LockGuard _(s_bound_socket_lock); auto it = s_bound_sockets.find(file.inode); if (it == s_bound_sockets.end()) return BAN::Error::from_errno(ECONNREFUSED); @@ -196,7 +209,7 @@ namespace Kernel { auto& target_info = target->m_info.get(); - SpinLockGuard guard(target_info.pending_lock); + LockGuard _(target_info.pending_lock); if (target_info.pending_connections.size() < target_info.pending_connections.capacity()) { @@ -205,8 +218,7 @@ namespace Kernel break; } - SpinLockGuardAsMutex smutex(guard); - TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker, &smutex)); + TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker, &target_info.pending_lock)); } target->epoll_notify(EPOLLIN); @@ -236,21 +248,16 @@ namespace Kernel { if (is_bound()) return BAN::Error::from_errno(EINVAL); - if (address_len != sizeof(sockaddr_un)) - return BAN::Error::from_errno(EINVAL); - auto& sockaddr_un = *reinterpret_cast(address); - if (sockaddr_un.sun_family != AF_UNIX) - return BAN::Error::from_errno(EAFNOSUPPORT); - auto bind_path = BAN::StringView(sockaddr_un.sun_path); - if (bind_path.empty()) + const auto sun_path = TRY(validate_sockaddr_un(address, address_len)); + if (sun_path.empty()) return BAN::Error::from_errno(EINVAL); // FIXME: This feels sketchy - auto parent_file = bind_path.front() == '/' + auto parent_file = sun_path.front() == '/' ? TRY(Process::current().root_file().clone()) : TRY(Process::current().working_directory().clone()); - if (auto ret = Process::current().create_file_or_dir(AT_FDCWD, bind_path.data(), 0755 | S_IFSOCK); ret.is_error()) + if (auto ret = Process::current().create_file_or_dir(AT_FDCWD, sun_path.data(), 0755 | S_IFSOCK); ret.is_error()) { if (ret.error().get_error_code() == EEXIST) return BAN::Error::from_errno(EADDRINUSE); @@ -260,11 +267,11 @@ namespace Kernel Process::current().root_file().inode, parent_file, Process::current().credentials(), - bind_path, + sun_path, O_RDWR )); - SpinLockGuard _(s_bound_socket_lock); + LockGuard _(s_bound_socket_lock); if (s_bound_sockets.contains(file.inode)) return BAN::Error::from_errno(EADDRINUSE); TRY(s_bound_sockets.emplace(file.inode, TRY(get_weak_ptr()))); @@ -287,21 +294,24 @@ namespace Kernel } } - BAN::ErrorOr UnixDomainSocket::add_packet(BAN::ConstByteSpan packet) + BAN::ErrorOr UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size) { - SpinLockGuard guard(m_packet_lock); - while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size) - { - SpinLockGuardAsMutex smutex(guard); - TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex)); - } + LockGuard _(m_packet_lock); + while (m_packet_sizes.full() || m_packet_size_total + total_size > s_packet_buffer_size) + TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr() + m_packet_size_total); - memcpy(packet_buffer, packet.data(), packet.size()); - m_packet_size_total += packet.size(); - if (!is_streaming()) - m_packet_sizes.push(packet.size()); + size_t offset = 0; + for (int i = 0; i < packet.msg_iovlen; i++) + { + memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len); + offset += packet.msg_iov[i].iov_len; + } + + ASSERT(offset == total_size); + m_packet_size_total += total_size; + m_packet_sizes.push(total_size); m_packet_thread_blocker.unblock(); @@ -348,27 +358,105 @@ namespace Kernel return false; } - BAN::ErrorOr UnixDomainSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr UnixDomainSocket::recvmsg_impl(msghdr& message, int flags) { - if (message.size() > s_packet_buffer_size) + if (flags != 0) + { + dwarnln("TODO: recvmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + { + dwarnln("ignoring recvmsg control message"); + message.msg_controllen = 0; + } + + LockGuard _(m_packet_lock); + while (m_packet_size_total == 0) + { + if (m_info.has()) + { + auto& connection_info = m_info.get(); + bool expected = true; + if (connection_info.target_closed.compare_exchange(expected, false)) + return 0; + if (!connection_info.connection) + return BAN::Error::from_errno(ENOTCONN); + } + + TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); + } + + uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); + + const size_t max_recv_size = is_streaming() ? m_packet_size_total : m_packet_sizes.front(); + + size_t total_recv = 0; + for (int i = 0; i < message.msg_iovlen; i++) + { + const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, max_recv_size - total_recv); + memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv); + total_recv += nrecv; + } + + size_t bytes_to_handle = total_recv; + while (bytes_to_handle) + { + const size_t to_handle = BAN::Math::min(bytes_to_handle, m_packet_sizes.front()); + bytes_to_handle -= to_handle; + m_packet_sizes.front() -= to_handle; + if (m_packet_sizes.front() == 0) + m_packet_sizes.pop(); + } + + const size_t to_discard = is_streaming() ? total_recv : max_recv_size; + memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard); + m_packet_size_total -= to_discard; + + m_packet_thread_blocker.unblock(); + + epoll_notify(EPOLLOUT); + + return total_recv; + } + + BAN::ErrorOr UnixDomainSocket::sendmsg_impl(const msghdr& message, int flags) + { + if (flags != 0) + { + dwarnln("TODO: sendmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + dwarnln("ignoring sendmsg control message"); + + const size_t total_message_size = + [&message]() -> size_t { + size_t result = 0; + for (int i = 0; i < message.msg_iovlen; i++) + result += message.msg_iov[i].iov_len; + return result; + }(); + + if (total_message_size > s_packet_buffer_size) return BAN::Error::from_errno(ENOBUFS); if (m_info.has()) { auto& connection_info = m_info.get(); - if (address) - return BAN::Error::from_errno(EISCONN); auto target = connection_info.connection.lock(); if (!target) return BAN::Error::from_errno(ENOTCONN); - TRY(target->add_packet(message)); - return message.size(); + TRY(target->add_packet(message, total_message_size)); + return total_message_size; } else { BAN::RefPtr target_inode; - if (!address) + if (!message.msg_name) { auto& connectionless_info = m_info.get(); if (connectionless_info.peer_address.empty()) @@ -384,13 +472,8 @@ namespace Kernel } else { - if (address_len != sizeof(sockaddr_un)) - return BAN::Error::from_errno(EINVAL); - auto& sockaddr_un = *reinterpret_cast(address); - if (sockaddr_un.sun_family != AF_UNIX) - return BAN::Error::from_errno(EAFNOSUPPORT); - - auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path)); + const auto sun_path = TRY(validate_sockaddr_un(static_cast(message.msg_name), message.msg_namelen)); + auto absolute_path = TRY(Process::current().absolute_path_of(sun_path)); target_inode = TRY(VirtualFileSystem::get().file_from_absolute_path( Process::current().root_file().inode, Process::current().credentials(), @@ -399,57 +482,21 @@ namespace Kernel )).inode; } - SpinLockGuard _(s_bound_socket_lock); - auto it = s_bound_sockets.find(target_inode); - if (it == s_bound_sockets.end()) - return BAN::Error::from_errno(EDESTADDRREQ); - auto target = it->value.lock(); - if (!target) - return BAN::Error::from_errno(EDESTADDRREQ); - TRY(target->add_packet(message)); - return message.size(); - } - } - - BAN::ErrorOr UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) - { - SpinLockGuard guard(m_packet_lock); - while (m_packet_size_total == 0) - { - if (m_info.has()) + BAN::RefPtr target; { - auto& connection_info = m_info.get(); - bool expected = true; - if (connection_info.target_closed.compare_exchange(expected, false)) - return 0; - if (!connection_info.connection) - return BAN::Error::from_errno(ENOTCONN); + LockGuard _(s_bound_socket_lock); + auto it = s_bound_sockets.find(target_inode); + if (it == s_bound_sockets.end()) + return BAN::Error::from_errno(EDESTADDRREQ); + target = it->value.lock(); } - SpinLockGuardAsMutex smutex(guard); - TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex)); + if (!target) + return BAN::Error::from_errno(EDESTADDRREQ); + TRY(target->add_packet(message, total_message_size)); + + return total_message_size; } - - uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); - - size_t nread = 0; - if (is_streaming()) - nread = BAN::Math::min(buffer.size(), m_packet_size_total); - else - { - nread = BAN::Math::min(buffer.size(), m_packet_sizes.front()); - m_packet_sizes.pop(); - } - - memcpy(buffer.data(), packet_buffer, nread); - memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread); - m_packet_size_total -= nread; - - m_packet_thread_blocker.unblock(); - - epoll_notify(EPOLLOUT); - - return nread; } BAN::ErrorOr UnixDomainSocket::getpeername_impl(sockaddr* address, socklen_t* address_len) diff --git a/kernel/kernel/OpenFileDescriptorSet.cpp b/kernel/kernel/OpenFileDescriptorSet.cpp index ab8470c9..4bb03f8a 100644 --- a/kernel/kernel/OpenFileDescriptorSet.cpp +++ b/kernel/kernel/OpenFileDescriptorSet.cpp @@ -424,7 +424,24 @@ namespace Kernel } if (inode->mode().ifsock()) - return recvfrom(fd, buffer, nullptr, nullptr); + { + iovec iov { + .iov_base = buffer.data(), + .iov_len = buffer.size(), + }; + + msghdr message { + .msg_name = nullptr, + .msg_namelen = 0, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = nullptr, + .msg_controllen = 0, + .msg_flags = 0, + }; + + return recvmsg(fd, message, 0); + } size_t nread; { @@ -461,7 +478,24 @@ namespace Kernel } if (inode->mode().ifsock()) - return sendto(fd, buffer, nullptr, 0); + { + iovec iov { + .iov_base = const_cast(buffer.data()), + .iov_len = buffer.size(), + }; + + msghdr message { + .msg_name = nullptr, + .msg_namelen = 0, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = nullptr, + .msg_controllen = 0, + .msg_flags = 0, + }; + + return sendmsg(fd, message, 0); + } size_t nwrite; { @@ -515,7 +549,7 @@ namespace Kernel } } - BAN::ErrorOr OpenFileDescriptorSet::recvfrom(int fd, BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) + BAN::ErrorOr OpenFileDescriptorSet::recvmsg(int fd, msghdr& message, int flags) { BAN::RefPtr inode; bool is_nonblock; @@ -533,10 +567,10 @@ namespace Kernel LockGuard _(inode->m_mutex); if (is_nonblock && !inode->can_read()) return BAN::Error::from_errno(EWOULDBLOCK); - return inode->recvfrom(buffer, address, address_len); + return inode->recvmsg(message, flags); } - BAN::ErrorOr OpenFileDescriptorSet::sendto(int fd, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr OpenFileDescriptorSet::sendmsg(int fd, const msghdr& message, int flags) { BAN::RefPtr inode; bool is_nonblock; @@ -559,7 +593,7 @@ namespace Kernel } if (is_nonblock && !inode->can_write()) return BAN::Error::from_errno(EWOULDBLOCK); - return inode->sendto(buffer, address, address_len); + return inode->sendmsg(message, flags); } BAN::ErrorOr OpenFileDescriptorSet::file_of(int fd) const diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index 935ece78..15adc08d 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -1569,73 +1569,72 @@ namespace Kernel return 0; } - BAN::ErrorOr Process::sys_sendto(const sys_sendto_t* _arguments) + BAN::ErrorOr Process::sys_recvmsg(int socket, msghdr* _message, int flags) { - sys_sendto_t arguments; + msghdr message; { LockGuard _(m_process_lock); - TRY(validate_pointer_access(_arguments, sizeof(sys_sendto_t), false)); - arguments = *_arguments; + TRY(validate_pointer_access(_message, sizeof(msghdr), true)); + message = *_message; } - if (arguments.length == 0) - return BAN::Error::from_errno(EINVAL); - - MemoryRegion* message_region = nullptr; - MemoryRegion* address_region = nullptr; - - BAN::ScopeGuard _([&] { - if (message_region) - message_region->unpin(); - if (address_region) - address_region->unpin(); + BAN::Vector regions; + BAN::ScopeGuard _([®ions] { + for (auto* region : regions) + if (region != nullptr) + region->unpin(); }); - message_region = TRY(validate_and_pin_pointer_access(arguments.message, arguments.length, false)); - if (arguments.dest_addr) - address_region = TRY(validate_and_pin_pointer_access(arguments.dest_addr, arguments.dest_len, false)); + if (message.msg_name) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_name, message.msg_namelen, true)))); + if (message.msg_control) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_control, message.msg_controllen, true)))); + if (message.msg_iov) + { + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov, message.msg_iovlen * sizeof(iovec), true)))); + for (int i = 0; i < message.msg_iovlen; i++) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov[i].iov_base, message.msg_iov[i].iov_len, true)))); + } - auto message = BAN::ConstByteSpan(static_cast(arguments.message), arguments.length); - return TRY(m_open_file_descriptors.sendto(arguments.socket, message, arguments.dest_addr, arguments.dest_len)); + auto ret = TRY(m_open_file_descriptors.recvmsg(socket, message, flags)); + + { + LockGuard _(m_process_lock); + TRY(validate_pointer_access(_message, sizeof(msghdr), true)); + *_message = message; + } + + return ret; } - BAN::ErrorOr Process::sys_recvfrom(sys_recvfrom_t* _arguments) + BAN::ErrorOr Process::sys_sendmsg(int socket, const msghdr* _message, int flags) { - sys_recvfrom_t arguments; + msghdr message; { LockGuard _(m_process_lock); - TRY(validate_pointer_access(_arguments, sizeof(sys_sendto_t), false)); - arguments = *_arguments; + TRY(validate_pointer_access(_message, sizeof(msghdr), false)); + message = *_message; } - if (!arguments.address != !arguments.address_len) - return BAN::Error::from_errno(EINVAL); - if (arguments.length == 0) - return BAN::Error::from_errno(EINVAL); - - MemoryRegion* buffer_region = nullptr; - MemoryRegion* address_region1 = nullptr; - MemoryRegion* address_region2 = nullptr; - - BAN::ScopeGuard _([&] { - if (buffer_region) - buffer_region->unpin(); - if (address_region1) - address_region1->unpin(); - if (address_region2) - address_region2->unpin(); + BAN::Vector regions; + BAN::ScopeGuard _([®ions] { + for (auto* region : regions) + if (region != nullptr) + region->unpin(); }); - buffer_region = TRY(validate_and_pin_pointer_access(arguments.buffer, arguments.length, true)); - - if (arguments.address_len) + if (message.msg_name) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_name, message.msg_namelen, false)))); + if (message.msg_control) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_control, message.msg_controllen, false)))); + if (message.msg_iov) { - address_region1 = TRY(validate_and_pin_pointer_access(arguments.address_len, sizeof(*arguments.address_len), true)); - address_region2 = TRY(validate_and_pin_pointer_access(arguments.address, *arguments.address_len, true)); + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov, message.msg_iovlen * sizeof(iovec), false)))); + for (int i = 0; i < message.msg_iovlen; i++) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov[i].iov_base, message.msg_iov[i].iov_len, false)))); } - auto message = BAN::ByteSpan(static_cast(arguments.buffer), arguments.length); - return TRY(m_open_file_descriptors.recvfrom(arguments.socket, message, arguments.address, arguments.address_len)); + return TRY(m_open_file_descriptors.sendmsg(socket, message, flags)); } BAN::ErrorOr Process::sys_ioctl(int fildes, int request, void* arg) diff --git a/kernel/kernel/Syscall.cpp b/kernel/kernel/Syscall.cpp index 2a6e1fae..ff5d8798 100644 --- a/kernel/kernel/Syscall.cpp +++ b/kernel/kernel/Syscall.cpp @@ -124,8 +124,8 @@ namespace Kernel case SYS_WAIT: case SYS_ACCEPT: case SYS_CONNECT: - case SYS_RECVFROM: - case SYS_SENDTO: + case SYS_RECVMSG: + case SYS_SENDMSG: case SYS_FLOCK: return true; default: diff --git a/userspace/libraries/LibC/include/sys/syscall.h b/userspace/libraries/LibC/include/sys/syscall.h index 52b8d530..7d236e52 100644 --- a/userspace/libraries/LibC/include/sys/syscall.h +++ b/userspace/libraries/LibC/include/sys/syscall.h @@ -67,8 +67,8 @@ __BEGIN_DECLS O(SYS_SOCKET, socket) \ O(SYS_SOCKETPAIR, socketpair) \ O(SYS_BIND, bind) \ - O(SYS_SENDTO, sendto) \ - O(SYS_RECVFROM, recvfrom) \ + O(SYS_RECVMSG, recvmsg) \ + O(SYS_SENDMSG, sendmsg) \ O(SYS_IOCTL, ioctl) \ O(SYS_ACCEPT, accept) \ O(SYS_CONNECT, connect) \ diff --git a/userspace/libraries/LibC/sys/socket.cpp b/userspace/libraries/LibC/sys/socket.cpp index f3c392fb..b3a12935 100644 --- a/userspace/libraries/LibC/sys/socket.cpp +++ b/userspace/libraries/LibC/sys/socket.cpp @@ -42,104 +42,69 @@ ssize_t recv(int socket, void* __restrict buffer, size_t length, int flags) ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len) { - pthread_testcancel(); - sys_recvfrom_t arguments { - .socket = socket, - .buffer = buffer, - .length = length, - .flags = flags, - .address = address, - .address_len = address_len + // cancellation point in recvmsg + + iovec iov { + .iov_base = buffer, + .iov_len = length, }; - return syscall(SYS_RECVFROM, &arguments); + + msghdr message { + .msg_name = address, + .msg_namelen = address_len ? *address_len : 0, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = 0, + }; + + const ssize_t ret = recvmsg(socket, &message, flags); + + if (address_len) + *address_len = message.msg_namelen; + + return ret; } -ssize_t send(int socket, const void* message, size_t length, int flags) +ssize_t recvmsg(int socket, struct msghdr* message, int flags) +{ + pthread_testcancel(); + return syscall(SYS_RECVMSG, socket, message, flags); +} + +ssize_t send(int socket, const void* buffer, size_t length, int flags) { // cancellation point in sendto - return sendto(socket, message, length, flags, nullptr, 0); + return sendto(socket, buffer, length, flags, nullptr, 0); } -ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len) +ssize_t sendto(int socket, const void* buffer, size_t length, int flags, const struct sockaddr* address, socklen_t address_len) +{ + // cancellation point in sendmsg + + iovec iov { + .iov_base = const_cast(buffer), + .iov_len = length, + }; + + msghdr message { + .msg_name = const_cast(address), + .msg_namelen = address_len, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = 0, + }; + + return sendmsg(socket, &message, flags); +} + +ssize_t sendmsg(int socket, const struct msghdr* message, int flags) { pthread_testcancel(); - sys_sendto_t arguments { - .socket = socket, - .message = message, - .length = length, - .flags = flags, - .dest_addr = dest_addr, - .dest_len = dest_len - }; - return syscall(SYS_SENDTO, &arguments); -} - -ssize_t recvmsg(int socket, struct msghdr* message, int flags) -{ - if (CMSG_FIRSTHDR(message)) - { - dwarnln("TODO: recvmsg ancillary data"); - errno = ENOTSUP; - return -1; - } - - size_t total_recv = 0; - - for (int i = 0; i < message->msg_iovlen; i++) - { - const ssize_t nrecv = recvfrom( - socket, - message->msg_iov[i].iov_base, - message->msg_iov[i].iov_len, - flags, - static_cast(message->msg_name), - &message->msg_namelen - ); - - if (nrecv < 0) - return -1; - - total_recv += nrecv; - - if (static_cast(nrecv) < message->msg_iov[i].iov_len) - break; - } - - return total_recv; -} - -ssize_t sendmsg(int socket, const struct msghdr* message, int flags) -{ - if (CMSG_FIRSTHDR(message)) - { - dwarnln("TODO: sendmsg ancillary data"); - errno = ENOTSUP; - return -1; - } - - size_t total_sent = 0; - - for (int i = 0; i < message->msg_iovlen; i++) - { - const ssize_t nsend = sendto( - socket, - message->msg_iov[i].iov_base, - message->msg_iov[i].iov_len, - flags, - static_cast(message->msg_name), - message->msg_namelen - ); - - if (nsend < 0) - return -1; - - total_sent += nsend; - - if (static_cast(nsend) < message->msg_iov[i].iov_len) - break; - } - - return total_sent; + return syscall(SYS_SENDMSG, socket, message, flags); } int socket(int domain, int type, int protocol)