Kernel: Implement basic connection-mode unix domain sockets

This commit is contained in:
Bananymous 2024-02-08 02:28:19 +02:00
parent 0c8e9fe095
commit e7dd03e551
13 changed files with 454 additions and 22 deletions

View File

@ -58,6 +58,7 @@ set(KERNEL_SOURCES
kernel/Networking/NetworkManager.cpp
kernel/Networking/NetworkSocket.cpp
kernel/Networking/UDPSocket.cpp
kernel/Networking/UNIX/Socket.cpp
kernel/OpenFileDescriptorSet.cpp
kernel/Panic.cpp
kernel/PCI.cpp

View File

@ -100,7 +100,7 @@ namespace Kernel
BAN::ErrorOr<BAN::String> link_target();
// Socket API
BAN::ErrorOr<void> accept(sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<long> accept(sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> listen(int backlog);
@ -131,7 +131,7 @@ namespace Kernel
virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); }
// Socket API
virtual BAN::ErrorOr<void> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }

View File

@ -0,0 +1,20 @@
#pragma once
namespace Kernel
{
enum class SocketDomain
{
INET,
INET6,
UNIX,
};
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
}

View File

@ -25,7 +25,7 @@ namespace Kernel
BAN::Vector<BAN::RefPtr<NetworkInterface>> interfaces() { return m_interfaces; }
BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t);
BAN::ErrorOr<BAN::RefPtr<TmpInode>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t);
void on_receive(NetworkInterface&, BAN::ConstByteSpan);

View File

@ -1,6 +1,7 @@
#pragma once
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
@ -16,20 +17,6 @@ namespace Kernel
UDP = 0x11,
};
enum class SocketDomain
{
INET,
INET6,
UNIX,
};
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
class NetworkSocket : public TmpInode, public BAN::Weakable<NetworkSocket>
{
BAN_NON_COPYABLE(NetworkSocket);

View File

@ -0,0 +1,67 @@
#pragma once
#include <BAN/Queue.h>
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
namespace Kernel
{
class UnixDomainSocket final : public TmpInode, public BAN::Weakable<UnixDomainSocket>
{
BAN_NON_COPYABLE(UnixDomainSocket);
BAN_NON_MOVABLE(UnixDomainSocket);
public:
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);
BAN::ErrorOr<void> add_packet(BAN::ConstByteSpan);
bool is_bound() const { return !m_bound_path.empty(); }
bool is_bound_to_unused() const { return m_bound_path == "X"sv; }
bool is_streaming() const;
private:
struct ConnectionInfo
{
bool listening { false };
BAN::Atomic<bool> connection_done { false };
BAN::WeakPtr<UnixDomainSocket> connection;
BAN::Queue<BAN::RefPtr<UnixDomainSocket>> pending_connections;
Semaphore pending_semaphore;
SpinLock pending_lock;
};
struct ConnectionlessInfo
{
};
private:
const SocketType m_socket_type;
BAN::String m_bound_path;
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;
BAN::CircularQueue<size_t, 128> m_packet_sizes;
size_t m_packet_size_total { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
Semaphore m_packet_semaphore;
friend class BAN::RefPtr<UnixDomainSocket>;
};
}

View File

@ -21,6 +21,7 @@ namespace Kernel
BAN::ErrorOr<void> clone_from(const OpenFileDescriptorSet&);
BAN::ErrorOr<int> open(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags);
BAN::ErrorOr<int> socket(int domain, int type, int protocol);

View File

@ -95,6 +95,8 @@ namespace Kernel
BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); }
BAN::ErrorOr<long> sys_getpgid(pid_t);
BAN::ErrorOr<long> open_inode(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode);
BAN::ErrorOr<long> open_file(BAN::StringView path, int oflag, mode_t = 0);
BAN::ErrorOr<long> sys_open(const char* path, int, mode_t);

View File

