Kernel: Implement fd passing with SCM_RIGTHS
This commit is contained in:
parent
641ccfdd47
commit
7b580b8f56
|
|
@ -7,6 +7,7 @@
|
||||||
#include <kernel/FS/TmpFS/Inode.h>
|
#include <kernel/FS/TmpFS/Inode.h>
|
||||||
#include <kernel/FS/VirtualFileSystem.h>
|
#include <kernel/FS/VirtualFileSystem.h>
|
||||||
#include <kernel/Lock/SpinLock.h>
|
#include <kernel/Lock/SpinLock.h>
|
||||||
|
#include <kernel/OpenFileDescriptorSet.h>
|
||||||
|
|
||||||
namespace Kernel
|
namespace Kernel
|
||||||
{
|
{
|
||||||
|
|
@ -16,6 +17,9 @@ namespace Kernel
|
||||||
BAN_NON_COPYABLE(UnixDomainSocket);
|
BAN_NON_COPYABLE(UnixDomainSocket);
|
||||||
BAN_NON_MOVABLE(UnixDomainSocket);
|
BAN_NON_MOVABLE(UnixDomainSocket);
|
||||||
|
|
||||||
|
public:
|
||||||
|
using FDWrapper = OpenFileDescriptorSet::FDWrapper;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&);
|
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&);
|
||||||
BAN::ErrorOr<void> make_socket_pair(UnixDomainSocket&);
|
BAN::ErrorOr<void> make_socket_pair(UnixDomainSocket&);
|
||||||
|
|
@ -38,7 +42,7 @@ namespace Kernel
|
||||||
UnixDomainSocket(Socket::Type, const Socket::Info&);
|
UnixDomainSocket(Socket::Type, const Socket::Info&);
|
||||||
~UnixDomainSocket();
|
~UnixDomainSocket();
|
||||||
|
|
||||||
BAN::ErrorOr<void> add_packet(const msghdr&, size_t total_size);
|
BAN::ErrorOr<void> add_packet(const msghdr&, size_t total_size, BAN::Vector<FDWrapper>&& fds_to_send);
|
||||||
|
|
||||||
bool is_bound() const { return !m_bound_file.canonical_path.empty(); }
|
bool is_bound() const { return !m_bound_file.canonical_path.empty(); }
|
||||||
bool is_bound_to_unused() const { return !m_bound_file.inode; }
|
bool is_bound_to_unused() const { return !m_bound_file.inode; }
|
||||||
|
|
@ -62,13 +66,19 @@ namespace Kernel
|
||||||
BAN::String peer_address;
|
BAN::String peer_address;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct PacketInfo
|
||||||
|
{
|
||||||
|
size_t size;
|
||||||
|
BAN::Vector<FDWrapper> fds;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const Socket::Type m_socket_type;
|
const Socket::Type m_socket_type;
|
||||||
VirtualFileSystem::File m_bound_file;
|
VirtualFileSystem::File m_bound_file;
|
||||||
|
|
||||||
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;
|
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;
|
||||||
|
|
||||||
BAN::CircularQueue<size_t, 128> m_packet_sizes;
|
BAN::CircularQueue<PacketInfo, 512> m_packet_infos;
|
||||||
size_t m_packet_size_total { 0 };
|
size_t m_packet_size_total { 0 };
|
||||||
BAN::UniqPtr<VirtualRange> m_packet_buffer;
|
BAN::UniqPtr<VirtualRange> m_packet_buffer;
|
||||||
Mutex m_packet_lock;
|
Mutex m_packet_lock;
|
||||||
|
|
|
||||||
|
|
@ -109,6 +109,32 @@ namespace Kernel
|
||||||
BAN::ErrorOr<int> get_free_fd() const;
|
BAN::ErrorOr<int> get_free_fd() const;
|
||||||
BAN::ErrorOr<void> get_free_fd_pair(int fds[2]) const;
|
BAN::ErrorOr<void> get_free_fd_pair(int fds[2]) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
class FDWrapper
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FDWrapper(BAN::RefPtr<OpenFileDescription>);
|
||||||
|
FDWrapper(const FDWrapper& other) { *this = other; }
|
||||||
|
FDWrapper(FDWrapper&& other) { *this = BAN::move(other); }
|
||||||
|
~FDWrapper();
|
||||||
|
|
||||||
|
FDWrapper& operator=(const FDWrapper&);
|
||||||
|
FDWrapper& operator=(FDWrapper&&);
|
||||||
|
|
||||||
|
int fd() const { return m_fd; }
|
||||||
|
|
||||||
|
void clear();
|
||||||
|
|
||||||
|
private:
|
||||||
|
BAN::RefPtr<OpenFileDescription> m_description;
|
||||||
|
int m_fd { -1 };
|
||||||
|
|
||||||
|
friend class OpenFileDescriptorSet;
|
||||||
|
};
|
||||||
|
|
||||||
|
BAN::ErrorOr<FDWrapper> get_fd_wrapper(int fd);
|
||||||
|
size_t open_all_fd_wrappers(BAN::Span<FDWrapper> fd_wrappers);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const Credentials& m_credentials;
|
const Credentials& m_credentials;
|
||||||
mutable Mutex m_mutex;
|
mutable Mutex m_mutex;
|
||||||
|
|
|
||||||
|
|
@ -240,6 +240,8 @@ namespace Kernel
|
||||||
// FIXME: remove this API
|
// FIXME: remove this API
|
||||||
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
|
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
|
||||||
|
|
||||||
|
OpenFileDescriptorSet& open_file_descriptor_set() { return m_open_file_descriptors; }
|
||||||
|
|
||||||
// ONLY CALLED BY TIMER INTERRUPT
|
// ONLY CALLED BY TIMER INTERRUPT
|
||||||
static void update_alarm_queue();
|
static void update_alarm_queue();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -294,10 +294,11 @@ namespace Kernel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size)
|
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size, BAN::Vector<FDWrapper>&& fds_to_send)
|
||||||
{
|
{
|
||||||
LockGuard _(m_packet_lock);
|
LockGuard _(m_packet_lock);
|
||||||
while (m_packet_sizes.full() || m_packet_size_total + total_size > s_packet_buffer_size)
|
|
||||||
|
while (m_packet_infos.full() || m_packet_size_total + total_size > s_packet_buffer_size)
|
||||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
||||||
|
|
||||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
|
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
|
||||||
|
|
@ -311,7 +312,7 @@ namespace Kernel
|
||||||
|
|
||||||
ASSERT(offset == total_size);
|
ASSERT(offset == total_size);
|
||||||
m_packet_size_total += total_size;
|
m_packet_size_total += total_size;
|
||||||
m_packet_sizes.push(total_size);
|
m_packet_infos.emplace(total_size, BAN::move(fds_to_send));
|
||||||
|
|
||||||
m_packet_thread_blocker.unblock();
|
m_packet_thread_blocker.unblock();
|
||||||
|
|
||||||
|
|
@ -366,12 +367,6 @@ namespace Kernel
|
||||||
return BAN::Error::from_errno(ENOTSUP);
|
return BAN::Error::from_errno(ENOTSUP);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (CMSG_FIRSTHDR(&message))
|
|
||||||
{
|
|
||||||
dwarnln("ignoring recvmsg control message");
|
|
||||||
message.msg_controllen = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
LockGuard _(m_packet_lock);
|
LockGuard _(m_packet_lock);
|
||||||
while (m_packet_size_total == 0)
|
while (m_packet_size_total == 0)
|
||||||
{
|
{
|
||||||
|
|
@ -388,32 +383,69 @@ namespace Kernel
|
||||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto* cheader = CMSG_FIRSTHDR(&message);
|
||||||
|
if (cheader != nullptr)
|
||||||
|
cheader->cmsg_len = message.msg_controllen;
|
||||||
|
size_t cheader_len = 0;
|
||||||
|
|
||||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
||||||
|
|
||||||
const size_t max_recv_size = is_streaming() ? m_packet_size_total : m_packet_sizes.front();
|
message.msg_flags = 0;
|
||||||
|
|
||||||
size_t total_recv = 0;
|
size_t total_recv = 0;
|
||||||
for (int i = 0; i < message.msg_iovlen; i++)
|
for (int i = 0; i < message.msg_iovlen; i++)
|
||||||
{
|
{
|
||||||
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, max_recv_size - total_recv);
|
auto& packet_info = m_packet_infos.front();
|
||||||
memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv);
|
|
||||||
total_recv += nrecv;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t bytes_to_handle = total_recv;
|
auto fds_to_open = BAN::move(packet_info.fds);
|
||||||
while (bytes_to_handle)
|
if (cheader == nullptr && !fds_to_open.empty())
|
||||||
|
message.msg_flags |= MSG_CTRUNC;
|
||||||
|
else if (!fds_to_open.empty())
|
||||||
{
|
{
|
||||||
const size_t to_handle = BAN::Math::min(bytes_to_handle, m_packet_sizes.front());
|
const size_t max_fd_count = (cheader->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
|
||||||
bytes_to_handle -= to_handle;
|
if (max_fd_count < fds_to_open.size())
|
||||||
m_packet_sizes.front() -= to_handle;
|
message.msg_flags |= MSG_CTRUNC;
|
||||||
if (m_packet_sizes.front() == 0)
|
|
||||||
m_packet_sizes.pop();
|
const size_t fd_count = BAN::Math::min(fds_to_open.size(), max_fd_count);
|
||||||
|
const size_t fds_opened = Process::current().open_file_descriptor_set().open_all_fd_wrappers(fds_to_open.span().slice(0, fd_count));
|
||||||
|
|
||||||
|
auto* fd_data = reinterpret_cast<int*>(CMSG_DATA(cheader));
|
||||||
|
for (size_t i = 0; i < fds_opened; i++)
|
||||||
|
fd_data[i] = fds_to_open[i].fd();
|
||||||
|
|
||||||
|
const size_t header_length = CMSG_LEN(fds_opened * sizeof(int));
|
||||||
|
cheader->cmsg_level = SOL_SOCKET;
|
||||||
|
cheader->cmsg_type = SCM_RIGHTS;
|
||||||
|
cheader->cmsg_len = header_length;
|
||||||
|
cheader = CMSG_NXTHDR(&message, cheader);
|
||||||
|
if (cheader != nullptr)
|
||||||
|
cheader->cmsg_len = message.msg_controllen - header_length;
|
||||||
|
cheader_len += header_length;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t to_discard = is_streaming() ? total_recv : max_recv_size;
|
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, packet_info.size);
|
||||||
|
memcpy(message.msg_iov[i].iov_base, packet_buffer, nrecv);
|
||||||
|
total_recv += nrecv;
|
||||||
|
|
||||||
|
if (!is_streaming() && nrecv < packet_info.size)
|
||||||
|
message.msg_flags |= MSG_TRUNC;
|
||||||
|
|
||||||
|
const size_t to_discard = is_streaming() ? nrecv : packet_info.size;
|
||||||
|
|
||||||
|
packet_info.size -= to_discard;
|
||||||
|
if (packet_info.size == 0)
|
||||||
|
m_packet_infos.pop();
|
||||||
|
|
||||||
|
// FIXME: get rid of this memmove :)
|
||||||
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
|
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
|
||||||
m_packet_size_total -= to_discard;
|
m_packet_size_total -= to_discard;
|
||||||
|
|
||||||
|
if (!is_streaming() || m_packet_infos.empty())
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
message.msg_controllen = cheader_len;
|
||||||
|
|
||||||
m_packet_thread_blocker.unblock();
|
m_packet_thread_blocker.unblock();
|
||||||
|
|
||||||
epoll_notify(EPOLLOUT);
|
epoll_notify(EPOLLOUT);
|
||||||
|
|
@ -429,8 +461,27 @@ namespace Kernel
|
||||||
return BAN::Error::from_errno(ENOTSUP);
|
return BAN::Error::from_errno(ENOTSUP);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (CMSG_FIRSTHDR(&message))
|
BAN::Vector<OpenFileDescriptorSet::FDWrapper> fds_to_send;
|
||||||
dwarnln("ignoring sendmsg control message");
|
|
||||||
|
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
|
||||||
|
{
|
||||||
|
if (header->cmsg_level != SOL_SOCKET)
|
||||||
|
{
|
||||||
|
dwarnln("ignoring control message with level {}", header->cmsg_level);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (header->cmsg_type != SCM_RIGHTS)
|
||||||
|
{
|
||||||
|
dwarnln("ignoring control message with type {}", header->cmsg_type);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto* fd_data = reinterpret_cast<const int*>(CMSG_DATA(header));
|
||||||
|
const size_t fd_count = (header->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
|
||||||
|
for (size_t i = 0; i < fd_count; i++)
|
||||||
|
TRY(fds_to_send.push_back(TRY(Process::current().open_file_descriptor_set().get_fd_wrapper(fd_data[i]))));
|
||||||
|
}
|
||||||
|
|
||||||
const size_t total_message_size =
|
const size_t total_message_size =
|
||||||
[&message]() -> size_t {
|
[&message]() -> size_t {
|
||||||
|
|
@ -449,7 +500,7 @@ namespace Kernel
|
||||||
auto target = connection_info.connection.lock();
|
auto target = connection_info.connection.lock();
|
||||||
if (!target)
|
if (!target)
|
||||||
return BAN::Error::from_errno(ENOTCONN);
|
return BAN::Error::from_errno(ENOTCONN);
|
||||||
TRY(target->add_packet(message, total_message_size));
|
TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send)));
|
||||||
return total_message_size;
|
return total_message_size;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
|
@ -493,7 +544,7 @@ namespace Kernel
|
||||||
|
|
||||||
if (!target)
|
if (!target)
|
||||||
return BAN::Error::from_errno(EDESTADDRREQ);
|
return BAN::Error::from_errno(EDESTADDRREQ);
|
||||||
TRY(target->add_packet(message, total_message_size));
|
TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send)));
|
||||||
|
|
||||||
return total_message_size;
|
return total_message_size;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -659,4 +659,66 @@ namespace Kernel
|
||||||
return BAN::Error::from_errno(EMFILE);
|
return BAN::Error::from_errno(EMFILE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using FDWrapper = OpenFileDescriptorSet::FDWrapper;
|
||||||
|
|
||||||
|
FDWrapper::FDWrapper(BAN::RefPtr<OpenFileDescription> description)
|
||||||
|
: m_description(description)
|
||||||
|
{
|
||||||
|
if (m_description)
|
||||||
|
m_description->file.inode->on_clone(m_description->status_flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
FDWrapper::~FDWrapper()
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
FDWrapper& FDWrapper::operator=(const FDWrapper& other)
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
m_description = other.m_description;
|
||||||
|
if (m_description)
|
||||||
|
m_description->file.inode->on_clone(m_description->status_flags);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
FDWrapper& FDWrapper::operator=(FDWrapper&& other)
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
m_description = BAN::move(other.m_description);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FDWrapper::clear()
|
||||||
|
{
|
||||||
|
if (m_description)
|
||||||
|
m_description->file.inode->on_close(m_description->status_flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
BAN::ErrorOr<FDWrapper> OpenFileDescriptorSet::get_fd_wrapper(int fd)
|
||||||
|
{
|
||||||
|
LockGuard _(m_mutex);
|
||||||
|
TRY(validate_fd(fd));
|
||||||
|
return FDWrapper { m_open_files[fd].description };
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t OpenFileDescriptorSet::open_all_fd_wrappers(BAN::Span<FDWrapper> fd_wrappers)
|
||||||
|
{
|
||||||
|
LockGuard _(m_mutex);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < fd_wrappers.size(); i++)
|
||||||
|
{
|
||||||
|
auto fd_or_error = get_free_fd();
|
||||||
|
if (fd_or_error.is_error())
|
||||||
|
return i;
|
||||||
|
|
||||||
|
const int fd = fd_or_error.release_value();
|
||||||
|
m_open_files[fd].description = BAN::move(fd_wrappers[i].m_description);
|
||||||
|
m_open_files[fd].descriptor_flags = 0;
|
||||||
|
fd_wrappers[i].m_fd = fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
return fd_wrappers.size();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue