Kernel/LibC: Implement SOCK_CLOEXEC and SOCK_NONBLOCK

This removes the need for fcntl after creating a socket :)
This commit is contained in:
Bananymous 2024-09-11 21:35:41 +03:00
parent c77ad5fb34
commit 467ac6c365
11 changed files with 38 additions and 15 deletions

View File

@ -97,7 +97,7 @@ namespace Kernel
BAN::ErrorOr<BAN::String> link_target(); BAN::ErrorOr<BAN::String> link_target();
// Socket API // Socket API
BAN::ErrorOr<long> accept(sockaddr* address, socklen_t* address_len); BAN::ErrorOr<long> accept(sockaddr* address, socklen_t* address_len, int flags);
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> listen(int backlog); BAN::ErrorOr<void> listen(int backlog);
@ -131,7 +131,7 @@ namespace Kernel
virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); }
// Socket API // Socket API
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }

View File

@ -55,7 +55,7 @@ namespace Kernel
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override;
protected: protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override; virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override; virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;

View File

@ -18,7 +18,7 @@ namespace Kernel
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&); static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&);
protected: protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override; virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override; virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;

View File

@ -129,7 +129,7 @@ namespace Kernel
BAN::ErrorOr<long> sys_getsockopt(int socket, int level, int option_name, void* option_value, socklen_t* option_len); BAN::ErrorOr<long> sys_getsockopt(int socket, int level, int option_name, void* option_value, socklen_t* option_len);
BAN::ErrorOr<long> sys_setsockopt(int socket, int level, int option_name, const void* option_value, socklen_t option_len); BAN::ErrorOr<long> sys_setsockopt(int socket, int level, int option_name, const void* option_value, socklen_t option_len);
BAN::ErrorOr<long> sys_accept(int socket, sockaddr* address, socklen_t* address_len); BAN::ErrorOr<long> sys_accept(int socket, sockaddr* address, socklen_t* address_len, int flags);
BAN::ErrorOr<long> sys_bind(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr<long> sys_bind(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_connect(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr<long> sys_connect(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_listen(int socket, int backlog); BAN::ErrorOr<long> sys_listen(int socket, int backlog);

View File

@ -110,12 +110,12 @@ namespace Kernel
return link_target_impl(); return link_target_impl();
} }
BAN::ErrorOr<long> Inode::accept(sockaddr* address, socklen_t* address_len) BAN::ErrorOr<long> Inode::accept(sockaddr* address, socklen_t* address_len, int flags)
{ {
LockGuard _(m_mutex); LockGuard _(m_mutex);
if (!mode().ifsock()) if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return accept_impl(address, address_len); return accept_impl(address, address_len, flags);
} }
BAN::ErrorOr<void> Inode::bind(const sockaddr* address, socklen_t address_len) BAN::ErrorOr<void> Inode::bind(const sockaddr* address, socklen_t address_len)

View File

@ -67,7 +67,7 @@ namespace Kernel
dprintln_if(DEBUG_TCP, "Socket destroyed"); dprintln_if(DEBUG_TCP, "Socket destroyed");
} }
BAN::ErrorOr<long> TCPSocket::accept_impl(sockaddr* address, socklen_t* address_len) BAN::ErrorOr<long> TCPSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags)
{ {
if (m_state != State::Listen) if (m_state != State::Listen)
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
@ -123,7 +123,7 @@ namespace Kernel
memcpy(address, &connection.target.address, *address_len); memcpy(address, &connection.target.address, *address_len);
} }
return TRY(Process::current().open_inode(return_inode, O_RDWR)); return TRY(Process::current().open_inode(return_inode, O_RDWR | flags));
} }
BAN::ErrorOr<void> TCPSocket::connect_impl(const sockaddr* address, socklen_t address_len) BAN::ErrorOr<void> TCPSocket::connect_impl(const sockaddr* address, socklen_t address_len)

View File

