Kernel: Make TCP sockets use the new ring buffer

Also fix race condition that sometimes prevented window updates not
being sent after zero window effectively hanging the whole socket
This commit is contained in:
2026-02-28 14:22:08 +02:00
parent 8b8af1a9d9
commit 812ae77cd7
2 changed files with 84 additions and 121 deletions

View File

@@ -35,23 +35,9 @@ namespace Kernel
{
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, info));
socket->m_last_sent_window_size = s_recv_window_buffer_size;
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(vaddr_t)0,
s_recv_window_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false
));
socket->m_recv_window.buffer = TRY(ByteRingBuffer::create(s_recv_window_buffer_size));
socket->m_recv_window.scale_shift = s_window_shift;
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(vaddr_t)0,
s_send_window_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false
));
socket->m_send_window.buffer = TRY(ByteRingBuffer::create(s_send_window_buffer_size));
socket->m_thread = TRY(Thread::create_kernel(
[](void* socket_ptr)
{
@@ -206,8 +192,7 @@ namespace Kernel
BAN::ErrorOr<size_t> TCPSocket::recvmsg_impl(msghdr& message, int flags)
{
flags &= (MSG_OOB | MSG_PEEK | MSG_WAITALL);
if (flags != 0)
if (flags & ~(MSG_PEEK))
{
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
@@ -222,7 +207,7 @@ namespace Kernel
if (!m_has_connected)
return BAN::Error::from_errno(ENOTCONN);
while (m_recv_window.data_size == 0)
while (m_recv_window.buffer->empty())
{
if (m_state != State::Established)
return return_with_maybe_zero();
@@ -232,37 +217,33 @@ namespace Kernel
message.msg_flags = 0;
size_t total_recv = 0;
for (int i = 0; i < message.msg_iovlen; i++)
for (int i = 0; i < message.msg_iovlen && total_recv < m_recv_window.buffer->size(); i++)
{
const auto* recv_base = reinterpret_cast<const uint8_t*>(m_recv_window.buffer->vaddr());
uint8_t* iov_base = static_cast<uint8_t*>(message.msg_iov[i].iov_base);
auto& iov = message.msg_iov[i];
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_recv_window.data_size);
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - m_recv_window.data_tail);
memcpy(iov_base, recv_base + m_recv_window.data_tail, before_wrap);
if (const size_t after_wrap = nrecv - before_wrap)
memcpy(iov_base + before_wrap, recv_base, after_wrap);
const size_t nrecv = BAN::Math::min(iov.iov_len, m_recv_window.buffer->size() - total_recv);
memcpy(iov.iov_base, m_recv_window.buffer->get_data().data() + total_recv, nrecv);
total_recv += nrecv;
m_recv_window.data_size -= nrecv;
m_recv_window.start_seq += nrecv;
m_recv_window.data_tail = (m_recv_window.data_tail + nrecv) % m_recv_window.buffer->size();
if (m_recv_window.data_size == 0)
break;
}
const size_t update_window_threshold = m_recv_window.buffer->size() / 8;
if (!(flags & MSG_PEEK))
{
m_recv_window.buffer->pop(total_recv);
m_recv_window.start_seq += total_recv;
}
const size_t update_window_threshold = m_recv_window.buffer->capacity() / 8;
const bool should_update_window_size =
m_last_sent_window_size != m_recv_window.buffer->size() && (
m_last_sent_window_size != m_recv_window.buffer->capacity() && (
(m_last_sent_window_size == 0) ||
(m_recv_window.data_size == 0) ||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size)
(m_recv_window.buffer->empty()) ||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->free())
);
if (should_update_window_size)
if (should_update_window_size || m_should_send_zero_window)
{
m_should_send_ack = true;
m_should_send_window_update = true;
m_thread_blocker.unblock();
}
@@ -283,7 +264,7 @@ namespace Kernel
if (!m_has_connected)
return BAN::Error::from_errno(ENOTCONN);
while (m_send_window.data_size == m_send_window.buffer->size())
while (m_send_window.buffer->full())
{
if (m_state != State::Established)
return return_with_maybe_zero();
@@ -293,23 +274,14 @@ namespace Kernel
}
size_t total_sent = 0;
for (int i = 0; i < message.msg_iovlen; i++)
for (int i = 0; i < message.msg_iovlen && !m_send_window.buffer->full(); i++)
{
auto* send_base = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
const auto* iov_base = static_cast<const uint8_t*>(message.msg_iov[i].iov_base);
const auto& iov = message.msg_iov[i];
const size_t nsend = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size);
const size_t send_head = (m_send_window.data_tail + m_send_window.data_size) % m_send_window.buffer->size();
const size_t before_wrap = BAN::Math::min(nsend, m_send_window.buffer->size() - send_head);
memcpy(send_base + send_head, message.msg_iov[i].iov_base, before_wrap);
if (const size_t after_wrap = nsend - before_wrap)
memcpy(send_base, iov_base + before_wrap, after_wrap);
const size_t nsend = BAN::Math::min(iov.iov_len, m_send_window.buffer->free());
m_send_window.buffer->push({ static_cast<const uint8_t*>(iov.iov_base), nsend });
total_sent += nsend;
m_send_window.data_size += nsend;
if (m_send_window.data_size == m_send_window.buffer->size())
break;
}
m_thread_blocker.unblock();
@@ -347,7 +319,7 @@ namespace Kernel
result = m_send_window.scaled_size();
break;
case SO_RCVBUF:
result = m_recv_window.buffer->size();
result = m_recv_window.buffer->capacity();
break;
default:
dwarnln("getsockopt(SOL_SOCKET, {})", option);
@@ -420,7 +392,7 @@ namespace Kernel
switch (request)
{
case FIONREAD:
*static_cast<int*>(argument) = m_recv_window.data_size;
*static_cast<int*>(argument) = m_recv_window.buffer->size();
return 0;
}
@@ -433,14 +405,14 @@ namespace Kernel
return true;
if (m_state == State::Listen)
return !m_pending_connections.empty();
return m_recv_window.data_size > 0;
return !m_recv_window.buffer->empty();
}
bool TCPSocket::can_write_impl() const
{
if (m_state != State::Established)
return false;
return m_send_window.data_size < m_send_window.buffer->size();
return !m_send_window.buffer->full();
}
bool TCPSocket::has_hungup_impl() const
@@ -530,19 +502,14 @@ namespace Kernel
ASSERT(m_mutex.locker() == Thread::current().tid());
ASSERT(header_buffer.size() == protocol_header_size());
m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size;
if (m_should_send_zero_window)
m_last_sent_window_size = 0;
m_should_send_ack = false;
m_should_send_zero_window = false;
m_last_sent_window_size = m_should_send_zero_window ? 0 : m_recv_window.buffer->free();
auto& header = header_buffer.as<TCPHeader>();
header = {
.src_port = bound_port(),
.dst_port = dst_port,
.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte,
.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte,
.ack_number = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte,
.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t),
.flags = m_next_flags,
.window_size = BAN::Math::min<size_t>(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift),
@@ -569,7 +536,7 @@ namespace Kernel
if (m_connection_info->has_window_scale)
add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift);
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->size());
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->capacity());
m_send_window.start_seq++;
m_send_window.current_seq = m_send_window.start_seq;
@@ -722,7 +689,7 @@ namespace Kernel
check_payload = true;
if (!(header.flags & FIN))
break;
if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number)
if (m_recv_window.start_seq + m_recv_window.buffer->size() != header.seq_number)
break;
m_next_flags = FIN | ACK;
m_next_state = State::LastAck;
@@ -771,7 +738,7 @@ namespace Kernel
break;
}
const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte;
if (header.seq_number > expected_seq)
dprintln_if(DEBUG_TCP, "Missing packets");
@@ -794,21 +761,12 @@ namespace Kernel
payload = {};
}
const bool can_receive_new_data = (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size());
const bool can_receive_new_data = (payload.size() > 0 && !m_recv_window.buffer->full());
if (can_receive_new_data)
{
auto* recv_base = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->size() - m_recv_window.data_size);
const size_t recv_head = (m_recv_window.data_tail + m_recv_window.data_size) % m_recv_window.buffer->size();
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - recv_head);
memcpy(recv_base + recv_head, payload.data(), before_wrap);
if (const size_t after_wrap = nrecv - before_wrap)
memcpy(recv_base, payload.data() + before_wrap, after_wrap);
m_recv_window.data_size += nrecv;
const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->free());
m_recv_window.buffer->push(payload.slice(0, nrecv));
epoll_notify(EPOLLIN);
@@ -816,10 +774,13 @@ namespace Kernel
}
// make sure zero window is reported
if (m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size())
if (m_last_sent_window_size > 0 && m_recv_window.buffer->full())
m_should_send_zero_window = true;
else if (can_receive_new_data)
m_should_send_ack = true;
{
m_next_flags = ACK;
m_next_state = m_state;
}
}
if (!hungup_before && has_hungup_impl())
@@ -925,12 +886,11 @@ namespace Kernel
if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
{
const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
ASSERT(acknowledged_bytes <= m_send_window.data_size);
ASSERT(acknowledged_bytes <= m_send_window.buffer->size());
m_send_window.data_size -= acknowledged_bytes;
m_send_window.start_seq += acknowledged_bytes;
m_send_window.sent_size -= acknowledged_bytes;
m_send_window.data_tail = (m_send_window.data_tail + acknowledged_bytes) % m_send_window.buffer->size();
m_send_window.buffer->pop(acknowledged_bytes);
epoll_notify(EPOLLOUT);
@@ -941,7 +901,7 @@ namespace Kernel
const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
const bool can_send_new_data = (m_send_window.data_size > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
const bool can_send_new_data = (m_send_window.buffer->size() > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data))
{
@@ -951,24 +911,20 @@ namespace Kernel
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size;
const size_t send_offset = should_retransmit ? 0 : m_send_window.sent_size;
const size_t total_send = BAN::Math::min<size_t>(
m_send_window.data_size - send_start_offset,
m_send_window.scaled_size() - send_start_offset
m_send_window.buffer->size() - send_offset,
m_send_window.scaled_size() - send_offset
);
m_send_window.current_seq = m_send_window.start_seq + send_start_offset;
m_send_window.current_seq = m_send_window.start_seq + send_offset;
const auto* send_base = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr());
for (size_t i = 0; i < total_send;)
{
const size_t send_offset = (m_send_window.data_tail + send_start_offset + i) % m_send_window.buffer->size();
const size_t to_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss);
const size_t max_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss);
const size_t to_send = BAN::Math::min(max_send, m_send_window.buffer->size() - send_offset);
auto message = BAN::ConstByteSpan(send_base + send_offset, to_send);
auto message = m_send_window.buffer->get_data().slice(send_offset + i, to_send);
m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())
@@ -979,9 +935,10 @@ namespace Kernel
dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send);
m_send_window.sent_size += to_send;
m_send_window.current_seq += to_send;
i += to_send;
m_send_window.current_seq += to_send;
if (send_offset + i > m_send_window.sent_size)
m_send_window.sent_size = send_offset + i;
}
m_send_window.last_send_ms = current_ms;
@@ -989,13 +946,23 @@ namespace Kernel
continue;
}
if (const size_t ack_count = m_should_send_ack + m_should_send_zero_window)
if (m_last_sent_window_size == 0)
m_should_send_zero_window = false;
if (m_should_send_zero_window || m_should_send_window_update)
{
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
for (size_t i = 0; i < ack_count; i++)
m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
dwarnln("{}", ret.error());
m_should_send_zero_window = false;
m_should_send_window_update = false;
if (m_last_sent_window_size == 0 && !m_recv_window.buffer->full())
{
m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())