Kernel: Cleanup and fix UNIX sockets
EPOLLOUT is now sent to the correct socket and buffer is now a ring buffer to avoid unnecessary memmove on every packet
This commit is contained in:
parent
2e59373a1e
commit
c648ea12f2
|
|
@ -70,6 +70,7 @@ namespace Kernel
|
|||
size_t size;
|
||||
BAN::Vector<FDWrapper> fds;
|
||||
BAN::Optional<struct ucred> ucred;
|
||||
BAN::WeakPtr<UnixDomainSocket> sender;
|
||||
};
|
||||
|
||||
BAN::ErrorOr<size_t> add_packet(const msghdr&, PacketInfo&&, bool dont_block);
|
||||
|
|
@ -82,10 +83,14 @@ namespace Kernel
|
|||
|
||||
BAN::CircularQueue<PacketInfo, 512> m_packet_infos;
|
||||
size_t m_packet_size_total { 0 };
|
||||
size_t m_packet_buffer_tail { 0 };
|
||||
BAN::UniqPtr<VirtualRange> m_packet_buffer;
|
||||
Mutex m_packet_lock;
|
||||
mutable Mutex m_packet_lock;
|
||||
ThreadBlocker m_packet_thread_blocker;
|
||||
|
||||
BAN::Atomic<size_t> m_sndbuf { 0 };
|
||||
BAN::Atomic<size_t> m_bytes_sent { 0 };
|
||||
|
||||
friend class BAN::RefPtr<UnixDomainSocket>;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ namespace Kernel
|
|||
static BAN::HashMap<BAN::RefPtr<Inode>, BAN::WeakPtr<UnixDomainSocket>, UnixSocketHash> s_bound_sockets;
|
||||
static Mutex s_bound_socket_lock;
|
||||
|
||||
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
|
||||
static constexpr size_t s_packet_buffer_size = 0x10000;
|
||||
|
||||
static BAN::ErrorOr<BAN::StringView> validate_sockaddr_un(const sockaddr* address, socklen_t address_len)
|
||||
{
|
||||
|
|
@ -45,8 +45,6 @@ namespace Kernel
|
|||
return BAN::StringView { sockaddr_un.sun_path, length };
|
||||
}
|
||||
|
||||
// FIXME: why is this using spinlocks instead of mutexes??
|
||||
|
||||
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
|
||||
{
|
||||
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
|
||||
|
|
@ -64,6 +62,7 @@ namespace Kernel
|
|||
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
|
||||
: Socket(info)
|
||||
, m_socket_type(socket_type)
|
||||
, m_sndbuf(s_packet_buffer_size)
|
||||
{
|
||||
switch (socket_type)
|
||||
{
|
||||
|
|
@ -290,29 +289,55 @@ namespace Kernel
|
|||
case Socket::Type::DGRAM:
|
||||
return false;
|
||||
}
|
||||
ASSERT_NOT_REACHED();
|
||||
}
|
||||
|
||||
BAN::ErrorOr<size_t> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info, bool dont_block)
|
||||
{
|
||||
LockGuard _(m_packet_lock);
|
||||
|
||||
while (m_packet_infos.full() || m_packet_size_total + packet_info.size > s_packet_buffer_size)
|
||||
const auto has_space =
|
||||
[&]() -> bool
|
||||
{
|
||||
if (m_packet_infos.full())
|
||||
return false;
|
||||
if (is_streaming())
|
||||
return m_packet_size_total < m_packet_buffer->size();
|
||||
return m_packet_size_total + packet_info.size <= m_packet_buffer->size();
|
||||
};
|
||||
|
||||
while (!has_space())
|
||||
{
|
||||
if (dont_block)
|
||||
return BAN::Error::from_errno(EAGAIN);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
||||
}
|
||||
|
||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < packet.msg_iovlen; i++)
|
||||
if (auto available = m_packet_buffer->size() - m_packet_size_total; available < packet_info.size)
|
||||
{
|
||||
memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len);
|
||||
offset += packet.msg_iov[i].iov_len;
|
||||
ASSERT(is_streaming());
|
||||
packet_info.size = available;
|
||||
}
|
||||
|
||||
ASSERT(offset == packet_info.size);
|
||||
uint8_t* packet_buffer_base_u8 = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
||||
|
||||
size_t bytes_copied = 0;
|
||||
for (int i = 0; i < packet.msg_iovlen && bytes_copied < packet_info.size; i++)
|
||||
{
|
||||
const uint8_t* iov_base_u8 = static_cast<const uint8_t*>(packet.msg_iov[i].iov_base);
|
||||
|
||||
const size_t to_copy = BAN::Math::min(packet.msg_iov[i].iov_len, packet_info.size - bytes_copied);
|
||||
|
||||
const size_t copy_offset = (m_packet_buffer_tail + m_packet_size_total + bytes_copied) % m_packet_buffer->size();
|
||||
const size_t before_wrap = BAN::Math::min(to_copy, m_packet_buffer->size() - copy_offset);
|
||||
memcpy(packet_buffer_base_u8 + copy_offset, iov_base_u8, before_wrap);
|
||||
if (const size_t after_wrap = to_copy - before_wrap)
|
||||
memcpy(packet_buffer_base_u8, iov_base_u8 + before_wrap, after_wrap);
|
||||
|
||||
bytes_copied += to_copy;
|
||||
}
|
||||
|
||||
ASSERT(bytes_copied == packet_info.size);
|
||||
m_packet_size_total += packet_info.size;
|
||||
m_packet_infos.emplace(BAN::move(packet_info));
|
||||
|
||||
|
|
@ -320,7 +345,7 @@ namespace Kernel
|
|||
|
||||
epoll_notify(EPOLLIN);
|
||||
|
||||
return {};
|
||||
return bytes_copied;
|
||||
}
|
||||
|
||||
bool UnixDomainSocket::can_read_impl() const
|
||||
|
|
@ -336,24 +361,13 @@ namespace Kernel
|
|||
return false;
|
||||
}
|
||||
|
||||
LockGuard _(m_packet_lock);
|
||||
return m_packet_size_total > 0;
|
||||
}
|
||||
|
||||
bool UnixDomainSocket::can_write_impl() const
|
||||
{
|
||||
if (m_info.has<ConnectionInfo>())
|
||||
{
|
||||
auto& connection_info = m_info.get<ConnectionInfo>();
|
||||
auto connection = connection_info.connection.lock();
|
||||
if (!connection)
|
||||
return false;
|
||||
if (connection->m_packet_infos.full())
|
||||
return false;
|
||||
if (connection->m_packet_size_total >= s_packet_buffer_size)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return m_bytes_sent < m_sndbuf;
|
||||
}
|
||||
|
||||
bool UnixDomainSocket::has_hungup_impl() const
|
||||
|
|
@ -377,6 +391,7 @@ namespace Kernel
|
|||
}
|
||||
|
||||
LockGuard _(m_packet_lock);
|
||||
|
||||
while (m_packet_size_total == 0)
|
||||
{
|
||||
if (m_info.has<ConnectionInfo>())
|
||||
|
|
@ -397,7 +412,7 @@ namespace Kernel
|
|||
cheader->cmsg_len = message.msg_controllen;
|
||||
size_t cheader_len = 0;
|
||||
|
||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
||||
uint8_t* packet_buffer_base_u8 = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
||||
|
||||
message.msg_flags = 0;
|
||||
|
||||
|
|
@ -471,7 +486,12 @@ namespace Kernel
|
|||
uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base);
|
||||
|
||||
const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received);
|
||||
memcpy(iov_base + iov_offset, packet_buffer + packet_received, nrecv);
|
||||
|
||||
const size_t copy_offset = (m_packet_buffer_tail + packet_received) % m_packet_buffer->size();
|
||||
const size_t before_wrap = BAN::Math::min<size_t>(nrecv, m_packet_buffer->size() - copy_offset);
|
||||
memcpy(iov_base + iov_offset, packet_buffer_base_u8 + copy_offset, before_wrap);
|
||||
if (const size_t after_wrap = nrecv - before_wrap)
|
||||
memcpy(iov_base + iov_offset + before_wrap, packet_buffer_base_u8, after_wrap);
|
||||
|
||||
packet_received += nrecv;
|
||||
|
||||
|
|
@ -492,12 +512,17 @@ namespace Kernel
|
|||
if (packet_info.size == 0)
|
||||
m_packet_infos.pop();
|
||||
|
||||
// FIXME: get rid of this memmove :)
|
||||
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
|
||||
m_packet_buffer_tail = (m_packet_buffer_tail + to_discard) % m_packet_buffer->size();
|
||||
m_packet_size_total -= to_discard;
|
||||
|
||||
total_recv += packet_received;
|
||||
|
||||
if (auto sender = packet_info.sender.lock())
|
||||
{
|
||||
sender->m_bytes_sent -= to_discard;
|
||||
sender->epoll_notify(EPOLLOUT);
|
||||
}
|
||||
|
||||
// on linux ancillary data is a barrier on stream sockets, lets do the same
|
||||
if (!is_streaming() || had_ancillary_data)
|
||||
break;
|
||||
|
|
@ -507,8 +532,6 @@ namespace Kernel
|
|||
|
||||
m_packet_thread_blocker.unblock();
|
||||
|
||||
epoll_notify(EPOLLOUT);
|
||||
|
||||
return total_recv;
|
||||
}
|
||||
|
||||
|
|
@ -528,13 +551,14 @@ namespace Kernel
|
|||
return result;
|
||||
}();
|
||||
|
||||
if (total_message_size > s_packet_buffer_size)
|
||||
return BAN::Error::from_errno(ENOBUFS);
|
||||
if (!is_streaming() && total_message_size > m_packet_buffer->size())
|
||||
return BAN::Error::from_errno(EMSGSIZE);
|
||||
|
||||
PacketInfo packet_info {
|
||||
.size = total_message_size,
|
||||
.fds = {},
|
||||
.ucred = {},
|
||||
.sender = TRY(get_weak_ptr()),
|
||||
};
|
||||
|
||||
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
|
||||
|
|
@ -607,8 +631,9 @@ namespace Kernel
|
|||
auto target = connection_info.connection.lock();
|
||||
if (!target)
|
||||
return BAN::Error::from_errno(ENOTCONN);
|
||||
TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
|
||||
return total_message_size;
|
||||
const size_t bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
|
||||
m_bytes_sent += bytes_sent;
|
||||
return bytes_sent;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
@ -625,7 +650,7 @@ namespace Kernel
|
|||
Process::current().root_file().inode,
|
||||
Process::current().credentials(),
|
||||
absolute_path,
|
||||
O_RDWR
|
||||
O_WRONLY
|
||||
)).inode;
|
||||
}
|
||||
else
|
||||
|
|
@ -651,8 +676,11 @@ namespace Kernel
|
|||
|
||||
if (!target)
|
||||
return BAN::Error::from_errno(EDESTADDRREQ);
|
||||
TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
|
||||
return total_message_size;
|
||||
if (target->m_socket_type != m_socket_type)
|
||||
return BAN::Error::from_errno(EPROTOTYPE);
|
||||
const auto bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
|
||||
m_bytes_sent += bytes_sent;
|
||||
return bytes_sent;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -688,7 +716,7 @@ namespace Kernel
|
|||
result = 0;
|
||||
break;
|
||||
case SO_SNDBUF:
|
||||
result = m_packet_buffer->size();
|
||||
result = m_sndbuf;
|
||||
break;
|
||||
case SO_RCVBUF:
|
||||
result = m_packet_buffer->size();
|
||||
|
|
@ -705,4 +733,29 @@ namespace Kernel
|
|||
return {};
|
||||
}
|
||||
|
||||
BAN::ErrorOr<void> UnixDomainSocket::setsockopt_impl(int level, int option, const void* value, socklen_t value_len)
|
||||
{
|
||||
if (level != SOL_SOCKET)
|
||||
return BAN::Error::from_errno(EINVAL);
|
||||
|
||||
switch (option)
|
||||
{
|
||||
case SO_SNDBUF:
|
||||
{
|
||||
if (value_len != sizeof(int))
|
||||
return BAN::Error::from_errno(EINVAL);
|
||||
const int new_sndbuf = *static_cast<const int*>(value);
|
||||
if (new_sndbuf < 0)
|
||||
return BAN::Error::from_errno(EINVAL);
|
||||
m_sndbuf = new_sndbuf;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
dwarnln("setsockopt(SOL_SOCKET, {})", option);
|
||||
return BAN::Error::from_errno(ENOTSUP);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue