Kernel/LibC: Implement {recv,send}msg as syscalls

This also removes the now old recvfrom and sendto syscalls. These are
now implemented as wrappers around recvmsg and sendmsg.

Also replace unnecessary spinlocks from unix socket with mutexes
This commit is contained in:
Bananymous 2025-11-09 16:23:37 +02:00
parent 2f38306c6b
commit 04d24bce70
16 changed files with 422 additions and 307 deletions

View File

@ -108,8 +108,8 @@ namespace Kernel
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> listen(int backlog);
BAN::ErrorOr<size_t> sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<size_t> recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<size_t> sendmsg(const msghdr& message, int flags);
BAN::ErrorOr<size_t> recvmsg(msghdr& message, int flags);
BAN::ErrorOr<void> getsockname(sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<void> getpeername(sockaddr* address, socklen_t* address_len);
@ -155,8 +155,8 @@ namespace Kernel
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> recvmsg_impl(msghdr&, int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> sendmsg_impl(const msghdr&, int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> getsockname_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> getpeername_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }

View File

@ -51,8 +51,6 @@ namespace Kernel
: m_info(info)
{}
BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan buffer) override { return recvfrom_impl(buffer, nullptr, nullptr); }
BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override { return sendto_impl(buffer, nullptr, 0); }
BAN::ErrorOr<void> fsync_impl() final override { return {}; }
private:

View File

@ -60,8 +60,8 @@ namespace Kernel
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<size_t> recvmsg_impl(msghdr& message, int flags) override;
virtual BAN::ErrorOr<size_t> sendmsg_impl(const msghdr& message, int flags) override;
virtual BAN::ErrorOr<void> getpeername_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<long> ioctl_impl(int, void*) override;

View File

@ -34,8 +34,8 @@ namespace Kernel
virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) override;
virtual BAN::ErrorOr<size_t> recvmsg_impl(msghdr& message, int flags) override;
virtual BAN::ErrorOr<size_t> sendmsg_impl(const msghdr& message, int flags) override;
virtual BAN::ErrorOr<void> getpeername_impl(sockaddr*, socklen_t*) override { return BAN::Error::from_errno(ENOTCONN); }
virtual BAN::ErrorOr<long> ioctl_impl(int, void*) override;

View File

@ -25,8 +25,8 @@ namespace Kernel
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<size_t> recvmsg_impl(msghdr& message, int flags) override;
virtual BAN::ErrorOr<size_t> sendmsg_impl(const msghdr& message, int flags) override;
virtual BAN::ErrorOr<void> getpeername_impl(sockaddr*, socklen_t*) override;
virtual bool can_read_impl() const override;
@ -38,7 +38,7 @@ namespace Kernel
UnixDomainSocket(Socket::Type, const Socket::Info&);
~UnixDomainSocket();
BAN::ErrorOr<void> add_packet(BAN::ConstByteSpan);
BAN::ErrorOr<void> add_packet(const msghdr&, size_t total_size);
bool is_bound() const { return !m_bound_file.canonical_path.empty(); }
bool is_bound_to_unused() const { return !m_bound_file.inode; }
@ -54,7 +54,7 @@ namespace Kernel
BAN::WeakPtr<UnixDomainSocket> connection;
BAN::Queue<BAN::RefPtr<UnixDomainSocket>> pending_connections;
ThreadBlocker pending_thread_blocker;
SpinLock pending_lock;
Mutex pending_lock;
};
struct ConnectionlessInfo
@ -71,7 +71,7 @@ namespace Kernel
BAN::CircularQueue<size_t, 128> m_packet_sizes;
size_t m_packet_size_total { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
SpinLock m_packet_lock;
Mutex m_packet_lock;
ThreadBlocker m_packet_thread_blocker;
friend class BAN::RefPtr<UnixDomainSocket>;

View File

@ -51,8 +51,8 @@ namespace Kernel
BAN::ErrorOr<size_t> read_dir_entries(int fd, struct dirent* list, size_t list_len);
BAN::ErrorOr<size_t> recvfrom(int fd, BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<size_t> sendto(int fd, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<size_t> recvmsg(int socket, msghdr& message, int flags);
BAN::ErrorOr<size_t> sendmsg(int socket, const msghdr& message, int flags);
BAN::ErrorOr<VirtualFileSystem::File> file_of(int) const;
BAN::ErrorOr<BAN::String> path_of(int) const;

View File

@ -134,8 +134,8 @@ namespace Kernel
BAN::ErrorOr<long> sys_bind(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_connect(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_listen(int socket, int backlog);
BAN::ErrorOr<long> sys_sendto(const sys_sendto_t*);
BAN::ErrorOr<long> sys_recvfrom(sys_recvfrom_t*);
BAN::ErrorOr<long> sys_recvmsg(int socket, msghdr* message, int flags);
BAN::ErrorOr<long> sys_sendmsg(int socket, const msghdr* message, int flags);
BAN::ErrorOr<long> sys_ioctl(int fildes, int request, void* arg);

View File

@ -176,21 +176,21 @@ namespace Kernel
return listen_impl(backlog);
}
BAN::ErrorOr<size_t> Inode::sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
BAN::ErrorOr<size_t> Inode::recvmsg(msghdr& message, int flags)
{
LockGuard _(m_mutex);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return sendto_impl(message, address, address_len);
};
return recvmsg_impl(message, flags);
}
BAN::ErrorOr<size_t> Inode::recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len)
BAN::ErrorOr<size_t> Inode::sendmsg(const msghdr& message, int flags)
{
LockGuard _(m_mutex);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return recvfrom_impl(buffer, address, address_len);
};
return sendmsg_impl(message, flags);
}
BAN::ErrorOr<void> Inode::getsockname(sockaddr* address, socklen_t* address_len)
{

View File

@ -190,8 +190,20 @@ namespace Kernel
return m_network_layer.bind_socket_to_address(this, address, address_len);
}
BAN::ErrorOr<size_t> TCPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*)
BAN::ErrorOr<size_t> TCPSocket::recvmsg_impl(msghdr& message, int flags)
{
if (flags != 0)
{
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
{
dwarnln("ignoring recvmsg control message");
message.msg_controllen = 0;
}
if (!m_has_connected)
return BAN::Error::from_errno(ENOTCONN);
@ -202,26 +214,38 @@ namespace Kernel
TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex));
}
const uint32_t to_recv = BAN::Math::min<uint32_t>(buffer.size(), m_recv_window.data_size);
size_t total_recv = 0;
for (int i = 0; i < message.msg_iovlen; i++)
{
auto* recv_buffer = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
memcpy(buffer.data(), recv_buffer, to_recv);
m_recv_window.data_size -= to_recv;
m_recv_window.start_seq += to_recv;
if (m_recv_window.data_size > 0)
memmove(recv_buffer, recv_buffer + to_recv, m_recv_window.data_size);
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_recv_window.data_size);
memcpy(message.msg_iov[i].iov_base, recv_buffer, nrecv);
return to_recv;
total_recv += nrecv;
m_recv_window.data_size -= nrecv;
m_recv_window.start_seq += nrecv;
if (m_recv_window.data_size == 0)
break;
// TODO: use circular buffer to avoid this
memmove(recv_buffer, recv_buffer + nrecv, m_recv_window.data_size);
}
BAN::ErrorOr<size_t> TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{
(void)address;
(void)address_len;
return total_recv;
}
BAN::ErrorOr<size_t> TCPSocket::sendmsg_impl(const msghdr& message, int flags)
{
if (flags != 0)
{
dwarnln("TODO: sendmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
dwarnln("ignoring sendmsg control message");
if (address)
return BAN::Error::from_errno(EISCONN);
if (!m_has_connected)
return BAN::Error::from_errno(ENOTCONN);
@ -232,17 +256,23 @@ namespace Kernel
TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex));
}
const size_t to_send = BAN::Math::min<size_t>(message.size(), m_send_window.buffer->size() - m_send_window.data_size);
size_t total_sent = 0;
for (int i = 0; i < message.msg_iovlen; i++)
{
auto* buffer = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
memcpy(buffer + m_send_window.data_size, message.data(), to_send);
m_send_window.data_size += to_send;
auto* send_buffer = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
const size_t nsend = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size);
memcpy(send_buffer + m_send_window.data_size, message.msg_iov[i].iov_base, nsend);
total_sent += nsend;
m_send_window.data_size += nsend;
if (m_send_window.data_size == m_send_window.buffer->size())
break;
}
m_thread_blocker.unblock();
return to_send;
return total_sent;
}
BAN::ErrorOr<void> TCPSocket::getpeername_impl(sockaddr* address, socklen_t* address_len)

