Kernel: Implement fd passing with SCM_RIGTHS

This commit is contained in:
Bananymous 2025-11-09 23:36:49 +02:00
parent 641ccfdd47
commit 7b580b8f56
5 changed files with 186 additions and 35 deletions

View File

@ -7,6 +7,7 @@
#include <kernel/FS/TmpFS/Inode.h>
#include <kernel/FS/VirtualFileSystem.h>
#include <kernel/Lock/SpinLock.h>
#include <kernel/OpenFileDescriptorSet.h>
namespace Kernel
{
@ -16,6 +17,9 @@ namespace Kernel
BAN_NON_COPYABLE(UnixDomainSocket);
BAN_NON_MOVABLE(UnixDomainSocket);
public:
using FDWrapper = OpenFileDescriptorSet::FDWrapper;
public:
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&);
BAN::ErrorOr<void> make_socket_pair(UnixDomainSocket&);
@ -38,7 +42,7 @@ namespace Kernel
UnixDomainSocket(Socket::Type, const Socket::Info&);
~UnixDomainSocket();
BAN::ErrorOr<void> add_packet(const msghdr&, size_t total_size);
BAN::ErrorOr<void> add_packet(const msghdr&, size_t total_size, BAN::Vector<FDWrapper>&& fds_to_send);
bool is_bound() const { return !m_bound_file.canonical_path.empty(); }
bool is_bound_to_unused() const { return !m_bound_file.inode; }
@ -62,17 +66,23 @@ namespace Kernel
BAN::String peer_address;
};
struct PacketInfo
{
size_t size;
BAN::Vector<FDWrapper> fds;
};
private:
const Socket::Type m_socket_type;
VirtualFileSystem::File m_bound_file;
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;
BAN::CircularQueue<size_t, 128> m_packet_sizes;
size_t m_packet_size_total { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
Mutex m_packet_lock;
ThreadBlocker m_packet_thread_blocker;
BAN::CircularQueue<PacketInfo, 512> m_packet_infos;
size_t m_packet_size_total { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
Mutex m_packet_lock;
ThreadBlocker m_packet_thread_blocker;
friend class BAN::RefPtr<UnixDomainSocket>;
};

View File

@ -109,6 +109,32 @@ namespace Kernel
BAN::ErrorOr<int> get_free_fd() const;
BAN::ErrorOr<void> get_free_fd_pair(int fds[2]) const;
public:
class FDWrapper
{
public:
FDWrapper(BAN::RefPtr<OpenFileDescription>);
FDWrapper(const FDWrapper& other) { *this = other; }
FDWrapper(FDWrapper&& other) { *this = BAN::move(other); }
~FDWrapper();
FDWrapper& operator=(const FDWrapper&);
FDWrapper& operator=(FDWrapper&&);
int fd() const { return m_fd; }
void clear();
private:
BAN::RefPtr<OpenFileDescription> m_description;
int m_fd { -1 };
friend class OpenFileDescriptorSet;
};
BAN::ErrorOr<FDWrapper> get_fd_wrapper(int fd);
size_t open_all_fd_wrappers(BAN::Span<FDWrapper> fd_wrappers);
private:
const Credentials& m_credentials;
mutable Mutex m_mutex;

View File

@ -240,6 +240,8 @@ namespace Kernel
// FIXME: remove this API
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
OpenFileDescriptorSet& open_file_descriptor_set() { return m_open_file_descriptors; }
// ONLY CALLED BY TIMER INTERRUPT
static void update_alarm_queue();

View File

@ -294,10 +294,11 @@ namespace Kernel
}
}
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size)
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size, BAN::Vector<FDWrapper>&& fds_to_send)
{
LockGuard _(m_packet_lock);
while (m_packet_sizes.full() || m_packet_size_total + total_size > s_packet_buffer_size)
while (m_packet_infos.full() || m_packet_size_total + total_size > s_packet_buffer_size)
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
@ -311,7 +312,7 @@ namespace Kernel
ASSERT(offset == total_size);
m_packet_size_total += total_size;
m_packet_sizes.push(total_size);
m_packet_infos.emplace(total_size, BAN::move(fds_to_send));
m_packet_thread_blocker.unblock();
@ -366,12 +367,6 @@ namespace Kernel
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
{
dwarnln("ignoring recvmsg control message");
message.msg_controllen = 0;
}
LockGuard _(m_packet_lock);
while (m_packet_size_total == 0)
{
@ -388,31 +383,68 @@ namespace Kernel
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
}
auto* cheader = CMSG_FIRSTHDR(&message);
if (cheader != nullptr)
cheader->cmsg_len = message.msg_controllen;
size_t cheader_len = 0;
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
const size_t max_recv_size = is_streaming() ? m_packet_size_total : m_packet_sizes.front();
message.msg_flags = 0;
size_t total_recv = 0;
for (int i = 0; i < message.msg_iovlen; i++)
{
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, max_recv_size - total_recv);
memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv);
auto& packet_info = m_packet_infos.front();
auto fds_to_open = BAN::move(packet_info.fds);
if (cheader == nullptr && !fds_to_open.empty())
message.msg_flags |= MSG_CTRUNC;
else if (!fds_to_open.empty())
{
const size_t max_fd_count = (cheader->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
if (max_fd_count < fds_to_open.size())
message.msg_flags |= MSG_CTRUNC;
const size_t fd_count = BAN::Math::min(fds_to_open.size(), max_fd_count);
const size_t fds_opened = Process::current().open_file_descriptor_set().open_all_fd_wrappers(fds_to_open.span().slice(0, fd_count));
auto* fd_data = reinterpret_cast<int*>(CMSG_DATA(cheader));
for (size_t i = 0; i < fds_opened; i++)
fd_data[i] = fds_to_open[i].fd();
const size_t header_length = CMSG_LEN(fds_opened * sizeof(int));
cheader->cmsg_level = SOL_SOCKET;
cheader->cmsg_type = SCM_RIGHTS;
cheader->cmsg_len = header_length;
cheader = CMSG_NXTHDR(&message, cheader);
if (cheader != nullptr)
cheader->cmsg_len = message.msg_controllen - header_length;
cheader_len += header_length;
}
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, packet_info.size);
memcpy(message.msg_iov[i].iov_base, packet_buffer, nrecv);
total_recv += nrecv;
if (!is_streaming() && nrecv < packet_info.size)
message.msg_flags |= MSG_TRUNC;
const size_t to_discard = is_streaming() ? nrecv : packet_info.size;
packet_info.size -= to_discard;
if (packet_info.size == 0)
m_packet_infos.pop();
// FIXME: get rid of this memmove :)
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
m_packet_size_total -= to_discard;
if (!is_streaming() || m_packet_infos.empty())
break;
}
size_t bytes_to_handle = total_recv;
while (bytes_to_handle)
{
const size_t to_handle = BAN::Math::min(bytes_to_handle, m_packet_sizes.front());
bytes_to_handle -= to_handle;
m_packet_sizes.front() -= to_handle;
if (m_packet_sizes.front() == 0)
m_packet_sizes.pop();
}
const size_t to_discard = is_streaming() ? total_recv : max_recv_size;
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
m_packet_size_total -= to_discard;
message.msg_controllen = cheader_len;
m_packet_thread_blocker.unblock();
@ -429,8 +461,27 @@ namespace Kernel
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
dwarnln("ignoring sendmsg control message");
BAN::Vector<OpenFileDescriptorSet::FDWrapper> fds_to_send;
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
{
if (header->cmsg_level != SOL_SOCKET)
{
dwarnln("ignoring control message with level {}", header->cmsg_level);
continue;
}
if (header->cmsg_type != SCM_RIGHTS)
{
dwarnln("ignoring control message with type {}", header->cmsg_type);
continue;
}
const auto* fd_data = reinterpret_cast<const int*>(CMSG_DATA(header));
const size_t fd_count = (header->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
for (size_t i = 0; i < fd_count; i++)
TRY(fds_to_send.push_back(TRY(Process::current().open_file_descriptor_set().get_fd_wrapper(fd_data[i]))));
}
const size_t total_message_size =
[&message]() -> size_t {
@ -449,7 +500,7 @@ namespace Kernel
auto target = connection_info.connection.lock();
if (!target)
return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet(message, total_message_size));
TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send)));
return total_message_size;
}
else
@ -493,7 +544,7 @@ namespace Kernel
if (!target)
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet(message, total_message_size));
TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send)));
return total_message_size;
}