@ -64,7 +64,7 @@ namespace Kernel
} }
} }
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len) BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags)
{ {
if (!m_info.has<ConnectionInfo>()) if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP); return BAN::Error::from_errno(EOPNOTSUPP);
@ -104,7 +104,7 @@ namespace Kernel
strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len); strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len);
} }
return TRY(Process::current().open_inode(return_inode, O_RDWR)); return TRY(Process::current().open_inode(return_inode, O_RDWR | flags));
} }
BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len) BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len)

View File

@ -106,6 +106,13 @@ namespace Kernel
return BAN::Error::from_errno(EPROTOTYPE); return BAN::Error::from_errno(EPROTOTYPE);
} }
int extra_flags = 0;
if (type & SOCK_NONBLOCK)
extra_flags |= O_NONBLOCK;
if (type & SOCK_CLOEXEC)
extra_flags |= O_CLOEXEC;
type &= ~(SOCK_NONBLOCK | SOCK_CLOEXEC);
Socket::Type sock_type; Socket::Type sock_type;
switch (type) switch (type)
{ {
@ -133,7 +140,7 @@ namespace Kernel
auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid())); auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid()));
int fd = TRY(get_free_fd()); int fd = TRY(get_free_fd());
m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(socket, "no-path"_sv, 0, O_RDWR)); m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(socket, "<socket>"_sv, 0, O_RDWR | extra_flags));
return fd; return fd;
} }

View File

@ -1192,12 +1192,14 @@ namespace Kernel
return BAN::Error::from_errno(ENOTSUP); return BAN::Error::from_errno(ENOTSUP);
} }
BAN::ErrorOr<long> Process::sys_accept(int socket, sockaddr* address, socklen_t* address_len) BAN::ErrorOr<long> Process::sys_accept(int socket, sockaddr* address, socklen_t* address_len, int flags)
{ {
if (address && !address_len) if (address && !address_len)
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
if (!address && address_len) if (!address && address_len)
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
if (flags & ~(SOCK_NONBLOCK | SOCK_CLOEXEC))
return BAN::Error::from_errno(EINVAL);
LockGuard _(m_process_lock); LockGuard _(m_process_lock);
if (address) if (address)
@ -1210,7 +1212,13 @@ namespace Kernel
if (!inode->mode().ifsock()) if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return TRY(inode->accept(address, address_len)); int open_flags = 0;
if (flags & SOCK_NONBLOCK)
open_flags |= O_NONBLOCK;
if (flags & SOCK_CLOEXEC)
open_flags |= O_CLOEXEC;
return TRY(inode->accept(address, address_len, open_flags));
} }
BAN::ErrorOr<long> Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len) BAN::ErrorOr<long> Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len)

View File

@ -71,6 +71,8 @@ struct linger
#define SOCK_RAW 2 #define SOCK_RAW 2
#define SOCK_SEQPACKET 3 #define SOCK_SEQPACKET 3
#define SOCK_STREAM 4 #define SOCK_STREAM 4
#define SOCK_CLOEXEC 0x10
#define SOCK_NONBLOCK 0x20
#define SOL_SOCKET 1 #define SOL_SOCKET 1
@ -137,6 +139,7 @@ struct sys_recvfrom_t
}; };
int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len); int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len);
int accept4(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len, int flags);
int bind(int socket, const struct sockaddr* address, socklen_t address_len); int bind(int socket, const struct sockaddr* address, socklen_t address_len);
int connect(int socket, const struct sockaddr* address, socklen_t address_len); int connect(int socket, const struct sockaddr* address, socklen_t address_len);
int getpeername(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len); int getpeername(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len);

View File

@ -4,7 +4,12 @@
int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len) int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
{ {
return syscall(SYS_ACCEPT, socket, address, address_len); return accept4(socket, address, address_len, 0);
}
int accept4(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len, int flags)
{
return syscall(SYS_ACCEPT, socket, address, address_len, flags);
} }
int bind(int socket, const struct sockaddr* address, socklen_t address_len) int bind(int socket, const struct sockaddr* address, socklen_t address_len)