View File

@ -86,8 +86,20 @@ namespace Kernel
return m_network_layer.bind_socket_to_address(this, address, address_len);
}
BAN::ErrorOr<size_t> UDPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len)
BAN::ErrorOr<size_t> UDPSocket::recvmsg_impl(msghdr& message, int flags)
{
if (flags != 0)
{
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
{
dwarnln("ignoring recvmsg control message");
message.msg_controllen = 0;
}
if (!is_bound())
{
dprintln("No interface bound");
@ -106,14 +118,16 @@ namespace Kernel
auto packet_info = m_packets.front();
m_packets.pop();
size_t nread = BAN::Math::min<size_t>(packet_info.packet_size, buffer.size());
auto* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t total_recv = 0;
for (int i = 0; i < message.msg_iovlen; i++)
{
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, packet_info.packet_size - total_recv);
memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv);
total_recv += nrecv;
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
memcpy(
buffer.data(),
packet_buffer,
nread
);
memmove(
packet_buffer,
packet_buffer + packet_info.packet_size,
@ -122,21 +136,49 @@ namespace Kernel
m_packet_total_size -= packet_info.packet_size;
if (address && address_len)
if (message.msg_name && message.msg_namelen)
{
if (*address_len > (socklen_t)sizeof(sockaddr_storage))
*address_len = sizeof(sockaddr_storage);
memcpy(address, &packet_info.sender, *address_len);
const size_t namelen = BAN::Math::min<size_t>(message.msg_namelen, sizeof(sockaddr_storage));
memcpy(message.msg_name, &packet_info.sender, namelen);
message.msg_namelen = namelen;
}
return nread;
return total_recv;
}
BAN::ErrorOr<size_t> UDPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
BAN::ErrorOr<size_t> UDPSocket::sendmsg_impl(const msghdr& message, int flags)
{
if (flags != 0)
{
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
dwarnln("ignoring recvmsg control message");
if (!is_bound())
TRY(m_network_layer.bind_socket_to_unused(this, address, address_len));
return TRY(m_network_layer.sendto(*this, message, address, address_len));
TRY(m_network_layer.bind_socket_to_unused(this, static_cast<sockaddr*>(message.msg_name), message.msg_namelen));
const size_t total_send_size =
[&message]() -> size_t {
size_t result = 0;
for (int i = 0; i < message.msg_iovlen; i++)
result += message.msg_iov[i].iov_len;
return result;
}();
BAN::Vector<uint8_t> buffer;
TRY(buffer.resize(total_send_size));
size_t offset = 0;
for (int i = 0; i < message.msg_iovlen; i++)
{
memcpy(buffer.data() + offset, message.msg_iov[i].iov_base, message.msg_iov[i].iov_len);
offset += message.msg_iov[i].iov_len;
}
return TRY(m_network_layer.sendto(*this, buffer.span(), static_cast<sockaddr*>(message.msg_name), message.msg_namelen));
}
BAN::ErrorOr<long> UDPSocket::ioctl_impl(int request, void* argument)

View File

@ -1,6 +1,7 @@
#include <BAN/HashMap.h>
#include <kernel/FS/VirtualFileSystem.h>
#include <kernel/Lock/SpinLockAsMutex.h>
#include <kernel/Lock/LockGuard.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UNIX/Socket.h>
#include <kernel/Process.h>
@ -22,10 +23,28 @@ namespace Kernel
};
static BAN::HashMap<BAN::RefPtr<Inode>, BAN::WeakPtr<UnixDomainSocket>, UnixSocketHash> s_bound_sockets;
static SpinLock s_bound_socket_lock;
static Mutex s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
static BAN::ErrorOr<BAN::StringView> validate_sockaddr_un(const sockaddr* address, socklen_t address_len)
{
if (address_len < static_cast<socklen_t>(sizeof(sa_family_t)))
return BAN::Error::from_errno(EINVAL);
if (address_len > static_cast<socklen_t>(sizeof(sockaddr_un)))
address_len = sizeof(sockaddr_un);
const auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EINVAL);
size_t length = 0;
while (length < address_len - sizeof(sa_family_t) && sockaddr_un.sun_path[length])
length++;
return BAN::StringView { sockaddr_un.sun_path, length };
}
// FIXME: why is this using spinlocks instead of mutexes??
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
@ -64,7 +83,7 @@ namespace Kernel
{
if (is_bound() && !is_bound_to_unused())
{
SpinLockGuard _(s_bound_socket_lock);
LockGuard _(s_bound_socket_lock);
s_bound_sockets.remove(m_bound_file.inode);
}
if (m_info.has<ConnectionInfo>())
@ -105,11 +124,9 @@ namespace Kernel
BAN::RefPtr<UnixDomainSocket> pending;
{
SpinLockGuard guard(connection_info.pending_lock);
SpinLockGuardAsMutex smutex(guard);
LockGuard _(connection_info.pending_lock);
while (connection_info.pending_connections.empty())
TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker, &smutex));
TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker, &connection_info.pending_lock));
pending = connection_info.pending_connections.front();
connection_info.pending_connections.pop();
@ -146,15 +163,11 @@ namespace Kernel
BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
const auto sun_path = TRY(validate_sockaddr_un(address, address_len));
if (!is_bound())
TRY(m_bound_file.canonical_path.push_back('X'));
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto absolute_path = TRY(Process::current().absolute_path_of(sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().root_file().inode,
Process::current().credentials(),
@ -165,7 +178,7 @@ namespace Kernel
BAN::RefPtr<UnixDomainSocket> target;
{
SpinLockGuard _(s_bound_socket_lock);
LockGuard _(s_bound_socket_lock);
auto it = s_bound_sockets.find(file.inode);
if (it == s_bound_sockets.end())
return BAN::Error::from_errno(ECONNREFUSED);
@ -196,7 +209,7 @@ namespace Kernel
{
auto& target_info = target->m_info.get<ConnectionInfo>();
SpinLockGuard guard(target_info.pending_lock);
LockGuard _(target_info.pending_lock);
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
{
@ -205,8 +218,7 @@ namespace Kernel
break;
}
SpinLockGuardAsMutex smutex(guard);
TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker, &smutex));
TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker, &target_info.pending_lock));
}
target->epoll_notify(EPOLLIN);
@ -236,21 +248,16 @@ namespace Kernel
{
if (is_bound())
return BAN::Error::from_errno(EINVAL);
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto bind_path = BAN::StringView(sockaddr_un.sun_path);
if (bind_path.empty())
const auto sun_path = TRY(validate_sockaddr_un(address, address_len));
if (sun_path.empty())
return BAN::Error::from_errno(EINVAL);
// FIXME: This feels sketchy
auto parent_file = bind_path.front() == '/'
auto parent_file = sun_path.front() == '/'
? TRY(Process::current().root_file().clone())
: TRY(Process::current().working_directory().clone());
if (auto ret = Process::current().create_file_or_dir(AT_FDCWD, bind_path.data(), 0755 | S_IFSOCK); ret.is_error())
if (auto ret = Process::current().create_file_or_dir(AT_FDCWD, sun_path.data(), 0755 | S_IFSOCK); ret.is_error())
{
if (ret.error().get_error_code() == EEXIST)
return BAN::Error::from_errno(EADDRINUSE);
@ -260,11 +267,11 @@ namespace Kernel
Process::current().root_file().inode,
parent_file,
Process::current().credentials(),
bind_path,
sun_path,
O_RDWR
));
SpinLockGuard _(s_bound_socket_lock);
LockGuard _(s_bound_socket_lock);
if (s_bound_sockets.contains(file.inode))
return BAN::Error::from_errno(EADDRINUSE);
TRY(s_bound_sockets.emplace(file.inode, TRY(get_weak_ptr())));
@ -287,21 +294,24 @@ namespace Kernel
}
}
BAN::ErrorOr<void> UnixDomainSocket::add_packet(BAN::ConstByteSpan packet)
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, size_t total_size)
{
SpinLockGuard guard(m_packet_lock);
while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size)
{
SpinLockGuardAsMutex smutex(guard);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex));
}
LockGuard _(m_packet_lock);
while (m_packet_sizes.full() || m_packet_size_total + total_size > s_packet_buffer_size)
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
memcpy(packet_buffer, packet.data(), packet.size());
m_packet_size_total += packet.size();
if (!is_streaming())
m_packet_sizes.push(packet.size());
size_t offset = 0;
for (int i = 0; i < packet.msg_iovlen; i++)
{
memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len);
offset += packet.msg_iov[i].iov_len;
}
ASSERT(offset == total_size);
m_packet_size_total += total_size;
m_packet_sizes.push(total_size);
m_packet_thread_blocker.unblock();
@ -348,27 +358,105 @@ namespace Kernel
return false;
}
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
BAN::ErrorOr<size_t> UnixDomainSocket::recvmsg_impl(msghdr& message, int flags)
{
if (message.size() > s_packet_buffer_size)
if (flags != 0)
{
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
{
dwarnln("ignoring recvmsg control message");
message.msg_controllen = 0;
}
LockGuard _(m_packet_lock);
while (m_packet_size_total == 0)
{
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
bool expected = true;
if (connection_info.target_closed.compare_exchange(expected, false))
return 0;
if (!connection_info.connection)
return BAN::Error::from_errno(ENOTCONN);
}
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
const size_t max_recv_size = is_streaming() ? m_packet_size_total : m_packet_sizes.front();
size_t total_recv = 0;
for (int i = 0; i < message.msg_iovlen; i++)
{
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, max_recv_size - total_recv);
memcpy(message.msg_iov[i].iov_base, packet_buffer + total_recv, nrecv);
total_recv += nrecv;
}
size_t bytes_to_handle = total_recv;
while (bytes_to_handle)
{
const size_t to_handle = BAN::Math::min(bytes_to_handle, m_packet_sizes.front());
bytes_to_handle -= to_handle;
m_packet_sizes.front() -= to_handle;
if (m_packet_sizes.front() == 0)
m_packet_sizes.pop();
}
const size_t to_discard = is_streaming() ? total_recv : max_recv_size;
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
m_packet_size_total -= to_discard;
m_packet_thread_blocker.unblock();
epoll_notify(EPOLLOUT);
return total_recv;
}
BAN::ErrorOr<size_t> UnixDomainSocket::sendmsg_impl(const msghdr& message, int flags)
{
if (flags != 0)
{
dwarnln("TODO: sendmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
dwarnln("ignoring sendmsg control message");
const size_t total_message_size =
[&message]() -> size_t {
size_t result = 0;
for (int i = 0; i < message.msg_iovlen; i++)
result += message.msg_iov[i].iov_len;
return result;
}();
if (total_message_size > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (address)
return BAN::Error::from_errno(EISCONN);
auto target = connection_info.connection.lock();
if (!target)
return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet(message));
return message.size();
TRY(target->add_packet(message, total_message_size));
return total_message_size;
}
else
{
BAN::RefPtr<Inode> target_inode;
if (!address)
if (!message.msg_name)
{
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
if (connectionless_info.peer_address.empty())
@ -384,13 +472,8 @@ namespace Kernel
}
else
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
const auto sun_path = TRY(validate_sockaddr_un(static_cast<sockaddr*>(message.msg_name), message.msg_namelen));
auto absolute_path = TRY(Process::current().absolute_path_of(sun_path));
target_inode = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().root_file().inode,
Process::current().credentials(),
@ -399,59 +482,23 @@ namespace Kernel
)).inode;
}
SpinLockGuard _(s_bound_socket_lock);
BAN::RefPtr<UnixDomainSocket> target;
{
LockGuard _(s_bound_socket_lock);
auto it = s_bound_sockets.find(target_inode);
if (it == s_bound_sockets.end())
return BAN::Error::from_errno(EDESTADDRREQ);
auto target = it->value.lock();
target = it->value.lock();
}
if (!target)
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet(message));
return message.size();
TRY(target->add_packet(message, total_message_size));
return total_message_size;
}
}
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*)
{
SpinLockGuard guard(m_packet_lock);
while (m_packet_size_total == 0)
{
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
bool expected = true;
if (connection_info.target_closed.compare_exchange(expected, false))
return 0;
if (!connection_info.connection)
return BAN::Error::from_errno(ENOTCONN);
}
SpinLockGuardAsMutex smutex(guard);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &smutex));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t nread = 0;
if (is_streaming())
nread = BAN::Math::min(buffer.size(), m_packet_size_total);
else
{
nread = BAN::Math::min(buffer.size(), m_packet_sizes.front());
m_packet_sizes.pop();
}
memcpy(buffer.data(), packet_buffer, nread);
memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread);
m_packet_size_total -= nread;
m_packet_thread_blocker.unblock();
epoll_notify(EPOLLOUT);
return nread;
}
BAN::ErrorOr<void> UnixDomainSocket::getpeername_impl(sockaddr* address, socklen_t* address_len)
{
if (!m_info.has<ConnectionInfo>())

View File

@ -424,7 +424,24 @@ namespace Kernel
}
if (inode->mode().ifsock())
return recvfrom(fd, buffer, nullptr, nullptr);
{
iovec iov {
.iov_base = buffer.data(),
.iov_len = buffer.size(),
};
msghdr message {
.msg_name = nullptr,
.msg_namelen = 0,
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = nullptr,
.msg_controllen = 0,
.msg_flags = 0,
};
return recvmsg(fd, message, 0);
}
size_t nread;
{
@ -461,7 +478,24 @@ namespace Kernel
}
if (inode->mode().ifsock())
return sendto(fd, buffer, nullptr, 0);
{
iovec iov {
.iov_base = const_cast<uint8_t*>(buffer.data()),
.iov_len = buffer.size(),
};
msghdr message {
.msg_name = nullptr,
.msg_namelen = 0,
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = nullptr,
.msg_controllen = 0,
.msg_flags = 0,
};
return sendmsg(fd, message, 0);
}
size_t nwrite;
{
@ -515,7 +549,7 @@ namespace Kernel
}
}
BAN::ErrorOr<size_t> OpenFileDescriptorSet::recvfrom(int fd, BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len)
BAN::ErrorOr<size_t> OpenFileDescriptorSet::recvmsg(int fd, msghdr& message, int flags)
{
BAN::RefPtr<Inode> inode;
bool is_nonblock;
@ -533,10 +567,10 @@ namespace Kernel
LockGuard _(inode->m_mutex);
if (is_nonblock && !inode->can_read())
return BAN::Error::from_errno(EWOULDBLOCK);
return inode->recvfrom(buffer, address, address_len);
return inode->recvmsg(message, flags);
}
BAN::ErrorOr<size_t> OpenFileDescriptorSet::sendto(int fd, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len)
BAN::ErrorOr<size_t> OpenFileDescriptorSet::sendmsg(int fd, const msghdr& message, int flags)
{
BAN::RefPtr<Inode> inode;
bool is_nonblock;
@ -559,7 +593,7 @@ namespace Kernel
}
if (is_nonblock && !inode->can_write())
return BAN::Error::from_errno(EWOULDBLOCK);
return inode->sendto(buffer, address, address_len);
return inode->sendmsg(message, flags);
}
BAN::ErrorOr<VirtualFileSystem::File> OpenFileDescriptorSet::file_of(int fd) const

View File

@ -1569,73 +1569,72 @@ namespace Kernel
return 0;
}
BAN::ErrorOr<long> Process::sys_sendto(const sys_sendto_t* _arguments)
BAN::ErrorOr<long> Process::sys_recvmsg(int socket, msghdr* _message, int flags)
{
sys_sendto_t arguments;
msghdr message;
{
LockGuard _(m_process_lock);
TRY(validate_pointer_access(_arguments, sizeof(sys_sendto_t), false));
arguments = *_arguments;
TRY(validate_pointer_access(_message, sizeof(msghdr), true));
message = *_message;
}
if (arguments.length == 0)
return BAN::Error::from_errno(EINVAL);
MemoryRegion* message_region = nullptr;
MemoryRegion* address_region = nullptr;
BAN::ScopeGuard _([&] {
if (message_region)
message_region->unpin();
if (address_region)
address_region->unpin();
BAN::Vector<MemoryRegion*> regions;
BAN::ScopeGuard _([&regions] {
for (auto* region : regions)
if (region != nullptr)
region->unpin();
});
message_region = TRY(validate_and_pin_pointer_access(arguments.message, arguments.length, false));
if (arguments.dest_addr)
address_region = TRY(validate_and_pin_pointer_access(arguments.dest_addr, arguments.dest_len, false));
auto message = BAN::ConstByteSpan(static_cast<const uint8_t*>(arguments.message), arguments.length);
return TRY(m_open_file_descriptors.sendto(arguments.socket, message, arguments.dest_addr, arguments.dest_len));
if (message.msg_name)
TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_name, message.msg_namelen, true))));
if (message.msg_control)
TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_control, message.msg_controllen, true))));
if (message.msg_iov)
{
TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov, message.msg_iovlen * sizeof(iovec), true))));
for (int i = 0; i < message.msg_iovlen; i++)
TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov[i].iov_base, message.msg_iov[i].iov_len, true))));
}
BAN::ErrorOr<long> Process::sys_recvfrom(sys_recvfrom_t* _arguments)
{
sys_recvfrom_t arguments;
auto ret = TRY(m_open_file_descriptors.recvmsg(socket, message, flags));
{
LockGuard _(m_process_lock);
TRY(validate_pointer_access(_arguments, sizeof(sys_sendto_t), false));
arguments = *_arguments;
TRY(validate_pointer_access(_message, sizeof(msghdr), true));
*_message = message;
}
if (!arguments.address != !arguments.address_len)
return BAN::Error::from_errno(EINVAL);
if (arguments.length == 0)
return BAN::Error::from_errno(EINVAL);
return ret;
}
MemoryRegion* buffer_region = nullptr;
MemoryRegion* address_region1 = nullptr;
MemoryRegion* address_region2 = nullptr;
BAN::ErrorOr<long> Process::sys_sendmsg(int socket, const msghdr* _message, int flags)
{
msghdr message;
{
LockGuard _(m_process_lock);
TRY(validate_pointer_access(_message, sizeof(msghdr), false));
message = *_message;
}
BAN::ScopeGuard _([&] {
if (buffer_region)
buffer_region->unpin();
if (address_region1)
address_region1->unpin();
if (address_region2)
address_region2->unpin();
BAN::Vector<MemoryRegion*> regions;
BAN::ScopeGuard _([&regions] {
for (auto* region : regions)
if (region != nullptr)
region->unpin();
});
buffer_region = TRY(validate_and_pin_pointer_access(arguments.buffer, arguments.length, true));
if (arguments.address_len)
if (message.msg_name)
TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_name, message.msg_namelen, false))));
if (message.msg_control)
TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_control, message.msg_controllen, false))));
if (message.msg_iov)
{
address_region1 = TRY(validate_and_pin_pointer_access(arguments.address_len, sizeof(*arguments.address_len), true));
address_region2 = TRY(validate_and_pin_pointer_access(arguments.address, *arguments.address_len, true));
TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov, message.msg_iovlen * sizeof(iovec), false))));
for (int i = 0; i < message.msg_iovlen; i++)
TRY(regions.push_back(TRY(validate_and_pin_pointer_access(message.msg_iov[i].iov_base, message.msg_iov[i].iov_len, false))));
}
auto message = BAN::ByteSpan(static_cast<uint8_t*>(arguments.buffer), arguments.length);
return TRY(m_open_file_descriptors.recvfrom(arguments.socket, message, arguments.address, arguments.address_len));
return TRY(m_open_file_descriptors.sendmsg(socket, message, flags));
}
BAN::ErrorOr<long> Process::sys_ioctl(int fildes, int request, void* arg)

View File

@ -124,8 +124,8 @@ namespace Kernel
case SYS_WAIT:
case SYS_ACCEPT:
case SYS_CONNECT:
case SYS_RECVFROM:
case SYS_SENDTO:
case SYS_RECVMSG:
case SYS_SENDMSG:
case SYS_FLOCK:
return true;
default:

View File

@ -67,8 +67,8 @@ __BEGIN_DECLS
O(SYS_SOCKET, socket) \
O(SYS_SOCKETPAIR, socketpair) \
O(SYS_BIND, bind) \
O(SYS_SENDTO, sendto) \
O(SYS_RECVFROM, recvfrom) \
O(SYS_RECVMSG, recvmsg) \
O(SYS_SENDMSG, sendmsg) \
O(SYS_IOCTL, ioctl) \
O(SYS_ACCEPT, accept) \
O(SYS_CONNECT, connect) \

View File

@ -42,104 +42,69 @@ ssize_t recv(int socket, void* __restrict buffer, size_t length, int flags)
ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
{
pthread_testcancel();
sys_recvfrom_t arguments {
.socket = socket,
.buffer = buffer,
.length = length,
.flags = flags,
.address = address,
.address_len = address_len
};
return syscall(SYS_RECVFROM, &arguments);
}
// cancellation point in recvmsg
ssize_t send(int socket, const void* message, size_t length, int flags)
{
// cancellation point in sendto
return sendto(socket, message, length, flags, nullptr, 0);
}
ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len)
{
pthread_testcancel();
sys_sendto_t arguments {
.socket = socket,
.message = message,
.length = length,
.flags = flags,
.dest_addr = dest_addr,
.dest_len = dest_len
iovec iov {
.iov_base = buffer,
.iov_len = length,
};
return syscall(SYS_SENDTO, &arguments);
msghdr message {
.msg_name = address,
.msg_namelen = address_len ? *address_len : 0,
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = NULL,
.msg_controllen = 0,
.msg_flags = 0,
};
const ssize_t ret = recvmsg(socket, &message, flags);
if (address_len)
*address_len = message.msg_namelen;
return ret;
}
ssize_t recvmsg(int socket, struct msghdr* message, int flags)
{
if (CMSG_FIRSTHDR(message))
{
dwarnln("TODO: recvmsg ancillary data");
errno = ENOTSUP;
return -1;
}
pthread_testcancel();
return syscall(SYS_RECVMSG, socket, message, flags);
}
size_t total_recv = 0;
ssize_t send(int socket, const void* buffer, size_t length, int flags)
{
// cancellation point in sendto
return sendto(socket, buffer, length, flags, nullptr, 0);
}
for (int i = 0; i < message->msg_iovlen; i++)
{
const ssize_t nrecv = recvfrom(
socket,
message->msg_iov[i].iov_base,
message->msg_iov[i].iov_len,
flags,
static_cast<sockaddr*>(message->msg_name),
&message->msg_namelen
);
ssize_t sendto(int socket, const void* buffer, size_t length, int flags, const struct sockaddr* address, socklen_t address_len)
{
// cancellation point in sendmsg
if (nrecv < 0)
return -1;
iovec iov {
.iov_base = const_cast<void*>(buffer),
.iov_len = length,
};
total_recv += nrecv;
msghdr message {
.msg_name = const_cast<sockaddr*>(address),
.msg_namelen = address_len,
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = NULL,
.msg_controllen = 0,
.msg_flags = 0,
};
if (static_cast<size_t>(nrecv) < message->msg_iov[i].iov_len)
break;
}
return total_recv;
return sendmsg(socket, &message, flags);
}
ssize_t sendmsg(int socket, const struct msghdr* message, int flags)
{
if (CMSG_FIRSTHDR(message))
{
dwarnln("TODO: sendmsg ancillary data");
errno = ENOTSUP;
return -1;
}
size_t total_sent = 0;
for (int i = 0; i < message->msg_iovlen; i++)
{
const ssize_t nsend = sendto(
socket,
message->msg_iov[i].iov_base,
message->msg_iov[i].iov_len,
flags,
static_cast<sockaddr*>(message->msg_name),
message->msg_namelen
);
if (nsend < 0)
return -1;
total_sent += nsend;
if (static_cast<size_t>(nsend) < message->msg_iov[i].iov_len)
break;
}
return total_sent;
pthread_testcancel();
return syscall(SYS_SENDMSG, socket, message, flags);
}
int socket(int domain, int type, int protocol)