From e7dd03e5514aefe0a382c68cd354c9c24e48714e Mon Sep 17 00:00:00 2001 From: Bananymous Date: Thu, 8 Feb 2024 02:28:19 +0200 Subject: [PATCH] Kernel: Implement basic connection-mode unix domain sockets --- kernel/CMakeLists.txt | 1 + kernel/include/kernel/FS/Inode.h | 4 +- kernel/include/kernel/FS/Socket.h | 20 ++ .../kernel/Networking/NetworkManager.h | 2 +- .../include/kernel/Networking/NetworkSocket.h | 15 +- .../include/kernel/Networking/UNIX/Socket.h | 67 ++++ kernel/include/kernel/OpenFileDescriptorSet.h | 1 + kernel/include/kernel/Process.h | 2 + kernel/kernel/FS/Inode.cpp | 2 +- kernel/kernel/Networking/NetworkManager.cpp | 14 +- kernel/kernel/Networking/UNIX/Socket.cpp | 323 ++++++++++++++++++ kernel/kernel/OpenFileDescriptorSet.cpp | 15 + kernel/kernel/Process.cpp | 10 +- 13 files changed, 454 insertions(+), 22 deletions(-) create mode 100644 kernel/include/kernel/FS/Socket.h create mode 100644 kernel/include/kernel/Networking/UNIX/Socket.h create mode 100644 kernel/kernel/Networking/UNIX/Socket.cpp diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 51cac192df..2927131758 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -58,6 +58,7 @@ set(KERNEL_SOURCES kernel/Networking/NetworkManager.cpp kernel/Networking/NetworkSocket.cpp kernel/Networking/UDPSocket.cpp + kernel/Networking/UNIX/Socket.cpp kernel/OpenFileDescriptorSet.cpp kernel/Panic.cpp kernel/PCI.cpp diff --git a/kernel/include/kernel/FS/Inode.h b/kernel/include/kernel/FS/Inode.h index 3142175ff3..d4c9b57715 100644 --- a/kernel/include/kernel/FS/Inode.h +++ b/kernel/include/kernel/FS/Inode.h @@ -100,7 +100,7 @@ namespace Kernel BAN::ErrorOr link_target(); // Socket API - BAN::ErrorOr accept(sockaddr* address, socklen_t* address_len); + BAN::ErrorOr accept(sockaddr* address, socklen_t* address_len); BAN::ErrorOr bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr listen(int backlog); @@ -131,7 +131,7 @@ namespace Kernel virtual BAN::ErrorOr link_target_impl() { return BAN::Error::from_errno(ENOTSUP); } // Socket API - virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } diff --git a/kernel/include/kernel/FS/Socket.h b/kernel/include/kernel/FS/Socket.h new file mode 100644 index 0000000000..487843b2c9 --- /dev/null +++ b/kernel/include/kernel/FS/Socket.h @@ -0,0 +1,20 @@ +#pragma once + +namespace Kernel +{ + + enum class SocketDomain + { + INET, + INET6, + UNIX, + }; + + enum class SocketType + { + STREAM, + DGRAM, + SEQPACKET, + }; + +} diff --git a/kernel/include/kernel/Networking/NetworkManager.h b/kernel/include/kernel/Networking/NetworkManager.h index 475cbde8d6..4c07bfe406 100644 --- a/kernel/include/kernel/Networking/NetworkManager.h +++ b/kernel/include/kernel/Networking/NetworkManager.h @@ -25,7 +25,7 @@ namespace Kernel BAN::Vector> interfaces() { return m_interfaces; } - BAN::ErrorOr> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t); + BAN::ErrorOr> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t); void on_receive(NetworkInterface&, BAN::ConstByteSpan); diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index 71ee94ac70..da25acd7c2 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -16,20 +17,6 @@ namespace Kernel UDP = 0x11, }; - enum class SocketDomain - { - INET, - INET6, - UNIX, - }; - - enum class SocketType - { - STREAM, - DGRAM, - SEQPACKET, - }; - class NetworkSocket : public TmpInode, public BAN::Weakable { BAN_NON_COPYABLE(NetworkSocket); diff --git a/kernel/include/kernel/Networking/UNIX/Socket.h b/kernel/include/kernel/Networking/UNIX/Socket.h new file mode 100644 index 0000000000..fed37942ac --- /dev/null +++ b/kernel/include/kernel/Networking/UNIX/Socket.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include + +namespace Kernel +{ + + class UnixDomainSocket final : public TmpInode, public BAN::Weakable + { + BAN_NON_COPYABLE(UnixDomainSocket); + BAN_NON_MOVABLE(UnixDomainSocket); + + public: + static BAN::ErrorOr> create(SocketType, ino_t, const TmpInodeInfo&); + + protected: + virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) override; + virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; + virtual BAN::ErrorOr listen_impl(int) override; + virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; + virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) override; + virtual BAN::ErrorOr recvfrom_impl(sys_recvfrom_t*) override; + + private: + UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&); + + BAN::ErrorOr add_packet(BAN::ConstByteSpan); + + bool is_bound() const { return !m_bound_path.empty(); } + bool is_bound_to_unused() const { return m_bound_path == "X"sv; } + + bool is_streaming() const; + + private: + struct ConnectionInfo + { + bool listening { false }; + BAN::Atomic connection_done { false }; + BAN::WeakPtr connection; + BAN::Queue> pending_connections; + Semaphore pending_semaphore; + SpinLock pending_lock; + }; + + struct ConnectionlessInfo + { + + }; + + private: + const SocketType m_socket_type; + BAN::String m_bound_path; + + BAN::Variant m_info; + + BAN::CircularQueue m_packet_sizes; + size_t m_packet_size_total { 0 }; + BAN::UniqPtr m_packet_buffer; + Semaphore m_packet_semaphore; + + friend class BAN::RefPtr; + }; + +} diff --git a/kernel/include/kernel/OpenFileDescriptorSet.h b/kernel/include/kernel/OpenFileDescriptorSet.h index 857a499586..c799fd2219 100644 --- a/kernel/include/kernel/OpenFileDescriptorSet.h +++ b/kernel/include/kernel/OpenFileDescriptorSet.h @@ -21,6 +21,7 @@ namespace Kernel BAN::ErrorOr clone_from(const OpenFileDescriptorSet&); + BAN::ErrorOr open(BAN::RefPtr, int flags); BAN::ErrorOr open(BAN::StringView absolute_path, int flags); BAN::ErrorOr socket(int domain, int type, int protocol); diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index c5b62fd26a..7a97eb23dc 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -95,6 +95,8 @@ namespace Kernel BAN::ErrorOr sys_getegid() const { return m_credentials.egid(); } BAN::ErrorOr sys_getpgid(pid_t); + BAN::ErrorOr open_inode(BAN::RefPtr, int flags); + BAN::ErrorOr create_file_or_dir(BAN::StringView name, mode_t mode); BAN::ErrorOr open_file(BAN::StringView path, int oflag, mode_t = 0); BAN::ErrorOr sys_open(const char* path, int, mode_t); diff --git a/kernel/kernel/FS/Inode.cpp b/kernel/kernel/FS/Inode.cpp index 0f775c3c45..a9dc22ef7f 100644 --- a/kernel/kernel/FS/Inode.cpp +++ b/kernel/kernel/FS/Inode.cpp @@ -116,7 +116,7 @@ namespace Kernel return link_target_impl(); } - BAN::ErrorOr Inode::accept(sockaddr* address, socklen_t* address_len) + BAN::ErrorOr Inode::accept(sockaddr* address, socklen_t* address_len) { LockGuard _(m_lock); if (!mode().ifsock()) diff --git a/kernel/kernel/Networking/NetworkManager.cpp b/kernel/kernel/Networking/NetworkManager.cpp index 884e47464d..4fe77f947f 100644 --- a/kernel/kernel/Networking/NetworkManager.cpp +++ b/kernel/kernel/Networking/NetworkManager.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #define DEBUG_ETHERTYPE 0 @@ -70,7 +71,7 @@ namespace Kernel return {}; } - BAN::ErrorOr> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid) + BAN::ErrorOr> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid) { switch (domain) { @@ -80,6 +81,10 @@ namespace Kernel return BAN::Error::from_errno(EPROTOTYPE); break; } + case SocketDomain::UNIX: + { + break; + } default: return BAN::Error::from_errno(EAFNOSUPPORT); } @@ -90,7 +95,7 @@ namespace Kernel auto inode_info = create_inode_info(mode, uid, gid); ino_t ino = TRY(allocate_inode(inode_info)); - BAN::RefPtr socket; + BAN::RefPtr socket; switch (domain) { case SocketDomain::INET: @@ -99,6 +104,11 @@ namespace Kernel socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info)); break; } + case SocketDomain::UNIX: + { + socket = TRY(UnixDomainSocket::create(type, ino, inode_info)); + break; + } default: ASSERT_NOT_REACHED(); } diff --git a/kernel/kernel/Networking/UNIX/Socket.cpp b/kernel/kernel/Networking/UNIX/Socket.cpp new file mode 100644 index 0000000000..7120355475 --- /dev/null +++ b/kernel/kernel/Networking/UNIX/Socket.cpp @@ -0,0 +1,323 @@ +#include +#include +#include +#include +#include + +#include +#include + +namespace Kernel +{ + + static BAN::HashMap> s_bound_sockets; + static SpinLock s_bound_socket_lock; + + static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE; + + BAN::ErrorOr> UnixDomainSocket::create(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info) + { + auto socket = TRY(BAN::RefPtr::create(socket_type, ino, inode_info)); + socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range( + PageTable::kernel(), + KERNEL_OFFSET, + ~(uintptr_t)0, + s_packet_buffer_size, + PageTable::Flags::ReadWrite | PageTable::Flags::Present, + true + )); + return socket; + } + + UnixDomainSocket::UnixDomainSocket(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info) + : TmpInode(NetworkManager::get(), ino, inode_info) + , m_socket_type(socket_type) + { + switch (socket_type) + { + case SocketType::STREAM: + case SocketType::SEQPACKET: + m_info.emplace(); + break; + case SocketType::DGRAM: + m_info.emplace(); + break; + default: + ASSERT_NOT_REACHED(); + } + } + + BAN::ErrorOr UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len) + { + if (!m_info.has()) + return BAN::Error::from_errno(EOPNOTSUPP); + auto& connection_info = m_info.get(); + if (!connection_info.listening) + return BAN::Error::from_errno(EINVAL); + + while (connection_info.pending_connections.empty()) + TRY(Thread::current().block_or_eintr(connection_info.pending_semaphore)); + + BAN::RefPtr pending; + + { + LockGuard _(connection_info.pending_lock); + pending = connection_info.pending_connections.front(); + connection_info.pending_connections.pop(); + connection_info.pending_semaphore.unblock(); + } + + BAN::RefPtr return_inode; + + { + auto return_inode_tmp = TRY(NetworkManager::get().create_socket(SocketDomain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid())); + return_inode = reinterpret_cast(return_inode_tmp.ptr()); + } + + TRY(return_inode->m_bound_path.append(m_bound_path)); + return_inode->m_info.get().connection = TRY(pending->get_weak_ptr()); + pending->m_info.get().connection = TRY(return_inode->get_weak_ptr()); + pending->m_info.get().connection_done = true; + + if (address && address_len && !is_bound_to_unused()) + { + size_t copy_len = BAN::Math::min(*address_len, sizeof(sockaddr) + m_bound_path.size() + 1); + auto& sockaddr_un = *reinterpret_cast(address); + sockaddr_un.sun_family = AF_UNIX; + strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len); + } + + return TRY(Process::current().open_inode(return_inode, O_RDWR)); + } + + BAN::ErrorOr UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len) + { + if (address_len != sizeof(sockaddr_un)) + return BAN::Error::from_errno(EINVAL); + auto& sockaddr_un = *reinterpret_cast(address); + if (sockaddr_un.sun_family != AF_UNIX) + return BAN::Error::from_errno(EAFNOSUPPORT); + if (!is_bound()) + TRY(m_bound_path.push_back('X')); + + auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path)); + auto file = TRY(VirtualFileSystem::get().file_from_absolute_path( + Process::current().credentials(), + absolute_path, + O_RDWR + )); + + BAN::RefPtr target; + + { + LockGuard _(s_bound_socket_lock); + if (!s_bound_sockets.contains(file.canonical_path)) + return BAN::Error::from_errno(ECONNREFUSED); + target = s_bound_sockets[file.canonical_path]; + } + + if (m_socket_type != target->m_socket_type) + return BAN::Error::from_errno(EPROTOTYPE); + + if (m_info.has()) + { + auto& connection_info = m_info.get(); + if (connection_info.connection) + return BAN::Error::from_errno(ECONNREFUSED); + if (connection_info.listening) + return BAN::Error::from_errno(EOPNOTSUPP); + + connection_info.connection_done = false; + + for (;;) + { + auto& target_info = target->m_info.get(); + { + LockGuard _(target_info.pending_lock); + if (target_info.pending_connections.size() < target_info.pending_connections.capacity()) + { + MUST(target_info.pending_connections.push(this)); + target_info.pending_semaphore.unblock(); + break; + } + } + TRY(Thread::current().block_or_eintr(target_info.pending_semaphore)); + } + + while (!connection_info.connection_done) + Scheduler::get().reschedule(); + + return {}; + } + else + { + return BAN::Error::from_errno(ENOTSUP); + } + } + + BAN::ErrorOr UnixDomainSocket::listen_impl(int backlog) + { + backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN); + if (!is_bound()) + return BAN::Error::from_errno(EDESTADDRREQ); + if (!m_info.has()) + return BAN::Error::from_errno(EOPNOTSUPP); + auto& connection_info = m_info.get(); + if (connection_info.connection) + return BAN::Error::from_errno(EINVAL); + TRY(connection_info.pending_connections.reserve(backlog)); + connection_info.listening = true; + return {}; + } + + BAN::ErrorOr UnixDomainSocket::bind_impl(const sockaddr* address, socklen_t address_len) + { + if (is_bound()) + return BAN::Error::from_errno(EINVAL); + if (address_len != sizeof(sockaddr_un)) + return BAN::Error::from_errno(EINVAL); + auto& sockaddr_un = *reinterpret_cast(address); + if (sockaddr_un.sun_family != AF_UNIX) + return BAN::Error::from_errno(EAFNOSUPPORT); + + auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path)); + if (auto ret = Process::current().create_file_or_dir(absolute_path, 0755 | S_IFSOCK); ret.is_error()) + { + if (ret.error().get_error_code() == EEXIST) + return BAN::Error::from_errno(EADDRINUSE); + return ret.release_error(); + } + auto file = TRY(VirtualFileSystem::get().file_from_absolute_path( + Process::current().credentials(), + absolute_path, + O_RDWR + )); + + LockGuard _(s_bound_socket_lock); + ASSERT(!s_bound_sockets.contains(file.canonical_path)); + TRY(s_bound_sockets.emplace(file.canonical_path, this)); + m_bound_path = BAN::move(file.canonical_path); + + return {}; + } + + bool UnixDomainSocket::is_streaming() const + { + switch (m_socket_type) + { + case SocketType::STREAM: + return true; + case SocketType::SEQPACKET: + case SocketType::DGRAM: + return false; + default: + ASSERT_NOT_REACHED(); + } + } + + // This to feels too hacky to expose out of here + struct LockFreeGuard + { + LockFreeGuard(RecursivePrioritySpinLock& lock) + : m_lock(lock) + , m_depth(lock.lock_depth()) + { + for (uint32_t i = 0; i < m_depth; i++) + m_lock.unlock(); + } + + ~LockFreeGuard() + { + for (uint32_t i = 0; i < m_depth; i++) + m_lock.lock(); + } + + private: + RecursivePrioritySpinLock& m_lock; + const uint32_t m_depth; + }; + + BAN::ErrorOr UnixDomainSocket::add_packet(BAN::ConstByteSpan packet) + { + LockGuard _(m_lock); + + while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size) + { + LockFreeGuard _(m_lock); + TRY(Thread::current().block_or_eintr(m_packet_semaphore)); + } + + uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr() + m_packet_size_total); + memcpy(packet_buffer, packet.data(), packet.size()); + m_packet_size_total += packet.size(); + + if (!is_streaming()) + m_packet_sizes.push(packet.size()); + + m_packet_semaphore.unblock(); + return {}; + } + + BAN::ErrorOr UnixDomainSocket::sendto_impl(const sys_sendto_t* arguments) + { + if (arguments->flags) + return BAN::Error::from_errno(ENOTSUP); + if (arguments->length > s_packet_buffer_size) + return BAN::Error::from_errno(ENOBUFS); + + if (m_info.has()) + { + auto& connection_info = m_info.get(); + if (arguments->dest_addr) + return BAN::Error::from_errno(EISCONN); + auto target = connection_info.connection.lock(); + if (!target) + return BAN::Error::from_errno(ENOTCONN); + TRY(target->add_packet({ reinterpret_cast(arguments->message), arguments->length })); + return arguments->length; + } + else + { + return BAN::Error::from_errno(ENOTSUP); + } + } + + BAN::ErrorOr UnixDomainSocket::recvfrom_impl(sys_recvfrom_t* arguments) + { + if (arguments->flags) + return BAN::Error::from_errno(ENOTSUP); + + if (m_info.has()) + { + auto& connection_info = m_info.get(); + if (!connection_info.connection) + return BAN::Error::from_errno(ENOTCONN); + } + + while (m_packet_size_total == 0) + { + LockFreeGuard _(m_lock); + TRY(Thread::current().block_or_eintr(m_packet_semaphore)); + } + + uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); + + size_t nread = 0; + if (is_streaming()) + nread = BAN::Math::min(arguments->length, m_packet_size_total); + else + { + nread = BAN::Math::min(arguments->length, m_packet_sizes.front()); + m_packet_sizes.pop(); + } + + memcpy(arguments->buffer, packet_buffer, nread); + memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread); + m_packet_size_total -= nread; + + m_packet_semaphore.unblock(); + + return nread; + } + +} diff --git a/kernel/kernel/OpenFileDescriptorSet.cpp b/kernel/kernel/OpenFileDescriptorSet.cpp index 89cfa94db5..a4518745b4 100644 --- a/kernel/kernel/OpenFileDescriptorSet.cpp +++ b/kernel/kernel/OpenFileDescriptorSet.cpp @@ -55,6 +55,21 @@ namespace Kernel return {}; } + BAN::ErrorOr OpenFileDescriptorSet::open(BAN::RefPtr inode, int flags) + { + ASSERT(inode); + ASSERT(!inode->mode().ifdir()); + + if (flags & ~(O_RDONLY | O_WRONLY)) + return BAN::Error::from_errno(ENOTSUP); + + int fd = TRY(get_free_fd()); + // FIXME: path? + m_open_files[fd] = TRY(BAN::RefPtr::create(inode, ""sv, 0, flags)); + + return fd; + } + BAN::ErrorOr OpenFileDescriptorSet::open(BAN::StringView absolute_path, int flags) { if (flags & ~(O_RDONLY | O_WRONLY | O_NOFOLLOW | O_SEARCH | O_APPEND | O_TRUNC | O_CLOEXEC | O_TTY_INIT | O_DIRECTORY | O_NONBLOCK)) diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index d544557f50..cddf84ae32 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -708,6 +708,13 @@ namespace Kernel return false; } + BAN::ErrorOr Process::open_inode(BAN::RefPtr inode, int flags) + { + ASSERT(inode); + LockGuard _(m_lock); + return TRY(m_open_file_descriptors.open(inode, flags)); + } + BAN::ErrorOr Process::open_file(BAN::StringView path, int flags, mode_t mode) { LockGuard _(m_lock); @@ -924,8 +931,7 @@ namespace Kernel if (!inode->mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - TRY(inode->accept(address, address_len)); - return 0; + return TRY(inode->accept(address, address_len)); } BAN::ErrorOr Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len)