From 467ac6c365fa68b4ec03389715b538c73b645b2e Mon Sep 17 00:00:00 2001 From: Bananymous Date: Wed, 11 Sep 2024 21:35:41 +0300 Subject: [PATCH] Kernel/LibC: Implement SOCK_CLOEXEC and SOCK_NONBLOCK This removes the need for fcntl after creating a socket :) --- kernel/include/kernel/FS/Inode.h | 4 ++-- kernel/include/kernel/Networking/TCPSocket.h | 2 +- kernel/include/kernel/Networking/UNIX/Socket.h | 2 +- kernel/include/kernel/Process.h | 2 +- kernel/kernel/FS/Inode.cpp | 4 ++-- kernel/kernel/Networking/TCPSocket.cpp | 4 ++-- kernel/kernel/Networking/UNIX/Socket.cpp | 4 ++-- kernel/kernel/OpenFileDescriptorSet.cpp | 9 ++++++++- kernel/kernel/Process.cpp | 12 ++++++++++-- userspace/libraries/LibC/include/sys/socket.h | 3 +++ userspace/libraries/LibC/sys/socket.cpp | 7 ++++++- 11 files changed, 38 insertions(+), 15 deletions(-) diff --git a/kernel/include/kernel/FS/Inode.h b/kernel/include/kernel/FS/Inode.h index 4a8d60f0..fb81c6be 100644 --- a/kernel/include/kernel/FS/Inode.h +++ b/kernel/include/kernel/FS/Inode.h @@ -97,7 +97,7 @@ namespace Kernel BAN::ErrorOr link_target(); // Socket API - BAN::ErrorOr accept(sockaddr* address, socklen_t* address_len); + BAN::ErrorOr accept(sockaddr* address, socklen_t* address_len, int flags); BAN::ErrorOr bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr listen(int backlog); @@ -131,7 +131,7 @@ namespace Kernel virtual BAN::ErrorOr link_target_impl() { return BAN::Error::from_errno(ENOTSUP); } // Socket API - virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*, int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index bbdbc19a..130e848b 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -55,7 +55,7 @@ namespace Kernel virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; protected: - virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) override; + virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*, int) override; virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr listen_impl(int) override; virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; diff --git a/kernel/include/kernel/Networking/UNIX/Socket.h b/kernel/include/kernel/Networking/UNIX/Socket.h index a2a78945..cfbff8e1 100644 --- a/kernel/include/kernel/Networking/UNIX/Socket.h +++ b/kernel/include/kernel/Networking/UNIX/Socket.h @@ -18,7 +18,7 @@ namespace Kernel static BAN::ErrorOr> create(Socket::Type, const Socket::Info&); protected: - virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) override; + virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*, int) override; virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr listen_impl(int) override; virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index 6c9b4749..f4f41eb0 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -129,7 +129,7 @@ namespace Kernel BAN::ErrorOr sys_getsockopt(int socket, int level, int option_name, void* option_value, socklen_t* option_len); BAN::ErrorOr sys_setsockopt(int socket, int level, int option_name, const void* option_value, socklen_t option_len); - BAN::ErrorOr sys_accept(int socket, sockaddr* address, socklen_t* address_len); + BAN::ErrorOr sys_accept(int socket, sockaddr* address, socklen_t* address_len, int flags); BAN::ErrorOr sys_bind(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr sys_connect(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr sys_listen(int socket, int backlog); diff --git a/kernel/kernel/FS/Inode.cpp b/kernel/kernel/FS/Inode.cpp index 02ac3532..4a42960d 100644 --- a/kernel/kernel/FS/Inode.cpp +++ b/kernel/kernel/FS/Inode.cpp @@ -110,12 +110,12 @@ namespace Kernel return link_target_impl(); } - BAN::ErrorOr Inode::accept(sockaddr* address, socklen_t* address_len) + BAN::ErrorOr Inode::accept(sockaddr* address, socklen_t* address_len, int flags) { LockGuard _(m_mutex); if (!mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return accept_impl(address, address_len); + return accept_impl(address, address_len, flags); } BAN::ErrorOr Inode::bind(const sockaddr* address, socklen_t address_len) diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index 09a09c74..5eb4abf2 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -67,7 +67,7 @@ namespace Kernel dprintln_if(DEBUG_TCP, "Socket destroyed"); } - BAN::ErrorOr TCPSocket::accept_impl(sockaddr* address, socklen_t* address_len) + BAN::ErrorOr TCPSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags) { if (m_state != State::Listen) return BAN::Error::from_errno(EINVAL); @@ -123,7 +123,7 @@ namespace Kernel memcpy(address, &connection.target.address, *address_len); } - return TRY(Process::current().open_inode(return_inode, O_RDWR)); + return TRY(Process::current().open_inode(return_inode, O_RDWR | flags)); } BAN::ErrorOr TCPSocket::connect_impl(const sockaddr* address, socklen_t address_len) diff --git a/kernel/kernel/Networking/UNIX/Socket.cpp b/kernel/kernel/Networking/UNIX/Socket.cpp index 4bc8c462..63bdd078 100644 --- a/kernel/kernel/Networking/UNIX/Socket.cpp +++ b/kernel/kernel/Networking/UNIX/Socket.cpp @@ -64,7 +64,7 @@ namespace Kernel } } - BAN::ErrorOr UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len) + BAN::ErrorOr UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags) { if (!m_info.has()) return BAN::Error::from_errno(EOPNOTSUPP); @@ -104,7 +104,7 @@ namespace Kernel strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len); } - return TRY(Process::current().open_inode(return_inode, O_RDWR)); + return TRY(Process::current().open_inode(return_inode, O_RDWR | flags)); } BAN::ErrorOr UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len) diff --git a/kernel/kernel/OpenFileDescriptorSet.cpp b/kernel/kernel/OpenFileDescriptorSet.cpp index 403ed184..b7baa09e 100644 --- a/kernel/kernel/OpenFileDescriptorSet.cpp +++ b/kernel/kernel/OpenFileDescriptorSet.cpp @@ -106,6 +106,13 @@ namespace Kernel return BAN::Error::from_errno(EPROTOTYPE); } + int extra_flags = 0; + if (type & SOCK_NONBLOCK) + extra_flags |= O_NONBLOCK; + if (type & SOCK_CLOEXEC) + extra_flags |= O_CLOEXEC; + type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC); + Socket::Type sock_type; switch (type) { @@ -133,7 +140,7 @@ namespace Kernel auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid())); int fd = TRY(get_free_fd()); - m_open_files[fd] = TRY(BAN::RefPtr::create(socket, "no-path"_sv, 0, O_RDWR)); + m_open_files[fd] = TRY(BAN::RefPtr::create(socket, ""_sv, 0, O_RDWR | extra_flags)); return fd; } diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index c6f3e254..ddbd4d05 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -1192,12 +1192,14 @@ namespace Kernel return BAN::Error::from_errno(ENOTSUP); } - BAN::ErrorOr Process::sys_accept(int socket, sockaddr* address, socklen_t* address_len) + BAN::ErrorOr Process::sys_accept(int socket, sockaddr* address, socklen_t* address_len, int flags) { if (address && !address_len) return BAN::Error::from_errno(EINVAL); if (!address && address_len) return BAN::Error::from_errno(EINVAL); + if (flags & ~(SOCK_NONBLOCK | SOCK_CLOEXEC)) + return BAN::Error::from_errno(EINVAL); LockGuard _(m_process_lock); if (address) @@ -1210,7 +1212,13 @@ namespace Kernel if (!inode->mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return TRY(inode->accept(address, address_len)); + int open_flags = 0; + if (flags & SOCK_NONBLOCK) + open_flags |= O_NONBLOCK; + if (flags & SOCK_CLOEXEC) + open_flags |= O_CLOEXEC; + + return TRY(inode->accept(address, address_len, open_flags)); } BAN::ErrorOr Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len) diff --git a/userspace/libraries/LibC/include/sys/socket.h b/userspace/libraries/LibC/include/sys/socket.h index 3581d262..f633455c 100644 --- a/userspace/libraries/LibC/include/sys/socket.h +++ b/userspace/libraries/LibC/include/sys/socket.h @@ -71,6 +71,8 @@ struct linger #define SOCK_RAW 2 #define SOCK_SEQPACKET 3 #define SOCK_STREAM 4 +#define SOCK_CLOEXEC 0x10 +#define SOCK_NONBLOCK 0x20 #define SOL_SOCKET 1 @@ -137,6 +139,7 @@ struct sys_recvfrom_t }; int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len); +int accept4(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len, int flags); int bind(int socket, const struct sockaddr* address, socklen_t address_len); int connect(int socket, const struct sockaddr* address, socklen_t address_len); int getpeername(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len); diff --git a/userspace/libraries/LibC/sys/socket.cpp b/userspace/libraries/LibC/sys/socket.cpp index d97747c8..4e62e0d4 100644 --- a/userspace/libraries/LibC/sys/socket.cpp +++ b/userspace/libraries/LibC/sys/socket.cpp @@ -4,7 +4,12 @@ int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len) { - return syscall(SYS_ACCEPT, socket, address, address_len); + return accept4(socket, address, address_len, 0); +} + +int accept4(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len, int flags) +{ + return syscall(SYS_ACCEPT, socket, address, address_len, flags); } int bind(int socket, const struct sockaddr* address, socklen_t address_len)