diff --git a/kernel/include/kernel/FS/Inode.h b/kernel/include/kernel/FS/Inode.h index c5cb7728..a1f1f2e0 100644 --- a/kernel/include/kernel/FS/Inode.h +++ b/kernel/include/kernel/FS/Inode.h @@ -108,8 +108,8 @@ namespace Kernel BAN::ErrorOr bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr listen(int backlog); - BAN::ErrorOr sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len); - BAN::ErrorOr recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len); + BAN::ErrorOr sendmsg(const msghdr& message, int flags); + BAN::ErrorOr recvmsg(msghdr& message, int flags); BAN::ErrorOr getsockname(sockaddr* address, socklen_t* address_len); BAN::ErrorOr getpeername(sockaddr* address, socklen_t* address_len); @@ -155,8 +155,8 @@ namespace Kernel virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr recvmsg_impl(msghdr&, int) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr sendmsg_impl(const msghdr&, int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr getsockname_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } diff --git a/kernel/include/kernel/FS/Socket.h b/kernel/include/kernel/FS/Socket.h index 474ff279..7e719e80 100644 --- a/kernel/include/kernel/FS/Socket.h +++ b/kernel/include/kernel/FS/Socket.h @@ -51,8 +51,6 @@ namespace Kernel : m_info(info) {} - BAN::ErrorOr read_impl(off_t, BAN::ByteSpan buffer) override { return recvfrom_impl(buffer, nullptr, nullptr); } - BAN::ErrorOr write_impl(off_t, BAN::ConstByteSpan buffer) override { return sendto_impl(buffer, nullptr, 0); } BAN::ErrorOr fsync_impl() final override { return {}; } private: diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index 0a0b63d7..a3298d57 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -60,8 +60,8 @@ namespace Kernel virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr listen_impl(int) override; virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override; + virtual BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; + virtual BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override; virtual BAN::ErrorOr ioctl_impl(int, void*) override; diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index 6e39381e..83e6ed0b 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -34,8 +34,8 @@ namespace Kernel virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override; - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) override; + virtual BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; + virtual BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override { return BAN::Error::from_errno(ENOTCONN); } virtual BAN::ErrorOr ioctl_impl(int, void*) override; diff --git a/kernel/include/kernel/Networking/UNIX/Socket.h b/kernel/include/kernel/Networking/UNIX/Socket.h index d246bde5..6616df74 100644 --- a/kernel/include/kernel/Networking/UNIX/Socket.h +++ b/kernel/include/kernel/Networking/UNIX/Socket.h @@ -25,8 +25,8 @@ namespace Kernel virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr listen_impl(int) override; virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override; + virtual BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; + virtual BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override; virtual bool can_read_impl() const override; @@ -38,7 +38,7 @@ namespace Kernel UnixDomainSocket(Socket::Type, const Socket::Info&); ~UnixDomainSocket(); - BAN::ErrorOr add_packet(BAN::ConstByteSpan); + BAN::ErrorOr add_packet(const msghdr&, size_t total_size); bool is_bound() const { return !m_bound_file.canonical_path.empty(); } bool is_bound_to_unused() const { return !m_bound_file.inode; } @@ -54,7 +54,7 @@ namespace Kernel BAN::WeakPtr connection; BAN::Queue> pending_connections; ThreadBlocker pending_thread_blocker; - SpinLock pending_lock; + Mutex pending_lock; }; struct ConnectionlessInfo @@ -71,7 +71,7 @@ namespace Kernel BAN::CircularQueue m_packet_sizes; size_t m_packet_size_total { 0 }; BAN::UniqPtr m_packet_buffer; - SpinLock m_packet_lock; + Mutex m_packet_lock; ThreadBlocker m_packet_thread_blocker; friend class BAN::RefPtr; diff --git a/kernel/include/kernel/OpenFileDescriptorSet.h b/kernel/include/kernel/OpenFileDescriptorSet.h index 9486627c..791b0b89 100644 --- a/kernel/include/kernel/OpenFileDescriptorSet.h +++ b/kernel/include/kernel/OpenFileDescriptorSet.h @@ -51,8 +51,8 @@ namespace Kernel BAN::ErrorOr read_dir_entries(int fd, struct dirent* list, size_t list_len); - BAN::ErrorOr recvfrom(int fd, BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len); - BAN::ErrorOr sendto(int fd, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len); + BAN::ErrorOr recvmsg(int socket, msghdr& message, int flags); + BAN::ErrorOr sendmsg(int socket, const msghdr& message, int flags); BAN::ErrorOr file_of(int) const; BAN::ErrorOr path_of(int) const; diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index b44d4246..10a1cdcc 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -134,8 +134,8 @@ namespace Kernel BAN::ErrorOr sys_bind(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr sys_connect(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr sys_listen(int socket, int backlog); - BAN::ErrorOr sys_sendto(const sys_sendto_t*); - BAN::ErrorOr sys_recvfrom(sys_recvfrom_t*); + BAN::ErrorOr sys_recvmsg(int socket, msghdr* message, int flags); + BAN::ErrorOr sys_sendmsg(int socket, const msghdr* message, int flags); BAN::ErrorOr sys_ioctl(int fildes, int request, void* arg); diff --git a/kernel/kernel/FS/Inode.cpp b/kernel/kernel/FS/Inode.cpp index 250413ec..7a5724c8 100644 --- a/kernel/kernel/FS/Inode.cpp +++ b/kernel/kernel/FS/Inode.cpp @@ -176,21 +176,21 @@ namespace Kernel return listen_impl(backlog); } - BAN::ErrorOr Inode::sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr Inode::recvmsg(msghdr& message, int flags) { LockGuard _(m_mutex); if (!mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return sendto_impl(message, address, address_len); - }; + return recvmsg_impl(message, flags); + } - BAN::ErrorOr Inode::recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) + BAN::ErrorOr Inode::sendmsg(const msghdr& message, int flags) { LockGuard _(m_mutex); if (!mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return recvfrom_impl(buffer, address, address_len); - }; + return sendmsg_impl(message, flags); + } BAN::ErrorOr Inode::getsockname(sockaddr* address, socklen_t* address_len) { diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index b91b2d55..255d36ba 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -190,8 +190,20 @@ namespace Kernel return m_network_layer.bind_socket_to_address(this, address, address_len); } - BAN::ErrorOr TCPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) + BAN::ErrorOr TCPSocket::recvmsg_impl(msghdr& message, int flags) { + if (flags != 0) + { + dwarnln("TODO: recvmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + { + dwarnln("ignoring recvmsg control message"); + message.msg_controllen = 0; + } + if (!m_has_connected) return BAN::Error::from_errno(ENOTCONN); @@ -202,26 +214,38 @@ namespace Kernel TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex)); } - const uint32_t to_recv = BAN::Math::min(buffer.size(), m_recv_window.data_size); + size_t total_recv = 0; + for (int i = 0; i < message.msg_iovlen; i++) + { + auto* recv_buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); - auto* recv_buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); - memcpy(buffer.data(), recv_buffer, to_recv); + const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, m_recv_window.data_size); + memcpy(message.msg_iov[i].iov_base, recv_buffer, nrecv); - m_recv_window.data_size -= to_recv; - m_recv_window.start_seq += to_recv; - if (m_recv_window.data_size > 0) - memmove(recv_buffer, recv_buffer + to_recv, m_recv_window.data_size); + total_recv += nrecv; + m_recv_window.data_size -= nrecv; + m_recv_window.start_seq += nrecv; + if (m_recv_window.data_size == 0) + break; - return to_recv; + // TODO: use circular buffer to avoid this + memmove(recv_buffer, recv_buffer + nrecv, m_recv_window.data_size); + } + + return total_recv; } - BAN::ErrorOr TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr TCPSocket::sendmsg_impl(const msghdr& message, int flags) { - (void)address; - (void)address_len; + if (flags != 0) + { + dwarnln("TODO: sendmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + dwarnln("ignoring sendmsg control message"); - if (address) - return BAN::Error::from_errno(EISCONN); if (!m_has_connected) return BAN::Error::from_errno(ENOTCONN); @@ -232,17 +256,23 @@ namespace Kernel TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex)); } - const size_t to_send = BAN::Math::min(message.size(), m_send_window.buffer->size() - m_send_window.data_size); - + size_t total_sent = 0; + for (int i = 0; i < message.msg_iovlen; i++) { - auto* buffer = reinterpret_cast(m_send_window.buffer->vaddr()); - memcpy(buffer + m_send_window.data_size, message.data(), to_send); - m_send_window.data_size += to_send; + auto* send_buffer = reinterpret_cast(m_send_window.buffer->vaddr()); + + const size_t nsend = BAN::Math::min(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size); + memcpy(send_buffer + m_send_window.data_size, message.msg_iov[i].iov_base, nsend); + + total_sent += nsend; + m_send_window.data_size += nsend; + if (m_send_window.data_size == m_send_window.buffer->size()) + break; } m_thread_blocker.unblock(); - return to_send; + return total_sent; } BAN::ErrorOr TCPSocket::getpeername_impl(sockaddr* address, socklen_t* address_len) diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index 763ab4ec..93fd2d78 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -86,8 +86,20 @@ namespace Kernel return m_network_layer.bind_socket_to_address(this, address, address_len); } - BAN::ErrorOr UDPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) + BAN::ErrorOr UDPSocket::recvmsg_impl(msghdr& message, int flags) { + if (flags != 0) + { + dwarnln("TODO: recvmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + { + dwarnln("ignoring recvmsg control message"); + message.msg_controllen = 0; + } + if (!is_bound()) { dprintln("No interface bound"); @@ -106,14 +118,16 @@ namespace Kernel auto packet_info = m_packets.front(); m_packets.pop(); - size_t nread = BAN::Math::min(packet_info.packet_size, buffer.size()); + auto* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); + + size_t total_recv = 0; + for (int i = 0; i < message.msg_iovlen; i++) + { + const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, packet_info.packet_size - total_recv); + memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv); + total_recv += nrecv; + } - uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); - memcpy( - buffer.data(), - packet_buffer, - nread - ); memmove( packet_buffer, packet_buffer + packet_info.packet_size, @@ -122,21 +136,49 @@ namespace Kernel m_packet_total_size -= packet_info.packet_size; - if (address && address_len) + if (message.msg_name && message.msg_namelen) { - if (*address_len > (socklen_t)sizeof(sockaddr_storage)) - *address_len = sizeof(sockaddr_storage); - memcpy(address, &packet_info.sender, *address_len); + const size_t namelen = BAN::Math::min(message.msg_namelen, sizeof(sockaddr_storage)); + memcpy(message.msg_name, &packet_info.sender, namelen); + message.msg_namelen = namelen; } - return nread; + return total_recv; } - BAN::ErrorOr UDPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr UDPSocket::sendmsg_impl(const msghdr& message, int flags) { + if (flags != 0) + { + dwarnln("TODO: recvmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + dwarnln("ignoring recvmsg control message"); + if (!is_bound()) - TRY(m_network_layer.bind_socket_to_unused(this, address, address_len)); - return TRY(m_network_layer.sendto(*this, message, address, address_len)); + TRY(m_network_layer.bind_socket_to_unused(this, static_cast(message.msg_name), message.msg_namelen)); + + const size_t total_send_size = + [&message]() -> size_t { + size_t result = 0; + for (int i = 0; i < message.msg_iovlen; i++) + result += message.msg_iov[i].iov_len; + return result; + }(); + + BAN::Vector buffer; + TRY(buffer.resize(total_send_size)); + + size_t offset = 0; + for (int i = 0; i < message.msg_iovlen; i++) + { + memcpy(buffer.data() + offset, message.msg_iov[i].iov_base, message.msg_iov[i].iov_len); + offset += message.msg_iov[i].iov_len; + } + + return TRY(m_network_layer.sendto(*this, buffer.span(), static_cast(message.msg_name), message.msg_namelen)); } BAN::ErrorOr UDPSocket::ioctl_impl(int request, void* argument) diff --git a/kernel/kernel/Networking/UNIX/Socket.cpp b/kernel/kernel/Networking/UNIX/Socket.cpp index 0d426a24..4f5a852f 100644 --- a/kernel/kernel/Networking/UNIX/Socket.cpp +++ b/kernel/kernel/Networking/UNIX/Socket.cpp @@ -1,6 +1,7 @@ #include + #include -#include +#include #include #include #include @@ -22,10 +23,28 @@ namespace Kernel }; static BAN::HashMap, BAN::WeakPtr, UnixSocketHash> s_bound_sockets; - static SpinLock s_bound_socket_lock; + static Mutex s_bound_socket_lock; static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE; + static BAN::ErrorOr validate_sockaddr_un(const sockaddr* address, socklen_t address_len) + { + if (address_len < static_cast(sizeof(sa_family_t))) + return BAN::Error::from_errno(EINVAL); + if (address_len > static_cast(sizeof(sockaddr_un))) + address_len = sizeof(sockaddr_un); + + const auto& sockaddr_un = *reinterpret_cast(address); + if (sockaddr_un.sun_family != AF_UNIX) + return BAN::Error::from_errno(EINVAL); + + size_t length = 0; + while (length < address_len - sizeof(sa_family_t) && sockaddr_un.sun_path[length]) + length++; + + return BAN::StringView { sockaddr_un.sun_path, length }; + } + // FIXME: why is this using spinlocks instead of mutexes?? BAN::ErrorOr> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info) @@ -64,7 +83,7 @@ namespace Kernel { if (is_bound() && !is_bound_to_unused()) { - SpinLockGuard _(s_bound_socket_lock); + LockGuard _(s_bound_socket_lock); s_bound_sockets.remove(m_bound_file.inode); } if (m_info.has()) @@ -105,11 +124,9 @@ namespace Kernel BAN::RefPtr pending; { - SpinLockGuard guard(connection_info.pending_lock); - - SpinLockGuardAsMutex smutex(guard); + LockGuard _(connection_info.pending_lock); while (connection_info.pending_connections.empty()) - TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker, &smutex)); + TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker, &connection_info.pending_lock)); pending = connection_info.pending_connections.front(); connection_info.pending_connections.pop(); @@ -146,15 +163,11 @@ namespace Kernel BAN::ErrorOr UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len) { - if (address_len != sizeof(sockaddr_un)) - return BAN::Error::from_errno(EINVAL); - auto& sockaddr_un = *reinterpret_cast(address); - if (sockaddr_un.sun_family != AF_UNIX) - return BAN::Error::from_errno(EAFNOSUPPORT); + const auto sun_path = TRY(validate_sockaddr_un(address, address_len)); if (!is_bound()) TRY(m_bound_file.canonical_path.push_back('X')); - auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path)); + auto absolute_path = TRY(Process::current().absolute_path_of(sun_path)); auto file = TRY(VirtualFileSystem::get().file_from_absolute_path( Process::current().root_file().inode, Process::current().credentials(), @@ -165,7 +178,7 @@ namespace Kernel BAN::RefPtr target; { - SpinLockGuard _(s_bound_socket_lock); + LockGuard _(s_bound_socket_lock); auto it = s_bound_sockets.find(file.inode); if (it == s_bound_sockets.end()) return BAN::Error::from_errno(ECONNREFUSED); @@ -196,7 +209,7 @@ namespace Kernel { auto& target_info = target->m_info.get(); - SpinLockGuard guard(target_info.pending_lock); + LockGuard _(target_info.pending_lock); if (target_info.pending_connections.size() < target_info.pending_connections.capacity()) { @@ -205,8 +218,7 @@ namespace Kernel break; } - SpinLockGuardAsMutex smutex(guard); - TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker, &smutex)); + TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker, &target_info.pending_lock)); } target->epoll_notify(EPOLLIN); @@ -236,21 +248,16 @@ namespace Kernel { if (is_bound()) return BAN::Error::from_errno(EINVAL); - if (address_len != sizeof(sockaddr_un)) - return BAN::Error::from_errno(EINVAL); - auto& sockaddr_un = *reinterpret_cast(address); - if (sockaddr_un.sun_family != AF_UNIX) - return BAN::Error::from_errno(EAFNOSUPPORT); - auto bind_path = BAN::StringView(sockaddr_un.sun_path); - if (bind_path.empty()) + const auto sun_path = TRY(validate_sockaddr_un(address, address_len)); + if (sun_path.empty()) return BAN::Error::from_errno(EINVAL); // FIXME: This feels sketchy - auto parent_file = bind_path.front() == '/' + auto parent_file = sun_path.front() == '/' ? TRY(Process::current().root_file().clone()) : TRY(Process::current().working_directory().clone()); - if (auto ret = Process::current().create_file_or_dir(AT_FDCWD, bind_path.data(), 0755 | S_IFSOCK); ret.is_error()) + if (auto ret = Process::current().create_file_or_dir(AT_FDCWD, sun_path.data(), 0755 | S_IFSOCK); ret.is_error()) { if (ret.error().get_error_code() == EEXIST) return BAN::Error::from_errno(EADDRINUSE); @@ -260,11 +267,11 @@ namespace Kernel Process::current().root_file().inode, parent_file, Process::current().credentials(), - bind_path, + sun_path, O_RDWR )); - SpinLockGuard _(s_bound_socket_lock); + LockGuard _(s_bound_socket_lock); if (s_bound_sockets.contains(file.inode)) return BAN::Error::from_errno(EADDRINUSE); TRY(s_bound_sockets.emplace(file.inode, TRY(get_weak_ptr()))); @@ -287,21 +294,24 @@ namespace Kernel } } - BAN::ErrorOr UnixDomainSocket::add_packet(BAN::ConstByteSpan packet) + BAN::ErrorOr UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size) { - SpinLockGuard guard(m_packet_lock); - while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size) - { - SpinLockGuardAsMutex smutex(guard); - TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex)); - } + LockGuard _(m_packet_lock); + while (m_packet_sizes.full() || m_packet_size_total + total_size > s_packet_buffer_size) + TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr() + m_packet_size_total); - memcpy(packet_buffer, packet.data(), packet.size()); - m_packet_size_total += packet.size(); - if (!is_streaming()) - m_packet_sizes.push(packet.size()); + size_t offset = 0; + for (int i = 0; i < packet.msg_iovlen; i++) + { + memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len); + offset += packet.msg_iov[i].iov_len; + } + + ASSERT(offset == total_size); + m_packet_size_total += total_size; + m_packet_sizes.push(total_size); m_packet_thread_blocker.unblock(); @@ -348,27 +358,105 @@ namespace Kernel return false; } - BAN::ErrorOr UnixDomainSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr UnixDomainSocket::recvmsg_impl(msghdr& message, int flags) { - if (message.size() > s_packet_buffer_size) + if (flags != 0) + { + dwarnln("TODO: recvmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + { + dwarnln("ignoring recvmsg control message"); + message.msg_controllen = 0; + } + + LockGuard _(m_packet_lock); + while (m_packet_size_total == 0) + { + if (m_info.has()) + { + auto& connection_info = m_info.get(); + bool expected = true; + if (connection_info.target_closed.compare_exchange(expected, false)) + return 0; + if (!connection_info.connection) + return BAN::Error::from_errno(ENOTCONN); + } + + TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); + } + + uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); + + const size_t max_recv_size = is_streaming() ? m_packet_size_total : m_packet_sizes.front(); + + size_t total_recv = 0; + for (int i = 0; i < message.msg_iovlen; i++) + { + const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, max_recv_size - total_recv); + memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv); + total_recv += nrecv; + } + + size_t bytes_to_handle = total_recv; + while (bytes_to_handle) + { + const size_t to_handle = BAN::Math::min(bytes_to_handle, m_packet_sizes.front()); + bytes_to_handle -= to_handle; + m_packet_sizes.front() -= to_handle; + if (m_packet_sizes.front() == 0) + m_packet_sizes.pop(); + } + + const size_t to_discard = is_streaming() ? total_recv : max_recv_size; + memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard); + m_packet_size_total -= to_discard; + + m_packet_thread_blocker.unblock(); + + epoll_notify(EPOLLOUT); + + return total_recv; + } + + BAN::ErrorOr UnixDomainSocket::sendmsg_impl(const msghdr& message, int flags) + { + if (flags != 0) + { + dwarnln("TODO: sendmsg with flags 0x{H}", flags); + return BAN::Error::from_errno(ENOTSUP); + } + + if (CMSG_FIRSTHDR(&message)) + dwarnln("ignoring sendmsg control message"); + + const size_t total_message_size = + [&message]() -> size_t { + size_t result = 0; + for (int i = 0; i < message.msg_iovlen; i++) + result += message.msg_iov[i].iov_len; + return result; + }(); + + if (total_message_size > s_packet_buffer_size) return BAN::Error::from_errno(ENOBUFS); if (m_info.has()) { auto& connection_info = m_info.get(); - if (address) - return BAN::Error::from_errno(EISCONN); auto target = connection_info.connection.lock(); if (!target) return BAN::Error::from_errno(ENOTCONN); - TRY(target->add_packet(message)); - return message.size(); + TRY(target->add_packet(message, total_message_size)); + return total_message_size; } else { BAN::RefPtr target_inode; - if (!address) + if (!message.msg_name) { auto& connectionless_info = m_info.get(); if (connectionless_info.peer_address.empty()) @@ -384,13 +472,8 @@ namespace Kernel } else { - if (address_len != sizeof(sockaddr_un)) - return BAN::Error::from_errno(EINVAL); - auto& sockaddr_un = *reinterpret_cast(address); - if (sockaddr_un.sun_family != AF_UNIX) - return BAN::Error::from_errno(EAFNOSUPPORT); - - auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path)); + const auto sun_path = TRY(validate_sockaddr_un(static_cast(message.msg_name), message.msg_namelen)); + auto absolute_path = TRY(Process::current().absolute_path_of(sun_path)); target_inode = TRY(VirtualFileSystem::get().file_from_absolute_path( Process::current().root_file().inode, Process::current().credentials(), @@ -399,57 +482,21 @@ namespace Kernel )).inode; } - SpinLockGuard _(s_bound_socket_lock); - auto it = s_bound_sockets.find(target_inode); - if (it == s_bound_sockets.end()) - return BAN::Error::from_errno(EDESTADDRREQ); - auto target = it->value.lock(); - if (!target) - return BAN::Error::from_errno(EDESTADDRREQ); - TRY(target->add_packet(message)); - return message.size(); - } - } - - BAN::ErrorOr UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) - { - SpinLockGuard guard(m_packet_lock); - while (m_packet_size_total == 0) - { - if (m_info.has()) + BAN::RefPtr target; { - auto& connection_info = m_info.get(); - bool expected = true; - if (connection_info.target_closed.compare_exchange(expected, false)) - return 0; - if (!connection_info.connection) - return BAN::Error::from_errno(ENOTCONN); + LockGuard _(s_bound_socket_lock); + auto it = s_bound_sockets.find(target_inode); + if (it == s_bound_sockets.end()) + return BAN::Error::from_errno(EDESTADDRREQ); + target = it->value.lock(); } - SpinLockGuardAsMutex smutex(guard); - TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex)); + if (!target) + return BAN::Error::from_errno(EDESTADDRREQ); + TRY(target->add_packet(message, total_message_size)); + + return total_message_size; } - - uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); - - size_t nread = 0; - if (is_streaming()) - nread = BAN::Math::min(buffer.size(), m_packet_size_total); - else - { - nread = BAN::Math::min(buffer.size(), m_packet_sizes.front()); - m_packet_sizes.pop(); - } - - memcpy(buffer.data(), packet_buffer, nread); - memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread); - m_packet_size_total -= nread; - - m_packet_thread_blocker.unblock(); - - epoll_notify(EPOLLOUT); - - return nread; } BAN::ErrorOr UnixDomainSocket::getpeername_impl(sockaddr* address, socklen_t* address_len) diff --git a/kernel/kernel/OpenFileDescriptorSet.cpp b/kernel/kernel/OpenFileDescriptorSet.cpp index ab8470c9..4bb03f8a 100644 --- a/kernel/kernel/OpenFileDescriptorSet.cpp +++ b/kernel/kernel/OpenFileDescriptorSet.cpp @@ -424,7 +424,24 @@ namespace Kernel } if (inode->mode().ifsock()) - return recvfrom(fd, buffer, nullptr, nullptr); + { + iovec iov { + .iov_base = buffer.data(), + .iov_len = buffer.size(), + }; + + msghdr message { + .msg_name = nullptr, + .msg_namelen = 0, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = nullptr, + .msg_controllen = 0, + .msg_flags = 0, + }; + + return recvmsg(fd, message, 0); + } size_t nread; { @@ -461,7 +478,24 @@ namespace Kernel } if (inode->mode().ifsock()) - return sendto(fd, buffer, nullptr, 0); + { + iovec iov { + .iov_base = const_cast(buffer.data()), + .iov_len = buffer.size(), + }; + + msghdr message { + .msg_name = nullptr, + .msg_namelen = 0, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = nullptr, + .msg_controllen = 0, + .msg_flags = 0, + }; + + return sendmsg(fd, message, 0); + } size_t nwrite; { @@ -515,7 +549,7 @@ namespace Kernel } } - BAN::ErrorOr OpenFileDescriptorSet::recvfrom(int fd, BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) + BAN::ErrorOr OpenFileDescriptorSet::recvmsg(int fd, msghdr& message, int flags) { BAN::RefPtr inode; bool is_nonblock; @@ -533,10 +567,10 @@ namespace Kernel LockGuard _(inode->m_mutex); if (is_nonblock && !inode->can_read()) return BAN::Error::from_errno(EWOULDBLOCK); - return inode->recvfrom(buffer, address, address_len); + return inode->recvmsg(message, flags); } - BAN::ErrorOr OpenFileDescriptorSet::sendto(int fd, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr OpenFileDescriptorSet::sendmsg(int fd, const msghdr& message, int flags) { BAN::RefPtr inode; bool is_nonblock; @@ -559,7 +593,7 @@ namespace Kernel } if (is_nonblock && !inode->can_write()) return BAN::Error::from_errno(EWOULDBLOCK); - return inode->sendto(buffer, address, address_len); + return inode->sendmsg(message, flags); } BAN::ErrorOr OpenFileDescriptorSet::file_of(int fd) const diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index 935ece78..15adc08d 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -1569,73 +1569,72 @@ namespace Kernel return 0; } - BAN::ErrorOr Process::sys_sendto(const sys_sendto_t* _arguments) + BAN::ErrorOr Process::sys_recvmsg(int socket, msghdr* _message, int flags) { - sys_sendto_t arguments; + msghdr message; { LockGuard _(m_process_lock); - TRY(validate_pointer_access(_arguments, sizeof(sys_sendto_t), false)); - arguments = *_arguments; + TRY(validate_pointer_access(_message, sizeof(msghdr), true)); + message = *_message; } - if (arguments.length == 0) - return BAN::Error::from_errno(EINVAL); - - MemoryRegion* message_region = nullptr; - MemoryRegion* address_region = nullptr; - - BAN::ScopeGuard _([&] { - if (message_region) - message_region->unpin(); - if (address_region) - address_region->unpin(); + BAN::Vector regions; + BAN::ScopeGuard _([®ions] { + for (auto* region : regions) + if (region != nullptr) + region->unpin(); }); - message_region = TRY(validate_and_pin_pointer_access(arguments.message, arguments.length, false)); - if (arguments.dest_addr) - address_region = TRY(validate_and_pin_pointer_access(arguments.dest_addr, arguments.dest_len, false)); + if (message.msg_name) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_name, message.msg_namelen, true)))); + if (message.msg_control) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_control, message.msg_controllen, true)))); + if (message.msg_iov) + { + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov, message.msg_iovlen * sizeof(iovec), true)))); + for (int i = 0; i < message.msg_iovlen; i++) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov[i].iov_base, message.msg_iov[i].iov_len, true)))); + } - auto message = BAN::ConstByteSpan(static_cast(arguments.message), arguments.length); - return TRY(m_open_file_descriptors.sendto(arguments.socket, message, arguments.dest_addr, arguments.dest_len)); + auto ret = TRY(m_open_file_descriptors.recvmsg(socket, message, flags)); + + { + LockGuard _(m_process_lock); + TRY(validate_pointer_access(_message, sizeof(msghdr), true)); + *_message = message; + } + + return ret; } - BAN::ErrorOr Process::sys_recvfrom(sys_recvfrom_t* _arguments) + BAN::ErrorOr Process::sys_sendmsg(int socket, const msghdr* _message, int flags) { - sys_recvfrom_t arguments; + msghdr message; { LockGuard _(m_process_lock); - TRY(validate_pointer_access(_arguments, sizeof(sys_sendto_t), false)); - arguments = *_arguments; + TRY(validate_pointer_access(_message, sizeof(msghdr), false)); + message = *_message; } - if (!arguments.address != !arguments.address_len) - return BAN::Error::from_errno(EINVAL); - if (arguments.length == 0) - return BAN::Error::from_errno(EINVAL); - - MemoryRegion* buffer_region = nullptr; - MemoryRegion* address_region1 = nullptr; - MemoryRegion* address_region2 = nullptr; - - BAN::ScopeGuard _([&] { - if (buffer_region) - buffer_region->unpin(); - if (address_region1) - address_region1->unpin(); - if (address_region2) - address_region2->unpin(); + BAN::Vector regions; + BAN::ScopeGuard _([®ions] { + for (auto* region : regions) + if (region != nullptr) + region->unpin(); }); - buffer_region = TRY(validate_and_pin_pointer_access(arguments.buffer, arguments.length, true)); - - if (arguments.address_len) + if (message.msg_name) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_name, message.msg_namelen, false)))); + if (message.msg_control) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_control, message.msg_controllen, false)))); + if (message.msg_iov) { - address_region1 = TRY(validate_and_pin_pointer_access(arguments.address_len, sizeof(*arguments.address_len), true)); - address_region2 = TRY(validate_and_pin_pointer_access(arguments.address, *arguments.address_len, true)); + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov, message.msg_iovlen * sizeof(iovec), false)))); + for (int i = 0; i < message.msg_iovlen; i++) + TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov[i].iov_base, message.msg_iov[i].iov_len, false)))); } - auto message = BAN::ByteSpan(static_cast(arguments.buffer), arguments.length); - return TRY(m_open_file_descriptors.recvfrom(arguments.socket, message, arguments.address, arguments.address_len)); + return TRY(m_open_file_descriptors.sendmsg(socket, message, flags)); } BAN::ErrorOr Process::sys_ioctl(int fildes, int request, void* arg) diff --git a/kernel/kernel/Syscall.cpp b/kernel/kernel/Syscall.cpp index 2a6e1fae..ff5d8798 100644 --- a/kernel/kernel/Syscall.cpp +++ b/kernel/kernel/Syscall.cpp @@ -124,8 +124,8 @@ namespace Kernel case SYS_WAIT: case SYS_ACCEPT: case SYS_CONNECT: - case SYS_RECVFROM: - case SYS_SENDTO: + case SYS_RECVMSG: + case SYS_SENDMSG: case SYS_FLOCK: return true; default: diff --git a/userspace/libraries/LibC/include/sys/syscall.h b/userspace/libraries/LibC/include/sys/syscall.h index 52b8d530..7d236e52 100644 --- a/userspace/libraries/LibC/include/sys/syscall.h +++ b/userspace/libraries/LibC/include/sys/syscall.h @@ -67,8 +67,8 @@ __BEGIN_DECLS O(SYS_SOCKET, socket) \ O(SYS_SOCKETPAIR, socketpair) \ O(SYS_BIND, bind) \ - O(SYS_SENDTO, sendto) \ - O(SYS_RECVFROM, recvfrom) \ + O(SYS_RECVMSG, recvmsg) \ + O(SYS_SENDMSG, sendmsg) \ O(SYS_IOCTL, ioctl) \ O(SYS_ACCEPT, accept) \ O(SYS_CONNECT, connect) \ diff --git a/userspace/libraries/LibC/sys/socket.cpp b/userspace/libraries/LibC/sys/socket.cpp index f3c392fb..b3a12935 100644 --- a/userspace/libraries/LibC/sys/socket.cpp +++ b/userspace/libraries/LibC/sys/socket.cpp @@ -42,104 +42,69 @@ ssize_t recv(int socket, void* __restrict buffer, size_t length, int flags) ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len) { - pthread_testcancel(); - sys_recvfrom_t arguments { - .socket = socket, - .buffer = buffer, - .length = length, - .flags = flags, - .address = address, - .address_len = address_len + // cancellation point in recvmsg + + iovec iov { + .iov_base = buffer, + .iov_len = length, }; - return syscall(SYS_RECVFROM, &arguments); + + msghdr message { + .msg_name = address, + .msg_namelen = address_len ? *address_len : 0, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = 0, + }; + + const ssize_t ret = recvmsg(socket, &message, flags); + + if (address_len) + *address_len = message.msg_namelen; + + return ret; } -ssize_t send(int socket, const void* message, size_t length, int flags) +ssize_t recvmsg(int socket, struct msghdr* message, int flags) +{ + pthread_testcancel(); + return syscall(SYS_RECVMSG, socket, message, flags); +} + +ssize_t send(int socket, const void* buffer, size_t length, int flags) { // cancellation point in sendto - return sendto(socket, message, length, flags, nullptr, 0); + return sendto(socket, buffer, length, flags, nullptr, 0); } -ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len) +ssize_t sendto(int socket, const void* buffer, size_t length, int flags, const struct sockaddr* address, socklen_t address_len) +{ + // cancellation point in sendmsg + + iovec iov { + .iov_base = const_cast(buffer), + .iov_len = length, + }; + + msghdr message { + .msg_name = const_cast(address), + .msg_namelen = address_len, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = 0, + }; + + return sendmsg(socket, &message, flags); +} + +ssize_t sendmsg(int socket, const struct msghdr* message, int flags) { pthread_testcancel(); - sys_sendto_t arguments { - .socket = socket, - .message = message, - .length = length, - .flags = flags, - .dest_addr = dest_addr, - .dest_len = dest_len - }; - return syscall(SYS_SENDTO, &arguments); -} - -ssize_t recvmsg(int socket, struct msghdr* message, int flags) -{ - if (CMSG_FIRSTHDR(message)) - { - dwarnln("TODO: recvmsg ancillary data"); - errno = ENOTSUP; - return -1; - } - - size_t total_recv = 0; - - for (int i = 0; i < message->msg_iovlen; i++) - { - const ssize_t nrecv = recvfrom( - socket, - message->msg_iov[i].iov_base, - message->msg_iov[i].iov_len, - flags, - static_cast(message->msg_name), - &message->msg_namelen - ); - - if (nrecv < 0) - return -1; - - total_recv += nrecv; - - if (static_cast(nrecv) < message->msg_iov[i].iov_len) - break; - } - - return total_recv; -} - -ssize_t sendmsg(int socket, const struct msghdr* message, int flags) -{ - if (CMSG_FIRSTHDR(message)) - { - dwarnln("TODO: sendmsg ancillary data"); - errno = ENOTSUP; - return -1; - } - - size_t total_sent = 0; - - for (int i = 0; i < message->msg_iovlen; i++) - { - const ssize_t nsend = sendto( - socket, - message->msg_iov[i].iov_base, - message->msg_iov[i].iov_len, - flags, - static_cast(message->msg_name), - message->msg_namelen - ); - - if (nsend < 0) - return -1; - - total_sent += nsend; - - if (static_cast(nsend) < message->msg_iov[i].iov_len) - break; - } - - return total_sent; + return syscall(SYS_SENDMSG, socket, message, flags); } int socket(int domain, int type, int protocol)