Kernel: Make TCP sockets use the new ring buffer

Also fix race condition that sometimes prevented window updates not
being sent after zero window effectively hanging the whole socket
This commit is contained in:
Bananymous 2026-02-28 14:22:08 +02:00
parent 8b8af1a9d9
commit 812ae77cd7
2 changed files with 84 additions and 121 deletions

View File

@ -4,7 +4,7 @@
#include <BAN/Endianness.h> #include <BAN/Endianness.h>
#include <BAN/Queue.h> #include <BAN/Queue.h>
#include <kernel/Lock/Mutex.h> #include <kernel/Lock/Mutex.h>
#include <kernel/Memory/VirtualRange.h> #include <kernel/Memory/ByteRingBuffer.h>
#include <kernel/Networking/NetworkInterface.h> #include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkSocket.h> #include <kernel/Networking/NetworkSocket.h>
#include <kernel/Thread.h> #include <kernel/Thread.h>
@ -97,10 +97,8 @@ namespace Kernel
bool has_ghost_byte { false }; bool has_ghost_byte { false };
uint32_t data_tail { 0 };
uint32_t data_size { 0 }; // number of bytes in this buffer
uint8_t scale_shift { 0 }; // window scale uint8_t scale_shift { 0 }; // window scale
BAN::UniqPtr<VirtualRange> buffer; BAN::UniqPtr<ByteRingBuffer> buffer;
}; };
struct SendWindowInfo struct SendWindowInfo
@ -119,10 +117,8 @@ namespace Kernel
bool has_ghost_byte { false }; bool has_ghost_byte { false };
bool had_zero_window { false }; bool had_zero_window { false };
uint32_t data_tail { 0 };
uint32_t data_size { 0 }; // number of bytes in this buffer
uint32_t sent_size { 0 }; // number of bytes in this buffer that have been sent uint32_t sent_size { 0 }; // number of bytes in this buffer that have been sent
BAN::UniqPtr<VirtualRange> buffer; BAN::UniqPtr<ByteRingBuffer> buffer;
}; };
struct ConnectionInfo struct ConnectionInfo
@ -180,8 +176,8 @@ namespace Kernel
bool m_keep_alive { false }; bool m_keep_alive { false };
bool m_no_delay { false }; bool m_no_delay { false };
bool m_should_send_ack { false };
bool m_should_send_zero_window { false }; bool m_should_send_zero_window { false };
bool m_should_send_window_update { false };
uint64_t m_time_wait_start_ms { 0 }; uint64_t m_time_wait_start_ms { 0 };

View File

