Kernel: Cleanup and fix UNIX sockets

EPOLLOUT is now sent to the correct socket and buffer is now a ring
buffer to avoid unnecessary memmove on every packet
This commit is contained in:
Bananymous 2026-02-08 19:38:28 +02:00
parent 2e59373a1e
commit c648ea12f2
2 changed files with 101 additions and 43 deletions

View File

@ -70,6 +70,7 @@ namespace Kernel
size_t size;
BAN::Vector<FDWrapper> fds;
BAN::Optional<struct ucred> ucred;
BAN::WeakPtr<UnixDomainSocket> sender;
};
BAN::ErrorOr<size_t> add_packet(const msghdr&, PacketInfo&&, bool dont_block);
@ -82,10 +83,14 @@ namespace Kernel
BAN::CircularQueue<PacketInfo, 512> m_packet_infos;
size_t m_packet_size_total { 0 };
size_t m_packet_buffer_tail { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
Mutex m_packet_lock;
mutable Mutex m_packet_lock;
ThreadBlocker m_packet_thread_blocker;
BAN::Atomic<size_t> m_sndbuf { 0 };
BAN::Atomic<size_t> m_bytes_sent { 0 };
friend class BAN::RefPtr<UnixDomainSocket>;
};

View File

@ -25,7 +25,7 @@ namespace Kernel
static BAN::HashMap<BAN::RefPtr<Inode>, BAN::WeakPtr<UnixDomainSocket>, UnixSocketHash> s_bound_sockets;
static Mutex s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
static constexpr size_t s_packet_buffer_size = 0x10000;
static BAN::ErrorOr<BAN::StringView> validate_sockaddr_un(const sockaddr* address, socklen_t address_len)
{
@ -45,8 +45,6 @@ namespace Kernel
return BAN::StringView { sockaddr_un.sun_path, length };
}
// FIXME: why is this using spinlocks instead of mutexes??
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
{
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
@ -64,6 +62,7 @@ namespace Kernel
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
: Socket(info)
, m_socket_type(socket_type)
, m_sndbuf(s_packet_buffer_size)
{
switch (socket_type)
{
@ -290,29 +289,55 @@ namespace Kernel
case Socket::Type::DGRAM:
return false;
}
ASSERT_NOT_REACHED();
}
BAN::ErrorOr<size_t> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info, bool dont_block)
{
LockGuard _(m_packet_lock);
while (m_packet_infos.full() || m_packet_size_total + packet_info.size > s_packet_buffer_size)
const auto has_space =
[&]() -> bool
{
if (m_packet_infos.full())
return false;
if (is_streaming())
return m_packet_size_total < m_packet_buffer->size();
return m_packet_size_total + packet_info.size <= m_packet_buffer->size();
};
while (!has_space())
{
if (dont_block)
return BAN::Error::from_errno(EAGAIN);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
size_t offset = 0;
for (int i = 0; i < packet.msg_iovlen; i++)
if (auto available = m_packet_buffer->size() - m_packet_size_total; available < packet_info.size)
{
memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len);
offset += packet.msg_iov[i].iov_len;
ASSERT(is_streaming());
packet_info.size = available;
}
ASSERT(offset == packet_info.size);
uint8_t* packet_buffer_base_u8 = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t bytes_copied = 0;
for (int i = 0; i < packet.msg_iovlen && bytes_copied < packet_info.size; i++)
{
const uint8_t* iov_base_u8 = static_cast<const uint8_t*>(packet.msg_iov[i].iov_base);
const size_t to_copy = BAN::Math::min(packet.msg_iov[i].iov_len, packet_info.size - bytes_copied);
const size_t copy_offset = (m_packet_buffer_tail + m_packet_size_total + bytes_copied) % m_packet_buffer->size();
const size_t before_wrap = BAN::Math::min(to_copy, m_packet_buffer->size() - copy_offset);
memcpy(packet_buffer_base_u8 + copy_offset, iov_base_u8, before_wrap);
if (const size_t after_wrap = to_copy - before_wrap)
memcpy(packet_buffer_base_u8, iov_base_u8 + before_wrap, after_wrap);
bytes_copied += to_copy;
}
ASSERT(bytes_copied == packet_info.size);
m_packet_size_total += packet_info.size;
m_packet_infos.emplace(BAN::move(packet_info));
@ -320,7 +345,7 @@ namespace Kernel
epoll_notify(EPOLLIN);
return {};
return bytes_copied;
}
bool UnixDomainSocket::can_read_impl() const
@ -336,24 +361,13 @@ namespace Kernel
return false;
}
LockGuard _(m_packet_lock);
return m_packet_size_total > 0;
}
bool UnixDomainSocket::can_write_impl() const
{
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
auto connection = connection_info.connection.lock();
if (!connection)
return false;
if (connection->m_packet_infos.full())
return false;
if (connection->m_packet_size_total >= s_packet_buffer_size)
return false;
}
return true;
return m_bytes_sent < m_sndbuf;
}
bool UnixDomainSocket::has_hungup_impl() const
@ -377,6 +391,7 @@ namespace Kernel
}
LockGuard _(m_packet_lock);
while (m_packet_size_total == 0)
{
if (m_info.has<ConnectionInfo>())
@ -397,7 +412,7 @@ namespace Kernel
cheader->cmsg_len = message.msg_controllen;
size_t cheader_len = 0;
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
uint8_t* packet_buffer_base_u8 = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
message.msg_flags = 0;
@ -471,7 +486,12 @@ namespace Kernel
uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base);
const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received);
memcpy(iov_base + iov_offset, packet_buffer + packet_received, nrecv);
const size_t copy_offset = (m_packet_buffer_tail + packet_received) % m_packet_buffer->size();
const size_t before_wrap = BAN::Math::min<size_t>(nrecv, m_packet_buffer->size() - copy_offset);
memcpy(iov_base + iov_offset, packet_buffer_base_u8 + copy_offset, before_wrap);
if (const size_t after_wrap = nrecv - before_wrap)
memcpy(iov_base + iov_offset + before_wrap, packet_buffer_base_u8, after_wrap);
packet_received += nrecv;
@ -492,12 +512,17 @@ namespace Kernel
if (packet_info.size == 0)
m_packet_infos.pop();
// FIXME: get rid of this memmove :)
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
m_packet_buffer_tail = (m_packet_buffer_tail + to_discard) % m_packet_buffer->size();
m_packet_size_total -= to_discard;
total_recv += packet_received;
if (auto sender = packet_info.sender.lock())
{
sender->m_bytes_sent -= to_discard;
sender->epoll_notify(EPOLLOUT);
}
// on linux ancillary data is a barrier on stream sockets, lets do the same
if (!is_streaming() || had_ancillary_data)
break;
@ -507,8 +532,6 @@ namespace Kernel
m_packet_thread_blocker.unblock();
epoll_notify(EPOLLOUT);
return total_recv;
}
@ -528,13 +551,14 @@ namespace Kernel
return result;
}();
if (total_message_size > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS);
if (!is_streaming() && total_message_size > m_packet_buffer->size())
return BAN::Error::from_errno(EMSGSIZE);
PacketInfo packet_info {
.size = total_message_size,
.fds = {},
.ucred = {},
.size = total_message_size,
.fds = {},
.ucred = {},
.sender = TRY(get_weak_ptr()),
};
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
@ -607,8 +631,9 @@ namespace Kernel
auto target = connection_info.connection.lock();
if (!target)
return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
return total_message_size;
const size_t bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
m_bytes_sent += bytes_sent;
return bytes_sent;
}
else
{
@ -625,7 +650,7 @@ namespace Kernel
Process::current().root_file().inode,
Process::current().credentials(),
absolute_path,
O_RDWR
O_WRONLY
)).inode;
}
else
@ -651,8 +676,11 @@ namespace Kernel
if (!target)
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
return total_message_size;
if (target->m_socket_type != m_socket_type)
return BAN::Error::from_errno(EPROTOTYPE);
const auto bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
m_bytes_sent += bytes_sent;
return bytes_sent;
}
}
@ -688,7 +716,7 @@ namespace Kernel
result = 0;
break;
case SO_SNDBUF:
result = m_packet_buffer->size();
result = m_sndbuf;
break;
case SO_RCVBUF:
result = m_packet_buffer->size();
@ -705,4 +733,29 @@ namespace Kernel
return {};
}
BAN::ErrorOr<void> UnixDomainSocket::setsockopt_impl(int level, int option, const void* value, socklen_t value_len)
{
if (level != SOL_SOCKET)
return BAN::Error::from_errno(EINVAL);
switch (option)
{
case SO_SNDBUF:
{
if (value_len != sizeof(int))
return BAN::Error::from_errno(EINVAL);
const int new_sndbuf = *static_cast<const int*>(value);
if (new_sndbuf < 0)
return BAN::Error::from_errno(EINVAL);
m_sndbuf = new_sndbuf;
break;
}
default:
dwarnln("setsockopt(SOL_SOCKET, {})", option);
return BAN::Error::from_errno(ENOTSUP);
}
return {};
}
}