@ -116,7 +116,7 @@ namespace Kernel
return link_target_impl();
}
BAN::ErrorOr<void> Inode::accept(sockaddr* address, socklen_t* address_len)
BAN::ErrorOr<long> Inode::accept(sockaddr* address, socklen_t* address_len)
{
LockGuard _(m_lock);
if (!mode().ifsock())

View File

@ -6,6 +6,7 @@
#include <kernel/Networking/ICMP.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UDPSocket.h>
#include <kernel/Networking/UNIX/Socket.h>
#define DEBUG_ETHERTYPE 0
@ -70,7 +71,7 @@ namespace Kernel
return {};
}
BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid)
BAN::ErrorOr<BAN::RefPtr<TmpInode>> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid)
{
switch (domain)
{
@ -80,6 +81,10 @@ namespace Kernel
return BAN::Error::from_errno(EPROTOTYPE);
break;
}
case SocketDomain::UNIX:
{
break;
}
default:
return BAN::Error::from_errno(EAFNOSUPPORT);
}
@ -90,7 +95,7 @@ namespace Kernel
auto inode_info = create_inode_info(mode, uid, gid);
ino_t ino = TRY(allocate_inode(inode_info));
BAN::RefPtr<NetworkSocket> socket;
BAN::RefPtr<TmpInode> socket;
switch (domain)
{
case SocketDomain::INET:
@ -99,6 +104,11 @@ namespace Kernel
socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info));
break;
}
case SocketDomain::UNIX:
{
socket = TRY(UnixDomainSocket::create(type, ino, inode_info));
break;
}
default:
ASSERT_NOT_REACHED();
}

View File

@ -0,0 +1,323 @@
#include <BAN/HashMap.h>
#include <kernel/FS/VirtualFileSystem.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UNIX/Socket.h>
#include <kernel/Scheduler.h>
#include <fcntl.h>
#include <sys/un.h>
namespace Kernel
{
static BAN::HashMap<BAN::String, BAN::RefPtr<UnixDomainSocket>> s_bound_sockets;
static SpinLock s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
{
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, ino, inode_info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
s_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
return socket;
}
UnixDomainSocket::UnixDomainSocket(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
, m_socket_type(socket_type)
{
switch (socket_type)
{
case SocketType::STREAM:
case SocketType::SEQPACKET:
m_info.emplace<ConnectionInfo>();
break;
case SocketType::DGRAM:
m_info.emplace<ConnectionlessInfo>();
break;
default:
ASSERT_NOT_REACHED();
}
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len)
{
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.listening)
return BAN::Error::from_errno(EINVAL);
while (connection_info.pending_connections.empty())
TRY(Thread::current().block_or_eintr(connection_info.pending_semaphore));
BAN::RefPtr<UnixDomainSocket> pending;
{
LockGuard _(connection_info.pending_lock);
pending = connection_info.pending_connections.front();
connection_info.pending_connections.pop();
connection_info.pending_semaphore.unblock();
}
BAN::RefPtr<UnixDomainSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(SocketDomain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
}
TRY(return_inode->m_bound_path.append(m_bound_path));
return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection_done = true;
if (address && address_len && !is_bound_to_unused())
{
size_t copy_len = BAN::Math::min<size_t>(*address_len, sizeof(sockaddr) + m_bound_path.size() + 1);
auto& sockaddr_un = *reinterpret_cast<struct sockaddr_un*>(address);
sockaddr_un.sun_family = AF_UNIX;
strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len);
}
return TRY(Process::current().open_inode(return_inode, O_RDWR));
}
BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (!is_bound())
TRY(m_bound_path.push_back('X'));
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
BAN::RefPtr<UnixDomainSocket> target;
{
LockGuard _(s_bound_socket_lock);
if (!s_bound_sockets.contains(file.canonical_path))
return BAN::Error::from_errno(ECONNREFUSED);
target = s_bound_sockets[file.canonical_path];
}
if (m_socket_type != target->m_socket_type)
return BAN::Error::from_errno(EPROTOTYPE);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(ECONNREFUSED);
if (connection_info.listening)
return BAN::Error::from_errno(EOPNOTSUPP);
connection_info.connection_done = false;
for (;;)
{
auto& target_info = target->m_info.get<ConnectionInfo>();
{
LockGuard _(target_info.pending_lock);
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
{
MUST(target_info.pending_connections.push(this));
target_info.pending_semaphore.unblock();
break;
}
}
TRY(Thread::current().block_or_eintr(target_info.pending_semaphore));
}
while (!connection_info.connection_done)
Scheduler::get().reschedule();
return {};
}
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog)
{
backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN);
if (!is_bound())
return BAN::Error::from_errno(EDESTADDRREQ);
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(EINVAL);
TRY(connection_info.pending_connections.reserve(backlog));
connection_info.listening = true;
return {};
}
BAN::ErrorOr<void> UnixDomainSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (is_bound())
return BAN::Error::from_errno(EINVAL);
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
if (auto ret = Process::current().create_file_or_dir(absolute_path, 0755 | S_IFSOCK); ret.is_error())
{
if (ret.error().get_error_code() == EEXIST)
return BAN::Error::from_errno(EADDRINUSE);
return ret.release_error();
}
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
LockGuard _(s_bound_socket_lock);
ASSERT(!s_bound_sockets.contains(file.canonical_path));
TRY(s_bound_sockets.emplace(file.canonical_path, this));
m_bound_path = BAN::move(file.canonical_path);
return {};
}
bool UnixDomainSocket::is_streaming() const
{
switch (m_socket_type)
{
case SocketType::STREAM:
return true;
case SocketType::SEQPACKET:
case SocketType::DGRAM:
return false;
default:
ASSERT_NOT_REACHED();
}
}
// This to feels too hacky to expose out of here
struct LockFreeGuard
{
LockFreeGuard(RecursivePrioritySpinLock& lock)
: m_lock(lock)
, m_depth(lock.lock_depth())
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.unlock();
}
~LockFreeGuard()
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.lock();
}
private:
RecursivePrioritySpinLock& m_lock;
const uint32_t m_depth;
};
BAN::ErrorOr<void> UnixDomainSocket::add_packet(BAN::ConstByteSpan packet)
{
LockGuard _(m_lock);
while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size)
{
LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
memcpy(packet_buffer, packet.data(), packet.size());
m_packet_size_total += packet.size();
if (!is_streaming())
m_packet_sizes.push(packet.size());
m_packet_semaphore.unblock();
return {};
}
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(const sys_sendto_t* arguments)
{
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (arguments->length > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (arguments->dest_addr)
return BAN::Error::from_errno(EISCONN);
auto target = connection_info.connection.lock();
if (!target)
return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet({ reinterpret_cast<const uint8_t*>(arguments->message), arguments->length }));
return arguments->length;
}
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.connection)
return BAN::Error::from_errno(ENOTCONN);
}
while (m_packet_size_total == 0)
{
LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t nread = 0;
if (is_streaming())
nread = BAN::Math::min(arguments->length, m_packet_size_total);
else
{
nread = BAN::Math::min(arguments->length, m_packet_sizes.front());
m_packet_sizes.pop();
}
memcpy(arguments->buffer, packet_buffer, nread);
memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread);
m_packet_size_total -= nread;
m_packet_semaphore.unblock();
return nread;
}
}

