Kernel: Cleanup and optimize TCP

We now only send enough data to fill other ends window, not past that.
Previous logic had a but that allowed sending too much data leading to
retransmissions.

When the target sends zero window and later updates window size,
immediately retransmit non-acknowledged bytes.

Don't validate packets through listeing socket twice. The actual socket
will already verify the checksum so the listening socket does not have
to.
This commit is contained in:
Bananymous 2026-02-24 14:10:15 +02:00
parent 2ea0a24795
commit ff378e4538
2 changed files with 78 additions and 51 deletions

View File

@ -117,6 +117,7 @@ namespace Kernel
uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout
bool has_ghost_byte { false };
bool had_zero_window { false };
uint32_t data_tail { 0 };
uint32_t data_size { 0 }; // number of bytes in this buffer
@ -179,6 +180,8 @@ namespace Kernel
bool m_keep_alive { false };
bool m_no_delay { false };
bool m_should_send_ack { false };
uint64_t m_time_wait_start_ms { 0 };
ThreadBlocker m_thread_blocker;

View File

@ -25,6 +25,9 @@ namespace Kernel
static constexpr size_t s_recv_window_buffer_size = 16 * PAGE_SIZE;
static constexpr size_t s_send_window_buffer_size = 16 * PAGE_SIZE;
// allows upto 1 MiB windows
static constexpr uint8_t s_window_shift = 4;
// https://www.rfc-editor.org/rfc/rfc1122 4.2.2.6
static constexpr uint16_t s_default_mss = 536;
@ -40,7 +43,7 @@ namespace Kernel
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false
));
socket->m_recv_window.scale_shift = PAGE_SIZE_SHIFT; // use PAGE_SIZE windows
socket->m_recv_window.scale_shift = s_window_shift;
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
@ -212,7 +215,7 @@ namespace Kernel
if (CMSG_FIRSTHDR(&message))
{
dwarnln("ignoring recvmsg control message");
dprintln_if(DEBUG_TCP, "ignoring recvmsg control message");
message.msg_controllen = 0;
}
@ -249,13 +252,17 @@ namespace Kernel
break;
}
const size_t update_window_threshold = m_recv_window.buffer->size() / 8;
const bool should_update_window_size =
m_last_sent_window_size != m_recv_window.buffer->size() && (
(m_last_sent_window_size == 0) ||
(m_recv_window.data_size == 0) ||
(m_last_sent_window_size + PAGE_SIZE < m_recv_window.buffer->size() - m_recv_window.data_size);
if (m_next_flags == 0 && should_update_window_size)
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size)
);
if (should_update_window_size)
{
m_next_flags = ACK;
m_should_send_ack = true;
m_thread_blocker.unblock();
}
@ -562,6 +569,8 @@ namespace Kernel
pseudo_header.extra = packet.size();
header.checksum = calculate_internet_checksum(packet, pseudo_header);
m_should_send_ack = false;
dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags);
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number);
@ -569,7 +578,19 @@ namespace Kernel
void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr* sender, socklen_t sender_len)
{
(void)sender_len;
if (m_state == State::Listen)
{
auto socket =
[&]() -> BAN::RefPtr<TCPSocket> {
LockGuard _(m_mutex);
if (auto it = m_listen_children.find(ListenKey(sender, sender_len)); it != m_listen_children.end())
return it->value;
return {};
}();
if (socket)
return socket->receive_packet(buffer, sender, sender_len);
}
{
uint16_t checksum = 0;
@ -609,11 +630,14 @@ namespace Kernel
const bool hungup_before = has_hungup_impl();
auto& header = buffer.as<const TCPHeader>();
dprintln_if(DEBUG_TCP, "receiving {} {8b}", (uint8_t)m_state, header.flags);
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number);
m_send_window.non_scaled_size = header.window_size;
if (m_send_window.scaled_size() == 0)
m_send_window.had_zero_window = true;
bool check_payload = false;
switch (m_state)
@ -657,9 +681,9 @@ namespace Kernel
m_has_connected = true;
break;
case State::Listen:
if (header.flags == SYN)
{
if (m_pending_connections.size() == m_pending_connections.capacity())
if (header.flags != SYN)
dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket");
else if (m_pending_connections.size() == m_pending_connections.capacity())
dprintln_if(DEBUG_TCP, "No storage to store pending connection");
else
{
@ -679,20 +703,6 @@ namespace Kernel
epoll_notify(EPOLLIN);
m_thread_blocker.unblock();
}
}
else
{
auto it = m_listen_children.find(ListenKey(sender, sender_len));
if (it == m_listen_children.end())
{
dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket");
break;
}
auto socket = it->value;
m_mutex.unlock();
socket->receive_packet(buffer, sender, sender_len);
m_mutex.lock();
}
return;
case State::Established:
check_payload = true;
@ -747,6 +757,8 @@ namespace Kernel
break;
}
// TODO: even without SACKs, if other end sends seq [0, 1000] and our current seq is 100, we should accept
// packet with seq [100, 1000]
if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte)
dprintln_if(DEBUG_TCP, "Missing packets");
else if (check_payload)
@ -776,16 +788,12 @@ namespace Kernel
dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv);
if (m_next_flags == 0)
{
m_next_flags = ACK;
m_next_state = m_state;
}
m_should_send_ack = true;
}
// make sure zero window is reported
if (m_next_flags == 0 && m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size())
m_next_flags = ACK;
m_should_send_ack = true;
}
if (!hungup_before && has_hungup_impl())
@ -888,7 +896,7 @@ namespace Kernel
continue;
}
if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
{
const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
ASSERT(acknowledged_bytes <= m_send_window.data_size);
@ -905,17 +913,22 @@ namespace Kernel
continue;
}
const bool should_retransmit = m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms;
const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
if (m_send_window.data_size > m_send_window.sent_size || should_retransmit)
if (m_send_window.sent_size < m_send_window.scaled_size() && (should_retransmit || m_send_window.data_size > m_send_window.sent_size))
{
m_send_window.had_zero_window = false;
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size;
const size_t total_send = BAN::Math::min<size_t>(m_send_window.data_size - send_start_offset, m_send_window.scaled_size());
const size_t total_send = BAN::Math::min<size_t>(
m_send_window.data_size - send_start_offset,
m_send_window.scaled_size() - m_send_window.sent_size
);
m_send_window.current_seq = m_send_window.start_seq + send_start_offset;
@ -948,6 +961,17 @@ namespace Kernel
continue;
}
if (m_should_send_ack)
{
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
dwarnln("{}", ret.error());
}
m_thread_blocker.unblock();
m_thread_blocker.block_with_wake_time_ms(current_ms + retransmit_timeout_ms, &m_mutex);
}