Kernel: Cleanup and fix UNIX sockets
EPOLLOUT is now sent to the correct socket and buffer is now a ring buffer to avoid unnecessary memmove on every packet
This commit is contained in:
parent
2e59373a1e
commit
c648ea12f2
|
|
@ -70,6 +70,7 @@ namespace Kernel
|
||||||
size_t size;
|
size_t size;
|
||||||
BAN::Vector<FDWrapper> fds;
|
BAN::Vector<FDWrapper> fds;
|
||||||
BAN::Optional<struct ucred> ucred;
|
BAN::Optional<struct ucred> ucred;
|
||||||
|
BAN::WeakPtr<UnixDomainSocket> sender;
|
||||||
};
|
};
|
||||||
|
|
||||||
BAN::ErrorOr<size_t> add_packet(const msghdr&, PacketInfo&&, bool dont_block);
|
BAN::ErrorOr<size_t> add_packet(const msghdr&, PacketInfo&&, bool dont_block);
|
||||||
|
|
@ -82,10 +83,14 @@ namespace Kernel
|
||||||
|
|
||||||
BAN::CircularQueue<PacketInfo, 512> m_packet_infos;
|
BAN::CircularQueue<PacketInfo, 512> m_packet_infos;
|
||||||
size_t m_packet_size_total { 0 };
|
size_t m_packet_size_total { 0 };
|
||||||
|
size_t m_packet_buffer_tail { 0 };
|
||||||
BAN::UniqPtr<VirtualRange> m_packet_buffer;
|
BAN::UniqPtr<VirtualRange> m_packet_buffer;
|
||||||
Mutex m_packet_lock;
|
mutable Mutex m_packet_lock;
|
||||||
ThreadBlocker m_packet_thread_blocker;
|
ThreadBlocker m_packet_thread_blocker;
|
||||||
|
|
||||||
|
BAN::Atomic<size_t> m_sndbuf { 0 };
|
||||||
|
BAN::Atomic<size_t> m_bytes_sent { 0 };
|
||||||
|
|
||||||
friend class BAN::RefPtr<UnixDomainSocket>;
|
friend class BAN::RefPtr<UnixDomainSocket>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ namespace Kernel
|
||||||
static BAN::HashMap<BAN::RefPtr<Inode>, BAN::WeakPtr<UnixDomainSocket>, UnixSocketHash> s_bound_sockets;
|
static BAN::HashMap<BAN::RefPtr<Inode>, BAN::WeakPtr<UnixDomainSocket>, UnixSocketHash> s_bound_sockets;
|
||||||
static Mutex s_bound_socket_lock;
|
static Mutex s_bound_socket_lock;
|
||||||
|
|
||||||
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
|
static constexpr size_t s_packet_buffer_size = 0x10000;
|
||||||
|
|
||||||
static BAN::ErrorOr<BAN::StringView> validate_sockaddr_un(const sockaddr* address, socklen_t address_len)
|
static BAN::ErrorOr<BAN::StringView> validate_sockaddr_un(const sockaddr* address, socklen_t address_len)
|
||||||
{
|
{
|
||||||
|
|
@ -45,8 +45,6 @@ namespace Kernel
|
||||||
return BAN::StringView { sockaddr_un.sun_path, length };
|
return BAN::StringView { sockaddr_un.sun_path, length };
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: why is this using spinlocks instead of mutexes??
|
|
||||||
|
|
||||||
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
|
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
|
||||||
{
|
{
|
||||||
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
|
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
|
||||||
|
|
@ -64,6 +62,7 @@ namespace Kernel
|
||||||
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
|
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
|
||||||
: Socket(info)
|
: Socket(info)
|
||||||
, m_socket_type(socket_type)
|
, m_socket_type(socket_type)
|
||||||
|
, m_sndbuf(s_packet_buffer_size)
|
||||||
{
|
{
|
||||||
switch (socket_type)
|
switch (socket_type)
|
||||||
{
|
{
|
||||||
|
|
@ -290,29 +289,55 @@ namespace Kernel
|
||||||
case Socket::Type::DGRAM:
|
case Socket::Type::DGRAM:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
ASSERT_NOT_REACHED();
|
||||||
}
|
}
|
||||||
|
|
||||||
BAN::ErrorOr<size_t> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info, bool dont_block)
|
BAN::ErrorOr<size_t> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info, bool dont_block)
|
||||||
{
|
{
|
||||||
LockGuard _(m_packet_lock);
|
LockGuard _(m_packet_lock);
|
||||||
|
|
||||||
while (m_packet_infos.full() || m_packet_size_total + packet_info.size > s_packet_buffer_size)
|
const auto has_space =
|
||||||
|
[&]() -> bool
|
||||||
|
{
|
||||||
|
if (m_packet_infos.full())
|
||||||
|
return false;
|
||||||
|
if (is_streaming())
|
||||||
|
return m_packet_size_total < m_packet_buffer->size();
|
||||||
|
return m_packet_size_total + packet_info.size <= m_packet_buffer->size();
|
||||||
|
};
|
||||||
|
|
||||||
|
while (!has_space())
|
||||||
{
|
{
|
||||||
if (dont_block)
|
if (dont_block)
|
||||||
return BAN::Error::from_errno(EAGAIN);
|
return BAN::Error::from_errno(EAGAIN);
|
||||||
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
|
if (auto available = m_packet_buffer->size() - m_packet_size_total; available < packet_info.size)
|
||||||
|
|
||||||
size_t offset = 0;
|
|
||||||
for (int i = 0; i < packet.msg_iovlen; i++)
|
|
||||||
{
|
{
|
||||||
memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len);
|
ASSERT(is_streaming());
|
||||||
offset += packet.msg_iov[i].iov_len;
|
packet_info.size = available;
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSERT(offset == packet_info.size);
|
uint8_t* packet_buffer_base_u8 = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
||||||
|
|
||||||
|
size_t bytes_copied = 0;
|
||||||
|
for (int i = 0; i < packet.msg_iovlen && bytes_copied < packet_info.size; i++)
|
||||||
|
{
|
||||||
|
const uint8_t* iov_base_u8 = static_cast<const uint8_t*>(packet.msg_iov[i].iov_base);
|
||||||
|
|
||||||
|
const size_t to_copy = BAN::Math::min(packet.msg_iov[i].iov_len, packet_info.size - bytes_copied);
|
||||||
|
|
||||||
|
const size_t copy_offset = (m_packet_buffer_tail + m_packet_size_total + bytes_copied) % m_packet_buffer->size();
|
||||||
|
const size_t before_wrap = BAN::Math::min(to_copy, m_packet_buffer->size() - copy_offset);
|
||||||
|
memcpy(packet_buffer_base_u8 + copy_offset, iov_base_u8, before_wrap);
|
||||||
|
if (const size_t after_wrap = to_copy - before_wrap)
|
||||||
|
memcpy(packet_buffer_base_u8, iov_base_u8 + before_wrap, after_wrap);
|
||||||
|
|
||||||
|
bytes_copied += to_copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT(bytes_copied == packet_info.size);
|
||||||
m_packet_size_total += packet_info.size;
|
m_packet_size_total += packet_info.size;
|
||||||
m_packet_infos.emplace(BAN::move(packet_info));
|
m_packet_infos.emplace(BAN::move(packet_info));
|
||||||
|
|
||||||
|
|
@ -320,7 +345,7 @@ namespace Kernel
|
||||||
|
|
||||||
epoll_notify(EPOLLIN);
|
epoll_notify(EPOLLIN);
|
||||||
|
|
||||||
return {};
|
return bytes_copied;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UnixDomainSocket::can_read_impl() const
|
bool UnixDomainSocket::can_read_impl() const
|
||||||
|
|
@ -336,24 +361,13 @@ namespace Kernel
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LockGuard _(m_packet_lock);
|
||||||
return m_packet_size_total > 0;
|
return m_packet_size_total > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UnixDomainSocket::can_write_impl() const
|
bool UnixDomainSocket::can_write_impl() const
|
||||||
{
|
{
|
||||||
if (m_info.has<ConnectionInfo>())
|
return m_bytes_sent < m_sndbuf;
|
||||||
{
|
|
||||||
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
||||||
auto connection = connection_info.connection.lock();
|
|
||||||
if (!connection)
|
|
||||||
return false;
|
|
||||||
if (connection->m_packet_infos.full())
|
|
||||||
return false;
|
|
||||||
if (connection->m_packet_size_total >= s_packet_buffer_size)
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UnixDomainSocket::has_hungup_impl() const
|
bool UnixDomainSocket::has_hungup_impl() const
|
||||||
|
|
@ -377,6 +391,7 @@ namespace Kernel
|
||||||
}
|
}
|
||||||
|
|
||||||
LockGuard _(m_packet_lock);
|
LockGuard _(m_packet_lock);
|
||||||
|
|
||||||
while (m_packet_size_total == 0)
|
while (m_packet_size_total == 0)
|
||||||
{
|
{
|
||||||
if (m_info.has<ConnectionInfo>())
|
if (m_info.has<ConnectionInfo>())
|
||||||
|
|
@ -397,7 +412,7 @@ namespace Kernel
|
||||||
cheader->cmsg_len = message.msg_controllen;
|
cheader->cmsg_len = message.msg_controllen;
|
||||||
size_t cheader_len = 0;
|
size_t cheader_len = 0;
|
||||||
|
|
||||||
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
uint8_t* packet_buffer_base_u8 = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
||||||
|
|
||||||
message.msg_flags = 0;
|
message.msg_flags = 0;
|
||||||
|
|
||||||
|
|
@ -471,7 +486,12 @@ namespace Kernel
|
||||||
uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base);
|
uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base);
|
||||||
|
|
||||||
const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received);
|
const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received);
|
||||||
memcpy(iov_base + iov_offset, packet_buffer + packet_received, nrecv);
|
|
||||||
|
const size_t copy_offset = (m_packet_buffer_tail + packet_received) % m_packet_buffer->size();
|
||||||
|
const size_t before_wrap = BAN::Math::min<size_t>(nrecv, m_packet_buffer->size() - copy_offset);
|
||||||
|
memcpy(iov_base + iov_offset, packet_buffer_base_u8 + copy_offset, before_wrap);
|
||||||
|
if (const size_t after_wrap = nrecv - before_wrap)
|
||||||
|
memcpy(iov_base + iov_offset + before_wrap, packet_buffer_base_u8, after_wrap);
|
||||||
|
|
||||||
packet_received += nrecv;
|
packet_received += nrecv;
|
||||||
|
|
||||||
|
|
@ -492,12 +512,17 @@ namespace Kernel
|
||||||
if (packet_info.size == 0)
|
if (packet_info.size == 0)
|
||||||
m_packet_infos.pop();
|
m_packet_infos.pop();
|
||||||
|
|
||||||
// FIXME: get rid of this memmove :)
|
m_packet_buffer_tail = (m_packet_buffer_tail + to_discard) % m_packet_buffer->size();
|
||||||
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
|
|
||||||
m_packet_size_total -= to_discard;
|
m_packet_size_total -= to_discard;
|
||||||
|
|
||||||
total_recv += packet_received;
|
total_recv += packet_received;
|
||||||
|
|
||||||
|
if (auto sender = packet_info.sender.lock())
|
||||||
|
{
|
||||||
|
sender->m_bytes_sent -= to_discard;
|
||||||
|
sender->epoll_notify(EPOLLOUT);
|
||||||
|
}
|
||||||
|
|
||||||
// on linux ancillary data is a barrier on stream sockets, lets do the same
|
// on linux ancillary data is a barrier on stream sockets, lets do the same
|
||||||
if (!is_streaming() || had_ancillary_data)
|
if (!is_streaming() || had_ancillary_data)
|
||||||
break;
|
break;
|
||||||
|
|
@ -507,8 +532,6 @@ namespace Kernel
|
||||||
|
|
||||||
m_packet_thread_blocker.unblock();
|
m_packet_thread_blocker.unblock();
|
||||||
|
|
||||||
epoll_notify(EPOLLOUT);
|
|
||||||
|
|
||||||
return total_recv;
|
return total_recv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -528,13 +551,14 @@ namespace Kernel
|
||||||
return result;
|
return result;
|
||||||
}();
|
}();
|
||||||
|
|
||||||
if (total_message_size > s_packet_buffer_size)
|
if (!is_streaming() && total_message_size > m_packet_buffer->size())
|
||||||
return BAN::Error::from_errno(ENOBUFS);
|
return BAN::Error::from_errno(EMSGSIZE);
|
||||||
|
|
||||||
PacketInfo packet_info {
|
PacketInfo packet_info {
|
||||||
.size = total_message_size,
|
.size = total_message_size,
|
||||||
.fds = {},
|
.fds = {},
|
||||||
.ucred = {},
|
.ucred = {},
|
||||||
|
.sender = TRY(get_weak_ptr()),
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
|
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
|
||||||
|
|
@ -607,8 +631,9 @@ namespace Kernel
|
||||||
auto target = connection_info.connection.lock();
|
auto target = connection_info.connection.lock();
|
||||||
if (!target)
|
if (!target)
|
||||||
return BAN::Error::from_errno(ENOTCONN);
|
return BAN::Error::from_errno(ENOTCONN);
|
||||||
TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
|
const size_t bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
|
||||||
return total_message_size;
|
m_bytes_sent += bytes_sent;
|
||||||
|
return bytes_sent;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|
@ -625,7 +650,7 @@ namespace Kernel
|
||||||
Process::current().root_file().inode,
|
Process::current().root_file().inode,
|
||||||
Process::current().credentials(),
|
Process::current().credentials(),
|
||||||
absolute_path,
|
absolute_path,
|
||||||
O_RDWR
|
O_WRONLY
|
||||||
)).inode;
|
)).inode;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
|
@ -651,8 +676,11 @@ namespace Kernel
|
||||||
|
|
||||||
if (!target)
|
if (!target)
|
||||||
return BAN::Error::from_errno(EDESTADDRREQ);
|
return BAN::Error::from_errno(EDESTADDRREQ);
|
||||||
TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
|
if (target->m_socket_type != m_socket_type)
|
||||||
return total_message_size;
|
return BAN::Error::from_errno(EPROTOTYPE);
|
||||||
|
const auto bytes_sent = TRY(target->add_packet(message, BAN::move(packet_info), flags & MSG_DONTWAIT));
|
||||||
|
m_bytes_sent += bytes_sent;
|
||||||
|
return bytes_sent;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -688,7 +716,7 @@ namespace Kernel
|
||||||
result = 0;
|
result = 0;
|
||||||
break;
|
break;
|
||||||
case SO_SNDBUF:
|
case SO_SNDBUF:
|
||||||
result = m_packet_buffer->size();
|
result = m_sndbuf;
|
||||||
break;
|
break;
|
||||||
case SO_RCVBUF:
|
case SO_RCVBUF:
|
||||||
result = m_packet_buffer->size();
|
result = m_packet_buffer->size();
|
||||||
|
|
@ -705,4 +733,29 @@ namespace Kernel
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BAN::ErrorOr<void> UnixDomainSocket::setsockopt_impl(int level, int option, const void* value, socklen_t value_len)
|
||||||
|
{
|
||||||
|
if (level != SOL_SOCKET)
|
||||||
|
return BAN::Error::from_errno(EINVAL);
|
||||||
|
|
||||||
|
switch (option)
|
||||||
|
{
|
||||||
|
case SO_SNDBUF:
|
||||||
|
{
|
||||||
|
if (value_len != sizeof(int))
|
||||||
|
return BAN::Error::from_errno(EINVAL);
|
||||||
|
const int new_sndbuf = *static_cast<const int*>(value);
|
||||||
|
if (new_sndbuf < 0)
|
||||||
|
return BAN::Error::from_errno(EINVAL);
|
||||||
|
m_sndbuf = new_sndbuf;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
dwarnln("setsockopt(SOL_SOCKET, {})", option);
|
||||||
|
return BAN::Error::from_errno(ENOTSUP);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue