Kernel: Cleanup and fix UNIX sockets

EPOLLOUT is now sent to the correct socket and buffer is now a ring
buffer to avoid unnecessary memmove on every packet
This commit is contained in:
Bananymous 2026-02-08 19:38:28 +02:00
parent 2e59373a1e
commit c648ea12f2
2 changed files with 101 additions and 43 deletions

View File

@ -70,6 +70,7 @@ namespace Kernel
size_t size; size_t size;
BAN::Vector<FDWrapper> fds; BAN::Vector<FDWrapper> fds;
BAN::Optional<struct ucred> ucred; BAN::Optional<struct ucred> ucred;
BAN::WeakPtr<UnixDomainSocket> sender;
}; };
BAN::ErrorOr<size_t> add_packet(const msghdr&, PacketInfo&&, bool dont_block); BAN::ErrorOr<size_t> add_packet(const msghdr&, PacketInfo&&, bool dont_block);
@ -82,10 +83,14 @@ namespace Kernel
BAN::CircularQueue<PacketInfo, 512> m_packet_infos; BAN::CircularQueue<PacketInfo, 512> m_packet_infos;
size_t m_packet_size_total { 0 }; size_t m_packet_size_total { 0 };
size_t m_packet_buffer_tail { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer; BAN::UniqPtr<VirtualRange> m_packet_buffer;
Mutex m_packet_lock; mutable Mutex m_packet_lock;
ThreadBlocker m_packet_thread_blocker; ThreadBlocker m_packet_thread_blocker;
BAN::Atomic<size_t> m_sndbuf { 0 };
BAN::Atomic<size_t> m_bytes_sent { 0 };
friend class BAN::RefPtr<UnixDomainSocket>; friend class BAN::RefPtr<UnixDomainSocket>;
}; };

View File

@ -25,7 +25,7 @@ namespace Kernel
static BAN::HashMap<BAN::RefPtr<Inode>, BAN::WeakPtr<UnixDomainSocket>, UnixSocketHash> s_bound_sockets; static BAN::HashMap<BAN::RefPtr<Inode>, BAN::WeakPtr<UnixDomainSocket>, UnixSocketHash> s_bound_sockets;
static Mutex s_bound_socket_lock; static Mutex s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE; static constexpr size_t s_packet_buffer_size = 0x10000;
static BAN::ErrorOr<BAN::StringView> validate_sockaddr_un(const sockaddr* address, socklen_t address_len) static BAN::ErrorOr<BAN::StringView> validate_sockaddr_un(const sockaddr* address, socklen_t address_len)
{ {
@ -45,8 +45,6 @@ namespace Kernel
return BAN::StringView { sockaddr_un.sun_path, length }; return BAN::StringView { sockaddr_un.sun_path, length };
} }
// FIXME: why is this using spinlocks instead of mutexes??
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info) BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
{ {
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info)); auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
@ -64,6 +62,7 @@ namespace Kernel
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info) UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
: Socket(info) : Socket(info)
, m_socket_type(socket_type) , m_socket_type(socket_type)
, m_sndbuf(s_packet_buffer_size)
{ {
switch (socket_type) switch (socket_type)
{ {
@ -290,29 +289,55 @@ namespace Kernel
case Socket::Type::DGRAM: case Socket::Type::DGRAM:
return false; return false;
} }
ASSERT_NOT_REACHED();
} }
BAN::ErrorOr<size_t> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info, bool dont_block) BAN::ErrorOr<size_t> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info, bool dont_block)
{ {
LockGuard _(m_packet_lock); LockGuard _(m_packet_lock);
while (m_packet_infos.full() || m_packet_size_total + packet_info.size > s_packet_buffer_size) const auto has_space =
[&]() -> bool
{
if (m_packet_infos.full())
return false;
if (is_streaming())
return m_packet_size_total < m_packet_buffer->size();
return m_packet_size_total + packet_info.size <= m_packet_buffer->size();
};
while (!has_space())
{ {
if (dont_block) if (dont_block)
return BAN::Error::from_errno(EAGAIN); return BAN::Error::from_errno(EAGAIN);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock)); TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
} }
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total); if (auto available = m_packet_buffer->size() - m_packet_size_total; available < packet_info.size)
size_t offset = 0;
for (int i = 0; i < packet.msg_iovlen; i++)
{ {
memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len); ASSERT(is_streaming());
offset += packet.msg_iov[i].iov_len; packet_info.size = available;
} }
ASSERT(offset == packet_info.size); uint8_t* packet_buffer_base_u8 = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t bytes_copied = 0;
for (int i = 0; i < packet.msg_iovlen && bytes_copied < packet_info.size; i++)
{
const uint8_t* iov_base_u8 = static_cast<const uint8_t*>(packet.msg_iov[i].iov_base);
const size_t to_copy = BAN::Math::min(packet.msg_iov[i].iov_len, packet_info.size - bytes_copied);
const size_t copy_offset = (m_packet_buffer_tail + m_packet_size_total + bytes_copied) % m_packet_buffer->size();
const size_t before_wrap = BAN::Math::min(to_copy, m_packet_buffer->size() - copy_offset);
memcpy(packet_buffer_base_u8 + copy_offset, iov_base_u8, before_wrap);
if (const size_t after_wrap = to_copy - before_wrap)
memcpy(packet_buffer_base_u8, iov_base_u8 + before_wrap, after_wrap);
bytes_copied += to_copy;
}
ASSERT(bytes_copied == packet_info.size);
m_packet_size_total += packet_info.size; m_packet_size_total += packet_info.size;
m_packet_infos.emplace(BAN::move(packet_info)); m_packet_infos.emplace(BAN::move(packet_info));
@ -320,7 +345,7 @@ namespace Kernel
epoll_notify(EPOLLIN); epoll_notify(EPOLLIN);
return {}; return bytes_copied;
} }
bool UnixDomainSocket::can_read_impl() const bool UnixDomainSocket::can_read_impl() const
@ -336,24 +361,13 @@ namespace Kernel
return false; return false;
} }
LockGuard _(m_packet_lock);
return m_packet_size_total > 0; return m_packet_size_total > 0;
} }
bool UnixDomainSocket::can_write_impl() const bool UnixDomainSocket::can_write_impl() const
{ {
if (m_info.has<ConnectionInfo>()) return m_bytes_sent < m_sndbuf;
{
auto& connection_info = m_info.get<ConnectionInfo>();
auto connection = connection_info.connection.lock();
if (!connection)
return false;
if (connection->m_packet_infos.full())
return false;
if (connection->m_packet_size_total >= s_packet_buffer_size)
return false;
}
return true;
} }
bool UnixDomainSocket::has_hungup_impl() const bool UnixDomainSocket::has_hungup_impl() const
@ -377,6 +391,7 @@ namespace Kernel
} }
LockGuard _(m_packet_lock); LockGuard _(m_packet_lock);
while (m_packet_size_total == 0) while (m_packet_size_total == 0)
{ {
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
@ -397,7 +412,7 @@ namespace Kernel
cheader->cmsg_len = message.msg_controllen; cheader->cmsg_len = message.msg_controllen;
size_t cheader_len = 0; size_t cheader_len = 0;
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr()); uint8_t* packet_buffer_base_u8 = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
message.msg_flags = 0; message.msg_flags = 0;
@ -471,7 +486,12 @@ namespace Kernel
uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base); uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base);
const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received); const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received);
memcpy(iov_base + iov_offset, packet_buffer + packet_received, nrecv);
const size_t copy_offset = (m_packet_buffer_tail + packet_received) % m_packet_buffer->size();
const size_t before_wrap = BAN::Math::min<size_t>(nrecv, m_packet_buffer->size() - copy_offset);
memcpy(iov_base + iov_offset, packet_buffer_base_u8 + copy_offset, before_wrap);
if (const size_t after_wrap = nrecv - before_wrap)
memcpy(iov_base + iov_offset + before_wrap, packet_buffer_base_u8, after_wrap);
packet_received += nrecv; packet_received += nrecv;
@ -492,12 +512,17 @@ namespace Kernel
if (packet_info.size == 0) if (packet_info.size == 0)
m_packet_infos.pop(); m_packet_infos.pop();
// FIXME: get rid of this memmove :) m_packet_buffer_tail = (m_packet_buffer_tail + to_discard) % m_packet_buffer->size();
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
m_packet_size_total -= to_discard; m_packet_size_total -= to_discard;
total_recv += packet_received; total_recv += packet_received;
if (auto sender = packet_info.sender.lock())
{
sender->m_bytes_sent -= to_discard;
sender->epoll_notify(EPOLLOUT);
}
// on linux ancillary data is a barrier on stream sockets, lets do the same // on linux ancillary data is a barrier on stream sockets, lets do the same
if (!is_streaming() || had_ancillary_data) if (!is_streaming() || had_ancillary_data)
break; break;
@ -507,8 +532,6 @@ namespace Kernel
m_packet_thread_blocker.unblock(); m_packet_thread_blocker.unblock();
epoll_notify(EPOLLOUT);
return total_recv; return total_recv;
} }
@ -528,13 +551,14 @@ namespace Kernel
return result; return result;
}(); }();
if (total_message_size > s_packet_buffer_size) if (!is_streaming() && total_message_size > m_packet_buffer->size())
return BAN::Error::from_errno(ENOBUFS); return BAN::Error::from_errno(EMSGSIZE);
PacketInfo packet_info { PacketInfo packet_info {
.size = total_message_size, .size = total_message_size,
.fds = {}, .fds = {},
.ucred = {}, .ucred = {},
.sender = TRY(get_weak_ptr()),
}; };
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header)) for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
@ -607,8 +631,9 @@ namespace Kernel
auto target = connection_info.connection.lock(); auto target = connection_info.connection.lock();
if (!target) if (!target)
return BAN::Error::from_errno(ENOTCONN); return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT)); const size_t bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
return total_message_size; m_bytes_sent += bytes_sent;
return bytes_sent;
} }
else else
{ {
@ -625,7 +650,7 @@ namespace Kernel
Process::current().root_file().inode, Process::current().root_file().inode,
Process::current().credentials(), Process::current().credentials(),
absolute_path, absolute_path,
O_RDWR O_WRONLY
)).inode; )).inode;
} }
else else
@ -651,8 +676,11 @@ namespace Kernel
if (!target) if (!target)
return BAN::Error::from_errno(EDESTADDRREQ); return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT)); if (target->m_socket_type != m_socket_type)
return total_message_size; return BAN::Error::from_errno(EPROTOTYPE);
const auto bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
m_bytes_sent += bytes_sent;
return bytes_sent;
} }
} }
@ -688,7 +716,7 @@ namespace Kernel
result = 0; result = 0;
break; break;
case SO_SNDBUF: case SO_SNDBUF:
result = m_packet_buffer->size(); result = m_sndbuf;
break; break;
case SO_RCVBUF: case SO_RCVBUF:
result = m_packet_buffer->size(); result = m_packet_buffer->size();
@ -705,4 +733,29 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<void> UnixDomainSocket::setsockopt_impl(int level, int option, const void* value, socklen_t value_len)
{
if (level != SOL_SOCKET)
return BAN::Error::from_errno(EINVAL);
switch (option)
{
case SO_SNDBUF:
{
if (value_len != sizeof(int))
return BAN::Error::from_errno(EINVAL);
const int new_sndbuf = *static_cast<const int*>(value);
if (new_sndbuf < 0)
return BAN::Error::from_errno(EINVAL);
m_sndbuf = new_sndbuf;
break;
}
default:
dwarnln("setsockopt(SOL_SOCKET, {})", option);
return BAN::Error::from_errno(ENOTSUP);
}
return {};
}
} }