Kernel: Add support for SCM_CREDENTIALS and fix recvmsg

recvmsg was broken when receiving into more than a single iovec
This commit is contained in:
Bananymous 2025-11-18 02:51:28 +02:00
parent b8a2573bb4
commit d60f12d3b8
3 changed files with 163 additions and 47 deletions

View File

@ -42,8 +42,6 @@ namespace Kernel
UnixDomainSocket(Socket::Type, const Socket::Info&); UnixDomainSocket(Socket::Type, const Socket::Info&);
~UnixDomainSocket(); ~UnixDomainSocket();
BAN::ErrorOr<void> add_packet(const msghdr&, size_t total_size, BAN::Vector<FDWrapper>&& fds_to_send);
bool is_bound() const { return !m_bound_file.canonical_path.empty(); } bool is_bound() const { return !m_bound_file.canonical_path.empty(); }
bool is_bound_to_unused() const { return !m_bound_file.inode; } bool is_bound_to_unused() const { return !m_bound_file.inode; }
@ -70,8 +68,11 @@ namespace Kernel
{ {
size_t size; size_t size;
BAN::Vector<FDWrapper> fds; BAN::Vector<FDWrapper> fds;
BAN::Optional<struct ucred> ucred;
}; };
BAN::ErrorOr<void> add_packet(const msghdr&, PacketInfo&&);
private: private:
const Socket::Type m_socket_type; const Socket::Type m_socket_type;
VirtualFileSystem::File m_bound_file; VirtualFileSystem::File m_bound_file;

View File