@ -35,23 +35,9 @@ namespace Kernel
{ {
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, info)); auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, info));
socket->m_last_sent_window_size = s_recv_window_buffer_size; socket->m_last_sent_window_size = s_recv_window_buffer_size;
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range( socket->m_recv_window.buffer = TRY(ByteRingBuffer::create(s_recv_window_buffer_size));
PageTable::kernel(),
KERNEL_OFFSET,
~(vaddr_t)0,
s_recv_window_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false
));
socket->m_recv_window.scale_shift = s_window_shift; socket->m_recv_window.scale_shift = s_window_shift;
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range( socket->m_send_window.buffer = TRY(ByteRingBuffer::create(s_send_window_buffer_size));
PageTable::kernel(),
KERNEL_OFFSET,
~(vaddr_t)0,
s_send_window_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false
));
socket->m_thread = TRY(Thread::create_kernel( socket->m_thread = TRY(Thread::create_kernel(
[](void* socket_ptr) [](void* socket_ptr)
{ {
@ -206,8 +192,7 @@ namespace Kernel
BAN::ErrorOr<size_t> TCPSocket::recvmsg_impl(msghdr& message, int flags) BAN::ErrorOr<size_t> TCPSocket::recvmsg_impl(msghdr& message, int flags)
{ {
flags &= (MSG_OOB | MSG_PEEK | MSG_WAITALL); if (flags & ~(MSG_PEEK))
if (flags != 0)
{ {
dwarnln("TODO: recvmsg with flags 0x{H}", flags); dwarnln("TODO: recvmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP); return BAN::Error::from_errno(ENOTSUP);
@ -222,7 +207,7 @@ namespace Kernel
if (!m_has_connected) if (!m_has_connected)
return BAN::Error::from_errno(ENOTCONN); return BAN::Error::from_errno(ENOTCONN);
while (m_recv_window.data_size == 0) while (m_recv_window.buffer->empty())
{ {
if (m_state != State::Established) if (m_state != State::Established)
return return_with_maybe_zero(); return return_with_maybe_zero();
@ -232,37 +217,33 @@ namespace Kernel
message.msg_flags = 0; message.msg_flags = 0;
size_t total_recv = 0; size_t total_recv = 0;
for (int i = 0; i < message.msg_iovlen; i++) for (int i = 0; i < message.msg_iovlen && total_recv < m_recv_window.buffer->size(); i++)
{ {
const auto* recv_base = reinterpret_cast<const uint8_t*>(m_recv_window.buffer->vaddr()); auto& iov = message.msg_iov[i];
uint8_t* iov_base = static_cast<uint8_t*>(message.msg_iov[i].iov_base);
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_recv_window.data_size); const size_t nrecv = BAN::Math::min(iov.iov_len, m_recv_window.buffer->size() - total_recv);
memcpy(iov.iov_base, m_recv_window.buffer->get_data().data() + total_recv, nrecv);
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - m_recv_window.data_tail);
memcpy(iov_base, recv_base + m_recv_window.data_tail, before_wrap);
if (const size_t after_wrap = nrecv - before_wrap)
memcpy(iov_base + before_wrap, recv_base, after_wrap);
total_recv += nrecv; total_recv += nrecv;
m_recv_window.data_size -= nrecv;
m_recv_window.start_seq += nrecv;
m_recv_window.data_tail = (m_recv_window.data_tail + nrecv) % m_recv_window.buffer->size();
if (m_recv_window.data_size == 0)
break;
} }
const size_t update_window_threshold = m_recv_window.buffer->size() / 8; if (!(flags & MSG_PEEK))
{
m_recv_window.buffer->pop(total_recv);
m_recv_window.start_seq += total_recv;
}
const size_t update_window_threshold = m_recv_window.buffer->capacity() / 8;
const bool should_update_window_size = const bool should_update_window_size =
m_last_sent_window_size != m_recv_window.buffer->size() && ( m_last_sent_window_size != m_recv_window.buffer->capacity() && (
(m_last_sent_window_size == 0) || (m_last_sent_window_size == 0) ||
(m_recv_window.data_size == 0) || (m_recv_window.buffer->empty()) ||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size) (m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->free())
); );
if (should_update_window_size) if (should_update_window_size || m_should_send_zero_window)
{ {
m_should_send_ack = true; m_should_send_window_update = true;
m_thread_blocker.unblock(); m_thread_blocker.unblock();
} }
@ -283,7 +264,7 @@ namespace Kernel
if (!m_has_connected) if (!m_has_connected)
return BAN::Error::from_errno(ENOTCONN); return BAN::Error::from_errno(ENOTCONN);
while (m_send_window.data_size == m_send_window.buffer->size()) while (m_send_window.buffer->full())
{ {
if (m_state != State::Established) if (m_state != State::Established)
return return_with_maybe_zero(); return return_with_maybe_zero();
@ -293,23 +274,14 @@ namespace Kernel
} }
size_t total_sent = 0; size_t total_sent = 0;
for (int i = 0; i < message.msg_iovlen; i++) for (int i = 0; i < message.msg_iovlen && !m_send_window.buffer->full(); i++)
{ {
auto* send_base = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr()); const auto& iov = message.msg_iov[i];
const auto* iov_base = static_cast<const uint8_t*>(message.msg_iov[i].iov_base);
const size_t nsend = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size); const size_t nsend = BAN::Math::min(iov.iov_len, m_send_window.buffer->free());
m_send_window.buffer->push({ static_cast<const uint8_t*>(iov.iov_base), nsend });
const size_t send_head = (m_send_window.data_tail + m_send_window.data_size) % m_send_window.buffer->size();
const size_t before_wrap = BAN::Math::min(nsend, m_send_window.buffer->size() - send_head);
memcpy(send_base + send_head, message.msg_iov[i].iov_base, before_wrap);
if (const size_t after_wrap = nsend - before_wrap)
memcpy(send_base, iov_base + before_wrap, after_wrap);
total_sent += nsend; total_sent += nsend;
m_send_window.data_size += nsend;
if (m_send_window.data_size == m_send_window.buffer->size())
break;
} }
m_thread_blocker.unblock(); m_thread_blocker.unblock();
@ -347,7 +319,7 @@ namespace Kernel
result = m_send_window.scaled_size(); result = m_send_window.scaled_size();
break; break;
case SO_RCVBUF: case SO_RCVBUF:
result = m_recv_window.buffer->size(); result = m_recv_window.buffer->capacity();
break; break;
default: default:
dwarnln("getsockopt(SOL_SOCKET, {})", option); dwarnln("getsockopt(SOL_SOCKET, {})", option);
@ -420,7 +392,7 @@ namespace Kernel
switch (request) switch (request)
{ {
case FIONREAD: case FIONREAD:
*static_cast<int*>(argument) = m_recv_window.data_size; *static_cast<int*>(argument) = m_recv_window.buffer->size();
return 0; return 0;
} }
@ -433,14 +405,14 @@ namespace Kernel
return true; return true;
if (m_state == State::Listen) if (m_state == State::Listen)
return !m_pending_connections.empty(); return !m_pending_connections.empty();
return m_recv_window.data_size > 0; return !m_recv_window.buffer->empty();
} }
bool TCPSocket::can_write_impl() const bool TCPSocket::can_write_impl() const
{ {
if (m_state != State::Established) if (m_state != State::Established)
return false; return false;
return m_send_window.data_size < m_send_window.buffer->size(); return !m_send_window.buffer->full();
} }
bool TCPSocket::has_hungup_impl() const bool TCPSocket::has_hungup_impl() const
@ -530,19 +502,14 @@ namespace Kernel
ASSERT(m_mutex.locker() == Thread::current().tid()); ASSERT(m_mutex.locker() == Thread::current().tid());
ASSERT(header_buffer.size() == protocol_header_size()); ASSERT(header_buffer.size() == protocol_header_size());
m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size; m_last_sent_window_size = m_should_send_zero_window ? 0 : m_recv_window.buffer->free();
if (m_should_send_zero_window)
m_last_sent_window_size = 0;
m_should_send_ack = false;
m_should_send_zero_window = false;
auto& header = header_buffer.as<TCPHeader>(); auto& header = header_buffer.as<TCPHeader>();
header = { header = {
.src_port = bound_port(), .src_port = bound_port(),
.dst_port = dst_port, .dst_port = dst_port,
.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte, .seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte,
.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte, .ack_number = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte,
.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t), .data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t),
.flags = m_next_flags, .flags = m_next_flags,
.window_size = BAN::Math::min<size_t>(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift), .window_size = BAN::Math::min<size_t>(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift),
@ -569,7 +536,7 @@ namespace Kernel
if (m_connection_info->has_window_scale) if (m_connection_info->has_window_scale)
add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift); add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift);
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->size()); header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->capacity());
m_send_window.start_seq++; m_send_window.start_seq++;
m_send_window.current_seq = m_send_window.start_seq; m_send_window.current_seq = m_send_window.start_seq;
@ -722,7 +689,7 @@ namespace Kernel
check_payload = true; check_payload = true;
if (!(header.flags & FIN)) if (!(header.flags & FIN))
break; break;
if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number) if (m_recv_window.start_seq + m_recv_window.buffer->size() != header.seq_number)
break; break;
m_next_flags = FIN | ACK; m_next_flags = FIN | ACK;
m_next_state = State::LastAck; m_next_state = State::LastAck;
@ -771,7 +738,7 @@ namespace Kernel
break; break;
} }
const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte; const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte;
if (header.seq_number > expected_seq) if (header.seq_number > expected_seq)
dprintln_if(DEBUG_TCP, "Missing packets"); dprintln_if(DEBUG_TCP, "Missing packets");
@ -794,21 +761,12 @@ namespace Kernel
payload = {}; payload = {};
} }
const bool can_receive_new_data = (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size()); const bool can_receive_new_data = (payload.size() > 0 && !m_recv_window.buffer->full());
if (can_receive_new_data) if (can_receive_new_data)
{ {
auto* recv_base = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr()); const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->free());
m_recv_window.buffer->push(payload.slice(0, nrecv));
const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->size() - m_recv_window.data_size);
const size_t recv_head = (m_recv_window.data_tail + m_recv_window.data_size) % m_recv_window.buffer->size();
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - recv_head);
memcpy(recv_base + recv_head, payload.data(), before_wrap);
if (const size_t after_wrap = nrecv - before_wrap)
memcpy(recv_base, payload.data() + before_wrap, after_wrap);
m_recv_window.data_size += nrecv;
epoll_notify(EPOLLIN); epoll_notify(EPOLLIN);
@ -816,10 +774,13 @@ namespace Kernel
} }
// make sure zero window is reported // make sure zero window is reported
if (m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size()) if (m_last_sent_window_size > 0 && m_recv_window.buffer->full())
m_should_send_zero_window = true; m_should_send_zero_window = true;
else if (can_receive_new_data) else if (can_receive_new_data)
m_should_send_ack = true; {
m_next_flags = ACK;
m_next_state = m_state;
}
} }
if (!hungup_before && has_hungup_impl()) if (!hungup_before && has_hungup_impl())
@ -925,12 +886,11 @@ namespace Kernel
if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq) if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
{ {
const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte; const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
ASSERT(acknowledged_bytes <= m_send_window.data_size); ASSERT(acknowledged_bytes <= m_send_window.buffer->size());
m_send_window.data_size -= acknowledged_bytes;
m_send_window.start_seq += acknowledged_bytes; m_send_window.start_seq += acknowledged_bytes;
m_send_window.sent_size -= acknowledged_bytes; m_send_window.sent_size -= acknowledged_bytes;
m_send_window.data_tail = (m_send_window.data_tail + acknowledged_bytes) % m_send_window.buffer->size(); m_send_window.buffer->pop(acknowledged_bytes);
epoll_notify(EPOLLOUT); epoll_notify(EPOLLOUT);
@ -941,7 +901,7 @@ namespace Kernel
const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms); const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
const bool can_send_new_data = (m_send_window.data_size > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size()); const bool can_send_new_data = (m_send_window.buffer->size() > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data)) if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data))
{ {
@ -951,24 +911,20 @@ namespace Kernel
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address); auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len; auto target_address_len = m_connection_info->address_len;
const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size; const size_t send_offset = should_retransmit ? 0 : m_send_window.sent_size;
const size_t total_send = BAN::Math::min<size_t>( const size_t total_send = BAN::Math::min<size_t>(
m_send_window.data_size - send_start_offset, m_send_window.buffer->size() - send_offset,
m_send_window.scaled_size() - send_start_offset m_send_window.scaled_size() - send_offset
); );
m_send_window.current_seq = m_send_window.start_seq + send_start_offset; m_send_window.current_seq = m_send_window.start_seq + send_offset;
const auto* send_base = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr());
for (size_t i = 0; i < total_send;) for (size_t i = 0; i < total_send;)
{ {
const size_t send_offset = (m_send_window.data_tail + send_start_offset + i) % m_send_window.buffer->size(); const size_t to_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss);
const size_t max_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss); auto message = m_send_window.buffer->get_data().slice(send_offset + i, to_send);
const size_t to_send = BAN::Math::min(max_send, m_send_window.buffer->size() - send_offset);
auto message = BAN::ConstByteSpan(send_base + send_offset, to_send);
m_next_flags = ACK; m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error()) if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())
@ -979,9 +935,10 @@ namespace Kernel
dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send); dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send);
m_send_window.sent_size += to_send;
m_send_window.current_seq += to_send;
i += to_send; i += to_send;
m_send_window.current_seq += to_send;
if (send_offset + i > m_send_window.sent_size)
m_send_window.sent_size = send_offset + i;
} }
m_send_window.last_send_ms = current_ms; m_send_window.last_send_ms = current_ms;
@ -989,13 +946,23 @@ namespace Kernel
continue; continue;
} }
if (const size_t ack_count = m_should_send_ack + m_should_send_zero_window) if (m_last_sent_window_size == 0)
m_should_send_zero_window = false;
if (m_should_send_zero_window || m_should_send_window_update)
{ {
ASSERT(m_connection_info.has_value()); ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address); auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len; auto target_address_len = m_connection_info->address_len;
for (size_t i = 0; i < ack_count; i++) m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
dwarnln("{}", ret.error());
m_should_send_zero_window = false;
m_should_send_window_update = false;
if (m_last_sent_window_size == 0 && !m_recv_window.buffer->full())
{ {
m_next_flags = ACK; m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())