From 89c9bfd052b20b90194282fb51643c5ddbe8ace3 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Tue, 27 May 2025 07:09:04 +0300 Subject: [PATCH] Kernel/LibC: Implement `socketpair` for UNIX sockets --- .../include/kernel/Networking/UNIX/Socket.h | 1 + kernel/include/kernel/OpenFileDescriptorSet.h | 1 + kernel/include/kernel/Process.h | 1 + kernel/kernel/Networking/NetworkManager.cpp | 19 ++++++ kernel/kernel/Networking/UNIX/Socket.cpp | 14 ++++ kernel/kernel/OpenFileDescriptorSet.cpp | 65 ++++++++++++++----- kernel/kernel/Process.cpp | 8 +++ .../libraries/LibC/include/sys/syscall.h | 1 + userspace/libraries/LibC/sys/socket.cpp | 5 ++ 9 files changed, 98 insertions(+), 17 deletions(-) diff --git a/kernel/include/kernel/Networking/UNIX/Socket.h b/kernel/include/kernel/Networking/UNIX/Socket.h index eff296200a..7ad12625a6 100644 --- a/kernel/include/kernel/Networking/UNIX/Socket.h +++ b/kernel/include/kernel/Networking/UNIX/Socket.h @@ -17,6 +17,7 @@ namespace Kernel public: static BAN::ErrorOr> create(Socket::Type, const Socket::Info&); + BAN::ErrorOr make_socket_pair(UnixDomainSocket&); protected: virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*, int) override; diff --git a/kernel/include/kernel/OpenFileDescriptorSet.h b/kernel/include/kernel/OpenFileDescriptorSet.h index 9376b32db0..1396653e2e 100644 --- a/kernel/include/kernel/OpenFileDescriptorSet.h +++ b/kernel/include/kernel/OpenFileDescriptorSet.h @@ -26,6 +26,7 @@ namespace Kernel BAN::ErrorOr open(BAN::StringView absolute_path, int flags); BAN::ErrorOr socket(int domain, int type, int protocol); + BAN::ErrorOr socketpair(int domain, int type, int protocol, int socket_vector[2]); BAN::ErrorOr pipe(int fds[2]); diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index 9e1007445f..4540c5540b 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -118,6 +118,7 @@ namespace Kernel BAN::ErrorOr sys_fchownat(int fd, const char* path, uid_t uid, gid_t gid, int flag); BAN::ErrorOr sys_socket(int domain, int type, int protocol); + BAN::ErrorOr sys_socketpair(int domain, int type, int protocol, int socket_vector[2]); BAN::ErrorOr sys_getsockname(int socket, sockaddr* address, socklen_t* address_len); BAN::ErrorOr sys_getpeername(int socket, sockaddr* address, socklen_t* address_len); BAN::ErrorOr sys_getsockopt(int socket, int level, int option_name, void* option_value, socklen_t* option_len); diff --git a/kernel/kernel/Networking/NetworkManager.cpp b/kernel/kernel/Networking/NetworkManager.cpp index b0b3920f73..838d820e14 100644 --- a/kernel/kernel/Networking/NetworkManager.cpp +++ b/kernel/kernel/Networking/NetworkManager.cpp @@ -129,6 +129,25 @@ namespace Kernel return socket; } + BAN::ErrorOr NetworkManager::connect_sockets(Socket::Domain domain, BAN::RefPtr socket1, BAN::RefPtr socket2) + { + switch (domain) + { + case Socket::Domain::UNIX: + { + auto* usocket1 = static_cast(socket1.ptr()); + auto* usocket2 = static_cast(socket2.ptr()); + TRY(usocket1->make_socket_pair(*usocket2)); + break; + } + default: + dwarnln("TODO: connect {} domain sockets", static_cast(domain)); + return BAN::Error::from_errno(ENOTSUP); + } + + return {}; + } + void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet) { if (packet.size() < sizeof(EthernetHeader)) diff --git a/kernel/kernel/Networking/UNIX/Socket.cpp b/kernel/kernel/Networking/UNIX/Socket.cpp index 84dc06dcb3..b7588c14f8 100644 --- a/kernel/kernel/Networking/UNIX/Socket.cpp +++ b/kernel/kernel/Networking/UNIX/Socket.cpp @@ -69,6 +69,20 @@ namespace Kernel } } + BAN::ErrorOr UnixDomainSocket::make_socket_pair(UnixDomainSocket& other) + { + if (!m_info.has() || !other.m_info.has()) + return BAN::Error::from_errno(EINVAL); + + TRY(this->get_weak_ptr()); + TRY(other.get_weak_ptr()); + + this->m_info.get().connection = MUST(other.get_weak_ptr()); + other.m_info.get().connection = MUST(this->get_weak_ptr()); + + return {}; + } + BAN::ErrorOr UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags) { if (!m_info.has()) diff --git a/kernel/kernel/OpenFileDescriptorSet.cpp b/kernel/kernel/OpenFileDescriptorSet.cpp index f66b02cba5..1e60b66fe9 100644 --- a/kernel/kernel/OpenFileDescriptorSet.cpp +++ b/kernel/kernel/OpenFileDescriptorSet.cpp @@ -83,50 +83,57 @@ namespace Kernel return open(TRY(VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, flags)), flags); } - BAN::ErrorOr OpenFileDescriptorSet::socket(int domain, int type, int protocol) + struct SocketInfo { - bool valid_protocol = true; + Socket::Domain domain; + Socket::Type type; + int status_flags; + int descriptor_flags; + }; - Socket::Domain sock_domain; + static BAN::ErrorOr parse_socket_info(int domain, int type, int protocol) + { + SocketInfo info; + + bool valid_protocol = true; switch (domain) { case AF_INET: - sock_domain = Socket::Domain::INET; + info.domain = Socket::Domain::INET; break; case AF_INET6: - sock_domain = Socket::Domain::INET6; + info.domain = Socket::Domain::INET6; break; case AF_UNIX: - sock_domain = Socket::Domain::UNIX; + info.domain = Socket::Domain::UNIX; valid_protocol = false; break; default: return BAN::Error::from_errno(EPROTOTYPE); } - int status_flags = 0; - int descriptor_flags = 0; + info.status_flags = 0; + info.descriptor_flags = 0; if (type & SOCK_NONBLOCK) - status_flags |= O_NONBLOCK; + info.status_flags |= O_NONBLOCK; if (type & SOCK_CLOEXEC) - descriptor_flags |= O_CLOEXEC; + info.descriptor_flags |= O_CLOEXEC; type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC); - Socket::Type sock_type; switch (type) { case SOCK_STREAM: - sock_type = Socket::Type::STREAM; + info.type = Socket::Type::STREAM; if (protocol != IPPROTO_TCP) valid_protocol = false; break; case SOCK_DGRAM: - sock_type = Socket::Type::DGRAM; + info.type = Socket::Type::DGRAM; if (protocol != IPPROTO_UDP) valid_protocol = false; break; case SOCK_SEQPACKET: - sock_type = Socket::Type::SEQPACKET; + info.type = Socket::Type::SEQPACKET; valid_protocol = false; break; default: @@ -136,15 +143,39 @@ namespace Kernel if (protocol && !valid_protocol) return BAN::Error::from_errno(EPROTONOSUPPORT); - auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid())); + return info; + } + + BAN::ErrorOr OpenFileDescriptorSet::socket(int domain, int type, int protocol) + { + auto sock_info = TRY(parse_socket_info(domain, type, protocol)); + auto socket = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0777, m_credentials.euid(), m_credentials.egid())); LockGuard _(m_mutex); int fd = TRY(get_free_fd()); - m_open_files[fd].description = TRY(BAN::RefPtr::create(VirtualFileSystem::File(socket, ""_sv), 0, O_RDWR | status_flags)); - m_open_files[fd].descriptor_flags = descriptor_flags; + m_open_files[fd].description = TRY(BAN::RefPtr::create(VirtualFileSystem::File(socket, ""_sv), 0, O_RDWR | sock_info.status_flags)); + m_open_files[fd].descriptor_flags = sock_info.descriptor_flags; return fd; } + BAN::ErrorOr OpenFileDescriptorSet::socketpair(int domain, int type, int protocol, int socket_vector[2]) + { + auto sock_info = TRY(parse_socket_info(domain, type, protocol)); + + auto socket1 = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0600, m_credentials.euid(), m_credentials.egid())); + auto socket2 = TRY(NetworkManager::get().create_socket(sock_info.domain, sock_info.type, 0600, m_credentials.euid(), m_credentials.egid())); + TRY(NetworkManager::get().connect_sockets(sock_info.domain, socket1, socket2)); + + LockGuard _(m_mutex); + + TRY(get_free_fd_pair(socket_vector)); + m_open_files[socket_vector[0]].description = TRY(BAN::RefPtr::create(VirtualFileSystem::File(socket1, ""_sv), 0, O_RDWR | sock_info.status_flags)); + m_open_files[socket_vector[0]].descriptor_flags = sock_info.descriptor_flags; + m_open_files[socket_vector[1]].description = TRY(BAN::RefPtr::create(VirtualFileSystem::File(socket2, ""_sv), 0, O_RDWR | sock_info.status_flags)); + m_open_files[socket_vector[1]].descriptor_flags = sock_info.descriptor_flags; + return {}; + } + BAN::ErrorOr OpenFileDescriptorSet::pipe(int fds[2]) { LockGuard _(m_mutex); diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index 97d5effc94..0d3d4fc16a 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -1270,6 +1270,14 @@ namespace Kernel return TRY(m_open_file_descriptors.socket(domain, type, protocol)); } + BAN::ErrorOr Process::sys_socketpair(int domain, int type, int protocol, int socket_vector[2]) + { + LockGuard _(m_process_lock); + TRY(validate_pointer_access(socket_vector, sizeof(int) * 2, true)); + TRY(m_open_file_descriptors.socketpair(domain, type, protocol, socket_vector)); + return 0; + } + BAN::ErrorOr Process::sys_getsockname(int socket, sockaddr* address, socklen_t* address_len) { LockGuard _(m_process_lock); diff --git a/userspace/libraries/LibC/include/sys/syscall.h b/userspace/libraries/LibC/include/sys/syscall.h index 0e7cb42e56..e298c8870d 100644 --- a/userspace/libraries/LibC/include/sys/syscall.h +++ b/userspace/libraries/LibC/include/sys/syscall.h @@ -64,6 +64,7 @@ __BEGIN_DECLS O(SYS_FCHOWNAT, fchownat) \ O(SYS_LOAD_KEYMAP, load_keymap) \ O(SYS_SOCKET, socket) \ + O(SYS_SOCKETPAIR, socketpair) \ O(SYS_BIND, bind) \ O(SYS_SENDTO, sendto) \ O(SYS_RECVFROM, recvfrom) \ diff --git a/userspace/libraries/LibC/sys/socket.cpp b/userspace/libraries/LibC/sys/socket.cpp index 93435cc35b..f3770cc829 100644 --- a/userspace/libraries/LibC/sys/socket.cpp +++ b/userspace/libraries/LibC/sys/socket.cpp @@ -68,6 +68,11 @@ int socket(int domain, int type, int protocol) return syscall(SYS_SOCKET, domain, type, protocol); } +int socketpair(int domain, int type, int protocol, int socket_vector[2]) +{ + return syscall(SYS_SOCKETPAIR, domain, type, protocol, socket_vector); +} + int getsockname(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len) { return syscall(SYS_GETSOCKNAME, socket, address, address_len);