@ -294,11 +294,11 @@ namespace Kernel
} }
} }
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size, BAN::Vector<FDWrapper>&& fds_to_send) BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info)
{ {
LockGuard _(m_packet_lock); LockGuard _(m_packet_lock);
while (m_packet_infos.full() || m_packet_size_total + total_size > s_packet_buffer_size) while (m_packet_infos.full() || m_packet_size_total + packet_info.size > s_packet_buffer_size)
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total); uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
@ -310,9 +310,9 @@ namespace Kernel
offset += packet.msg_iov[i].iov_len; offset += packet.msg_iov[i].iov_len;
} }
ASSERT(offset == total_size); ASSERT(offset == packet_info.size);
m_packet_size_total += total_size; m_packet_size_total += packet_info.size;
m_packet_infos.emplace(total_size, BAN::move(fds_to_send)); m_packet_infos.emplace(BAN::move(packet_info));
m_packet_thread_blocker.unblock(); m_packet_thread_blocker.unblock();
@ -326,10 +326,10 @@ namespace Kernel
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
{ {
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.listening)
return !connection_info.pending_connections.empty();
if (connection_info.target_closed) if (connection_info.target_closed)
return true; return true;
if (!connection_info.pending_connections.empty())
return true;
if (!connection_info.connection) if (!connection_info.connection)
return false; return false;
} }
@ -342,7 +342,13 @@ namespace Kernel
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
{ {
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
return connection_info.connection.valid(); auto connection = connection_info.connection.lock();
if (!connection)
return false;
if (connection->m_packet_infos.full())
return false;
if (connection->m_packet_size_total >= s_packet_buffer_size)
return false;
} }
return true; return true;
@ -393,16 +399,27 @@ namespace Kernel
message.msg_flags = 0; message.msg_flags = 0;
int iov_index = 0;
size_t iov_offset = 0;
size_t total_recv = 0; size_t total_recv = 0;
for (int i = 0; i < message.msg_iovlen; i++)
while (!m_packet_infos.empty() && iov_index < message.msg_iovlen)
{ {
auto& packet_info = m_packet_infos.front(); auto& packet_info = m_packet_infos.front();
auto fds_to_open = BAN::move(packet_info.fds); auto fds_to_open = BAN::move(packet_info.fds);
if (cheader == nullptr && !fds_to_open.empty()) auto ucred_to_recv = BAN::move(packet_info.ucred);
message.msg_flags |= MSG_CTRUNC; const bool had_ancillary_data = !fds_to_open.empty() || ucred_to_recv.has_value();
else if (!fds_to_open.empty())
if (!fds_to_open.empty()) do
{ {
if (cheader == nullptr)
{
dwarnln("no space to receive {} fds", fds_to_open.size());
message.msg_flags |= MSG_CTRUNC;
break;
}
const size_t max_fd_count = (cheader->cmsg_len - sizeof(cmsghdr)) / sizeof(int); const size_t max_fd_count = (cheader->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
if (max_fd_count < fds_to_open.size()) if (max_fd_count < fds_to_open.size())
message.msg_flags |= MSG_CTRUNC; message.msg_flags |= MSG_CTRUNC;
@ -422,16 +439,52 @@ namespace Kernel
if (cheader != nullptr) if (cheader != nullptr)
cheader->cmsg_len = message.msg_controllen - header_length; cheader->cmsg_len = message.msg_controllen - header_length;
cheader_len += header_length; cheader_len += header_length;
} while (false);
if (ucred_to_recv.has_value()) do
{
if (cheader == nullptr || cheader->cmsg_len - sizeof(cmsghdr) < sizeof(struct ucred))
{
dwarnln("no space to receive credentials");
message.msg_flags |= MSG_CTRUNC;
break;
} }
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, packet_info.size); *reinterpret_cast<struct ucred*>(CMSG_DATA(cheader)) = ucred_to_recv.value();
memcpy(message.msg_iov[i].iov_base, packet_buffer, nrecv);
total_recv += nrecv;
if (!is_streaming() && nrecv < packet_info.size) const size_t header_length = CMSG_LEN(sizeof(struct ucred));
cheader->cmsg_level = SOL_SOCKET;
cheader->cmsg_type = SCM_CREDENTIALS;
cheader->cmsg_len = header_length;
cheader = CMSG_NXTHDR(&message, cheader);
if (cheader != nullptr)
cheader->cmsg_len = message.msg_controllen - header_length;
cheader_len += header_length;
} while (false);
size_t packet_received = 0;
while (iov_index < message.msg_iovlen && packet_received < packet_info.size)
{
auto& iov = message.msg_iov[iov_index];
uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base);
const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received);
memcpy(iov_base + iov_offset, packet_buffer + packet_received, nrecv);
packet_received += nrecv;
iov_offset += nrecv;
if (iov_offset >= iov.iov_len)
{
iov_offset = 0;
iov_index++;
}
}
if (!is_streaming() && packet_received < packet_info.size)
message.msg_flags |= MSG_TRUNC; message.msg_flags |= MSG_TRUNC;
const size_t to_discard = is_streaming() ? nrecv : packet_info.size; const size_t to_discard = is_streaming() ? packet_received : packet_info.size;
packet_info.size -= to_discard; packet_info.size -= to_discard;
if (packet_info.size == 0) if (packet_info.size == 0)
@ -441,7 +494,10 @@ namespace Kernel
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard); memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
m_packet_size_total -= to_discard; m_packet_size_total -= to_discard;
if (!is_streaming() || m_packet_infos.empty()) total_recv += packet_received;
// on linux ancillary data is a barrier on stream sockets, lets do the same
if (!is_streaming() || had_ancillary_data)
break; break;
} }
@ -465,28 +521,6 @@ namespace Kernel
return BAN::Error::from_errno(ENOTSUP); return BAN::Error::from_errno(ENOTSUP);
} }
BAN::Vector<OpenFileDescriptorSet::FDWrapper> fds_to_send;
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
{
if (header->cmsg_level != SOL_SOCKET)
{
dwarnln("ignoring control message with level {}", header->cmsg_level);
continue;
}
if (header->cmsg_type != SCM_RIGHTS)
{
dwarnln("ignoring control message with type {}", header->cmsg_type);
continue;
}
const auto* fd_data = reinterpret_cast<const int*>(CMSG_DATA(header));
const size_t fd_count = (header->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
for (size_t i = 0; i < fd_count; i++)
TRY(fds_to_send.push_back(TRY(Process::current().open_file_descriptor_set().get_fd_wrapper(fd_data[i]))));
}
const size_t total_message_size = const size_t total_message_size =
[&message]() -> size_t { [&message]() -> size_t {
size_t result = 0; size_t result = 0;
@ -498,20 +532,90 @@ namespace Kernel
if (total_message_size > s_packet_buffer_size) if (total_message_size > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS); return BAN::Error::from_errno(ENOBUFS);
PacketInfo packet_info {
.size = total_message_size,
.fds = {},
.ucred = {},
};
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
{
if (header->cmsg_level != SOL_SOCKET)
{
dwarnln("ignoring control message with level {}", header->cmsg_level);
continue;
}
switch (header->cmsg_type)
{
case SCM_RIGHTS:
{
if (!packet_info.fds.empty())
{
dwarnln("multiple SCM_RIGHTS in one sendmsg");
return BAN::Error::from_errno(EINVAL);
}
const auto* fd_data = reinterpret_cast<const int*>(CMSG_DATA(header));
const size_t fd_count = (header->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
for (size_t i = 0; i < fd_count; i++)
TRY(packet_info.fds.push_back(TRY(Process::current().open_file_descriptor_set().get_fd_wrapper(fd_data[i]))));
break;
}
case SCM_CREDENTIALS:
{
if (packet_info.ucred.has_value())
{
dwarnln("multiple SCM_CREDENTIALS in one sendmsg");
return BAN::Error::from_errno(EINVAL);
}
if (header->cmsg_len - sizeof(cmsghdr) < sizeof(struct ucred))
return BAN::Error::from_errno(EINVAL);
const ucred* ucred = reinterpret_cast<const struct ucred*>(CMSG_DATA(header));
const bool is_valid_ucred =
[ucred]() -> bool
{
const auto& creds = Process::current().credentials();
if (creds.is_superuser())
return true;
if (ucred->pid != Process::current().pid())
return false;
if (ucred->uid != creds.ruid() && ucred->uid != creds.euid() && ucred->uid != creds.suid())
return false;
if (ucred->gid != creds.rgid() && !creds.has_egid(ucred->gid) && ucred->gid != creds.sgid())
return false;
return true;
}();
if (!is_valid_ucred)
return BAN::Error::from_errno(EPERM);
packet_info.ucred = *ucred;
break;
}
default:
dwarnln("ignoring control message with type {}", header->cmsg_type);
break;
}
}
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
{ {
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
auto target = connection_info.connection.lock(); auto target = connection_info.connection.lock();
if (!target) if (!target)
return BAN::Error::from_errno(ENOTCONN); return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send))); TRY(target->add_packet(message, BAN::move(packet_info)));
return total_message_size; return total_message_size;
} }
else else
{ {
BAN::RefPtr<Inode> target_inode; BAN::RefPtr<Inode> target_inode;
if (!message.msg_name) if (!message.msg_name || message.msg_namelen == 0)
{ {
auto& connectionless_info = m_info.get<ConnectionlessInfo>(); auto& connectionless_info = m_info.get<ConnectionlessInfo>();
if (connectionless_info.peer_address.empty()) if (connectionless_info.peer_address.empty())
@ -548,7 +652,7 @@ namespace Kernel
if (!target) if (!target)
return BAN::Error::from_errno(EDESTADDRREQ); return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send))); TRY(target->add_packet(message, BAN::move(packet_info)));
return total_message_size; return total_message_size;
} }

View File

@ -7,6 +7,9 @@
__BEGIN_DECLS __BEGIN_DECLS
#define __need_pid_t
#define __need_uid_t
#define __need_gid_t
#define __need_size_t #define __need_size_t
#define __need_ssize_t #define __need_ssize_t
#include <sys/types.h> #include <sys/types.h>
@ -54,6 +57,7 @@ struct cmsghdr
}; };
#define SCM_RIGHTS 1 #define SCM_RIGHTS 1
#define SCM_CREDENTIALS 2
#define CMSG_DATA(cmsg) ((cmsg)->__cmg_data) #define CMSG_DATA(cmsg) ((cmsg)->__cmg_data)
@ -78,6 +82,13 @@ struct cmsghdr
#define CMSG_LEN(length) \ #define CMSG_LEN(length) \
(socklen_t)((length) + sizeof(struct cmsghdr)) (socklen_t)((length) + sizeof(struct cmsghdr))
struct ucred
{
pid_t pid;
uid_t uid;
gid_t gid;
};
struct linger struct linger
{ {
int l_onoff; /* Indicates wheter linger option is enabled. */ int l_onoff; /* Indicates wheter linger option is enabled. */