View File

@ -55,6 +55,21 @@ namespace Kernel
return {};
}
BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::RefPtr<Inode> inode, int flags)
{
ASSERT(inode);
ASSERT(!inode->mode().ifdir());
if (flags & ~(O_RDONLY | O_WRONLY))
return BAN::Error::from_errno(ENOTSUP);
int fd = TRY(get_free_fd());
// FIXME: path?
m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(inode, ""sv, 0, flags));
return fd;
}
BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::StringView absolute_path, int flags)
{
if (flags & ~(O_RDONLY | O_WRONLY | O_NOFOLLOW | O_SEARCH | O_APPEND | O_TRUNC | O_CLOEXEC | O_TTY_INIT | O_DIRECTORY | O_NONBLOCK))

View File

@ -708,6 +708,13 @@ namespace Kernel
return false;
}
BAN::ErrorOr<long> Process::open_inode(BAN::RefPtr<Inode> inode, int flags)
{
ASSERT(inode);
LockGuard _(m_lock);
return TRY(m_open_file_descriptors.open(inode, flags));
}
BAN::ErrorOr<long> Process::open_file(BAN::StringView path, int flags, mode_t mode)
{
LockGuard _(m_lock);
@ -924,8 +931,7 @@ namespace Kernel
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
TRY(inode->accept(address, address_len));
return 0;
return TRY(inode->accept(address, address_len));
}
BAN::ErrorOr<long> Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len)