Kernel: Add support for SCM_CREDENTIALS and fix recvmsg
recvmsg was broken when receiving into more than a single iovec
This commit is contained in:
parent
b8a2573bb4
commit
d60f12d3b8
|
|
@ -42,8 +42,6 @@ namespace Kernel
|
|||
UnixDomainSocket(Socket::Type, const Socket::Info&);
|
||||
~UnixDomainSocket();
|
||||
|
||||
BAN::ErrorOr<void> add_packet(const msghdr&, size_t total_size, BAN::Vector<FDWrapper>&& fds_to_send);
|
||||
|
||||
bool is_bound() const { return !m_bound_file.canonical_path.empty(); }
|
||||
bool is_bound_to_unused() const { return !m_bound_file.inode; }
|
||||
|
||||
|
|
@ -70,8 +68,11 @@ namespace Kernel
|
|||
{
|
||||
size_t size;
|
||||
BAN::Vector<FDWrapper> fds;
|
||||
BAN::Optional<struct ucred> ucred;
|
||||
};
|
||||
|
||||
BAN::ErrorOr<void> add_packet(const msghdr&, PacketInfo&&);
|
||||
|
||||
private:
|
||||
const Socket::Type m_socket_type;
|
||||
VirtualFileSystem::File m_bound_file;
|
||||
|
|
|
|||
|
|
@ -294,11 +294,11 @@ namespace Kernel
|
|||
}
|
||||
}
|
||||
|
||||
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size, BAN::Vector<FDWrapper>&& fds_to_send)
|
||||
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info)
|
||||
{
|
||||
LockGuard _(m_packet_lock);
|
||||
|
||||
while (m_packet_infos.full() || m_packet_size_total + total_size > s_packet_buffer_size)
|
||||
while (m_packet_infos.full() || m_packet_size_total + packet_info.size > s_packet_buffer_size)
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
||||
|
||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
|
||||
|
|
@ -310,9 +310,9 @@ namespace Kernel
|
|||
offset += packet.msg_iov[i].iov_len;
|
||||
}
|
||||
|
||||
ASSERT(offset == total_size);
|
||||
m_packet_size_total += total_size;
|
||||
m_packet_infos.emplace(total_size, BAN::move(fds_to_send));
|
||||
ASSERT(offset == packet_info.size);
|
||||
m_packet_size_total += packet_info.size;
|
||||
m_packet_infos.emplace(BAN::move(packet_info));
|
||||
|
||||
m_packet_thread_blocker.unblock();
|
||||
|
||||
|
|
@ -326,10 +326,10 @@ namespace Kernel
|
|||
if (m_info.has<ConnectionInfo>())
|
||||
{
|
||||
auto& connection_info = m_info.get<ConnectionInfo>();
|
||||
if (connection_info.listening)
|
||||
return !connection_info.pending_connections.empty();
|
||||
if (connection_info.target_closed)
|
||||
return true;
|
||||
if (!connection_info.pending_connections.empty())
|
||||
return true;
|
||||
if (!connection_info.connection)
|
||||
return false;
|
||||
}
|
||||
|
|
@ -342,7 +342,13 @@ namespace Kernel
|
|||
if (m_info.has<ConnectionInfo>())
|
||||
{
|
||||
auto& connection_info = m_info.get<ConnectionInfo>();
|
||||
return connection_info.connection.valid();
|
||||
auto connection = connection_info.connection.lock();
|
||||
if (!connection)
|
||||
return false;
|
||||
if (connection->m_packet_infos.full())
|
||||
return false;
|
||||
if (connection->m_packet_size_total >= s_packet_buffer_size)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
@ -393,16 +399,27 @@ namespace Kernel
|
|||
|
||||
message.msg_flags = 0;
|
||||
|
||||
int iov_index = 0;
|
||||
size_t iov_offset = 0;
|
||||
size_t total_recv = 0;
|
||||
for (int i = 0; i < message.msg_iovlen; i++)
|
||||
|
||||
while (!m_packet_infos.empty() && iov_index < message.msg_iovlen)
|
||||
{
|
||||
auto& packet_info = m_packet_infos.front();
|
||||
|
||||
auto fds_to_open = BAN::move(packet_info.fds);
|
||||
if (cheader == nullptr && !fds_to_open.empty())
|
||||
message.msg_flags |= MSG_CTRUNC;
|
||||
else if (!fds_to_open.empty())
|
||||
auto ucred_to_recv = BAN::move(packet_info.ucred);
|
||||
const bool had_ancillary_data = !fds_to_open.empty() || ucred_to_recv.has_value();
|
||||
|
||||
if (!fds_to_open.empty()) do
|
||||
{
|
||||
if (cheader == nullptr)
|
||||
{
|
||||
dwarnln("no space to receive {} fds", fds_to_open.size());
|
||||
message.msg_flags |= MSG_CTRUNC;
|
||||
break;
|
||||
}
|
||||
|
||||
const size_t max_fd_count = (cheader->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
|
||||
if (max_fd_count < fds_to_open.size())
|
||||
message.msg_flags |= MSG_CTRUNC;
|
||||
|
|
@ -422,16 +439,52 @@ namespace Kernel
|
|||
if (cheader != nullptr)
|
||||
cheader->cmsg_len = message.msg_controllen - header_length;
|
||||
cheader_len += header_length;
|
||||
} while (false);
|
||||
|
||||
if (ucred_to_recv.has_value()) do
|
||||
{
|
||||
if (cheader == nullptr || cheader->cmsg_len - sizeof(cmsghdr) < sizeof(struct ucred))
|
||||
{
|
||||
dwarnln("no space to receive credentials");
|
||||
message.msg_flags |= MSG_CTRUNC;
|
||||
break;
|
||||
}
|
||||
|
||||
*reinterpret_cast<struct ucred*>(CMSG_DATA(cheader)) = ucred_to_recv.value();
|
||||
|
||||
const size_t header_length = CMSG_LEN(sizeof(struct ucred));
|
||||
cheader->cmsg_level = SOL_SOCKET;
|
||||
cheader->cmsg_type = SCM_CREDENTIALS;
|
||||
cheader->cmsg_len = header_length;
|
||||
cheader = CMSG_NXTHDR(&message, cheader);
|
||||
if (cheader != nullptr)
|
||||
cheader->cmsg_len = message.msg_controllen - header_length;
|
||||
cheader_len += header_length;
|
||||
} while (false);
|
||||
|
||||
size_t packet_received = 0;
|
||||
while (iov_index < message.msg_iovlen && packet_received < packet_info.size)
|
||||
{
|
||||
auto& iov = message.msg_iov[iov_index];
|
||||
uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base);
|
||||
|
||||
const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received);
|
||||
memcpy(iov_base + iov_offset, packet_buffer + packet_received, nrecv);
|
||||
|
||||
packet_received += nrecv;
|
||||
|
||||
iov_offset += nrecv;
|
||||
if (iov_offset >= iov.iov_len)
|
||||
{
|
||||
iov_offset = 0;
|
||||
iov_index++;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, packet_info.size);
|
||||
memcpy(message.msg_iov[i].iov_base, packet_buffer, nrecv);
|
||||
total_recv += nrecv;
|
||||
|
||||
if (!is_streaming() && nrecv < packet_info.size)
|
||||
if (!is_streaming() && packet_received < packet_info.size)
|
||||
message.msg_flags |= MSG_TRUNC;
|
||||
|
||||
const size_t to_discard = is_streaming() ? nrecv : packet_info.size;
|
||||
const size_t to_discard = is_streaming() ? packet_received : packet_info.size;
|
||||
|
||||
packet_info.size -= to_discard;
|
||||
if (packet_info.size == 0)
|
||||
|
|
@ -441,7 +494,10 @@ namespace Kernel
|
|||
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
|
||||
m_packet_size_total -= to_discard;
|
||||
|
||||
if (!is_streaming() || m_packet_infos.empty())
|
||||
total_recv += packet_received;
|
||||
|
||||
// on linux ancillary data is a barrier on stream sockets, lets do the same
|
||||
if (!is_streaming() || had_ancillary_data)
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -465,28 +521,6 @@ namespace Kernel
|
|||
return BAN::Error::from_errno(ENOTSUP);
|
||||
}
|
||||
|
||||
BAN::Vector<OpenFileDescriptorSet::FDWrapper> fds_to_send;
|
||||
|
||||
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
|
||||
{
|
||||
if (header->cmsg_level != SOL_SOCKET)
|
||||
{
|
||||
dwarnln("ignoring control message with level {}", header->cmsg_level);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (header->cmsg_type != SCM_RIGHTS)
|
||||
{
|
||||
dwarnln("ignoring control message with type {}", header->cmsg_type);
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto* fd_data = reinterpret_cast<const int*>(CMSG_DATA(header));
|
||||
const size_t fd_count = (header->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
|
||||
for (size_t i = 0; i < fd_count; i++)
|
||||
TRY(fds_to_send.push_back(TRY(Process::current().open_file_descriptor_set().get_fd_wrapper(fd_data[i]))));
|
||||
}
|
||||
|
||||
const size_t total_message_size =
|
||||
[&message]() -> size_t {
|
||||
size_t result = 0;
|
||||
|
|
@ -498,20 +532,90 @@ namespace Kernel
|
|||
if (total_message_size > s_packet_buffer_size)
|
||||
return BAN::Error::from_errno(ENOBUFS);
|
||||
|
||||
PacketInfo packet_info {
|
||||
.size = total_message_size,
|
||||
.fds = {},
|
||||
.ucred = {},
|
||||
};
|
||||
|
||||
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
|
||||
{
|
||||
if (header->cmsg_level != SOL_SOCKET)
|
||||
{
|
||||
dwarnln("ignoring control message with level {}", header->cmsg_level);
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (header->cmsg_type)
|
||||
{
|
||||
case SCM_RIGHTS:
|
||||
{
|
||||
if (!packet_info.fds.empty())
|
||||
{
|
||||
dwarnln("multiple SCM_RIGHTS in one sendmsg");
|
||||
return BAN::Error::from_errno(EINVAL);
|
||||
}
|
||||
|
||||
const auto* fd_data = reinterpret_cast<const int*>(CMSG_DATA(header));
|
||||
const size_t fd_count = (header->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
|
||||
for (size_t i = 0; i < fd_count; i++)
|
||||
TRY(packet_info.fds.push_back(TRY(Process::current().open_file_descriptor_set().get_fd_wrapper(fd_data[i]))));
|
||||
break;
|
||||
}
|
||||
case SCM_CREDENTIALS:
|
||||
{
|
||||
if (packet_info.ucred.has_value())
|
||||
{
|
||||
dwarnln("multiple SCM_CREDENTIALS in one sendmsg");
|
||||
return BAN::Error::from_errno(EINVAL);
|
||||
}
|
||||
|
||||
if (header->cmsg_len - sizeof(cmsghdr) < sizeof(struct ucred))
|
||||
return BAN::Error::from_errno(EINVAL);
|
||||
const ucred* ucred = reinterpret_cast<const struct ucred*>(CMSG_DATA(header));
|
||||
|
||||
const bool is_valid_ucred =
|
||||
[ucred]() -> bool
|
||||
{
|
||||
const auto& creds = Process::current().credentials();
|
||||
if (creds.is_superuser())
|
||||
return true;
|
||||
if (ucred->pid != Process::current().pid())
|
||||
return false;
|
||||
if (ucred->uid != creds.ruid() && ucred->uid != creds.euid() && ucred->uid != creds.suid())
|
||||
return false;
|
||||
if (ucred->gid != creds.rgid() && !creds.has_egid(ucred->gid) && ucred->gid != creds.sgid())
|
||||
return false;
|
||||
return true;
|
||||
}();
|
||||
|
||||
if (!is_valid_ucred)
|
||||
return BAN::Error::from_errno(EPERM);
|
||||
|
||||
packet_info.ucred = *ucred;
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
dwarnln("ignoring control message with type {}", header->cmsg_type);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_info.has<ConnectionInfo>())
|
||||
{
|
||||
auto& connection_info = m_info.get<ConnectionInfo>();
|
||||
auto target = connection_info.connection.lock();
|
||||
if (!target)
|
||||
return BAN::Error::from_errno(ENOTCONN);
|
||||
TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send)));
|
||||
TRY(target->add_packet(message, BAN::move(packet_info)));
|
||||
return total_message_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
BAN::RefPtr<Inode> target_inode;
|
||||
|
||||
if (!message.msg_name)
|
||||
if (!message.msg_name || message.msg_namelen == 0)
|
||||
{
|
||||
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
|
||||
if (connectionless_info.peer_address.empty())
|
||||
|
|
@ -548,7 +652,7 @@ namespace Kernel
|
|||
|
||||
if (!target)
|
||||
return BAN::Error::from_errno(EDESTADDRREQ);
|
||||
TRY(target->add_packet(message, total_message_size, BAN::move(fds_to_send)));
|
||||
TRY(target->add_packet(message, BAN::move(packet_info)));
|
||||
|
||||
return total_message_size;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@
|
|||
|
||||
__BEGIN_DECLS
|
||||
|
||||
#define __need_pid_t
|
||||
#define __need_uid_t
|
||||
#define __need_gid_t
|
||||
#define __need_size_t
|
||||
#define __need_ssize_t
|
||||
#include <sys/types.h>
|
||||
|
|
@ -53,7 +56,8 @@ struct cmsghdr
|
|||
unsigned char __cmg_data[];
|
||||
};
|
||||
|
||||
#define SCM_RIGHTS 1
|
||||
#define SCM_RIGHTS 1
|
||||
#define SCM_CREDENTIALS 2
|
||||
|
||||
#define CMSG_DATA(cmsg) ((cmsg)->__cmg_data)
|
||||
|
||||
|
|
@ -78,6 +82,13 @@ struct cmsghdr
|
|||
#define CMSG_LEN(length) \
|
||||
(socklen_t)((length) + sizeof(struct cmsghdr))
|
||||
|
||||
struct ucred
|
||||
{
|
||||
pid_t pid;
|
||||
uid_t uid;
|
||||
gid_t gid;
|
||||
};
|
||||
|
||||
struct linger
|
||||
{
|
||||
int l_onoff; /* Indicates wheter linger option is enabled. */
|
||||
|
|
|
|||
Loading…
Reference in New Issue