View File

@ -659,4 +659,66 @@ namespace Kernel
return BAN::Error::from_errno(EMFILE);
}
using FDWrapper = OpenFileDescriptorSet::FDWrapper;
FDWrapper::FDWrapper(BAN::RefPtr<OpenFileDescription> description)
: m_description(description)
{
if (m_description)
m_description->file.inode->on_clone(m_description->status_flags);
}
FDWrapper::~FDWrapper()
{
clear();
}
FDWrapper& FDWrapper::operator=(const FDWrapper& other)
{
clear();
m_description = other.m_description;
if (m_description)
m_description->file.inode->on_clone(m_description->status_flags);
return *this;
}
FDWrapper& FDWrapper::operator=(FDWrapper&& other)
{
clear();
m_description = BAN::move(other.m_description);
return *this;
}
void FDWrapper::clear()
{
if (m_description)
m_description->file.inode->on_close(m_description->status_flags);
}
BAN::ErrorOr<FDWrapper> OpenFileDescriptorSet::get_fd_wrapper(int fd)
{
LockGuard _(m_mutex);
TRY(validate_fd(fd));
return FDWrapper { m_open_files[fd].description };
}
size_t OpenFileDescriptorSet::open_all_fd_wrappers(BAN::Span<FDWrapper> fd_wrappers)
{
LockGuard _(m_mutex);
for (size_t i = 0; i < fd_wrappers.size(); i++)
{
auto fd_or_error = get_free_fd();
if (fd_or_error.is_error())
return i;
const int fd = fd_or_error.release_value();
m_open_files[fd].description = BAN::move(fd_wrappers[i].m_description);
m_open_files[fd].descriptor_flags = 0;
fd_wrappers[i].m_fd = fd;
}
return fd_wrappers.size();
}
}