From ff378e4538ae28c6e5fd3f3dfa4c8dd7a5061e4a Mon Sep 17 00:00:00 2001 From: Bananymous Date: Tue, 24 Feb 2026 14:10:15 +0200 Subject: [PATCH] Kernel: Cleanup and optimize TCP We now only send enough data to fill other ends window, not past that. Previous logic had a but that allowed sending too much data leading to retransmissions. When the target sends zero window and later updates window size, immediately retransmit non-acknowledged bytes. Don't validate packets through listeing socket twice. The actual socket will already verify the checksum so the listening socket does not have to. --- kernel/include/kernel/Networking/TCPSocket.h | 3 + kernel/kernel/Networking/TCPSocket.cpp | 126 +++++++++++-------- 2 files changed, 78 insertions(+), 51 deletions(-) diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index f83325c1..221080d8 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -117,6 +117,7 @@ namespace Kernel uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout bool has_ghost_byte { false }; + bool had_zero_window { false }; uint32_t data_tail { 0 }; uint32_t data_size { 0 }; // number of bytes in this buffer @@ -179,6 +180,8 @@ namespace Kernel bool m_keep_alive { false }; bool m_no_delay { false }; + bool m_should_send_ack { false }; + uint64_t m_time_wait_start_ms { 0 }; ThreadBlocker m_thread_blocker; diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index a09b0a8c..cd83846a 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -25,6 +25,9 @@ namespace Kernel static constexpr size_t s_recv_window_buffer_size = 16 * PAGE_SIZE; static constexpr size_t s_send_window_buffer_size = 16 * PAGE_SIZE; + // allows upto 1 MiB windows + static constexpr uint8_t s_window_shift = 4; + // https://www.rfc-editor.org/rfc/rfc1122 4.2.2.6 static constexpr uint16_t s_default_mss = 536; @@ -40,7 +43,7 @@ namespace Kernel PageTable::Flags::ReadWrite | PageTable::Flags::Present, true, false )); - socket->m_recv_window.scale_shift = PAGE_SIZE_SHIFT; // use PAGE_SIZE windows + socket->m_recv_window.scale_shift = s_window_shift; socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range( PageTable::kernel(), KERNEL_OFFSET, @@ -212,7 +215,7 @@ namespace Kernel if (CMSG_FIRSTHDR(&message)) { - dwarnln("ignoring recvmsg control message"); + dprintln_if(DEBUG_TCP, "ignoring recvmsg control message"); message.msg_controllen = 0; } @@ -249,13 +252,17 @@ namespace Kernel break; } + const size_t update_window_threshold = m_recv_window.buffer->size() / 8; const bool should_update_window_size = - (m_last_sent_window_size == 0) || - (m_recv_window.data_size == 0) || - (m_last_sent_window_size + PAGE_SIZE < m_recv_window.buffer->size() - m_recv_window.data_size); - if (m_next_flags == 0 && should_update_window_size) + m_last_sent_window_size != m_recv_window.buffer->size() && ( + (m_last_sent_window_size == 0) || + (m_recv_window.data_size == 0) || + (m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size) + ); + + if (should_update_window_size) { - m_next_flags = ACK; + m_should_send_ack = true; m_thread_blocker.unblock(); } @@ -562,6 +569,8 @@ namespace Kernel pseudo_header.extra = packet.size(); header.checksum = calculate_internet_checksum(packet, pseudo_header); + m_should_send_ack = false; + dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags); dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number); dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number); @@ -569,7 +578,19 @@ namespace Kernel void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr* sender, socklen_t sender_len) { - (void)sender_len; + if (m_state == State::Listen) + { + auto socket = + [&]() -> BAN::RefPtr { + LockGuard _(m_mutex); + if (auto it = m_listen_children.find(ListenKey(sender, sender_len)); it != m_listen_children.end()) + return it->value; + return {}; + }(); + + if (socket) + return socket->receive_packet(buffer, sender, sender_len); + } { uint16_t checksum = 0; @@ -609,11 +630,14 @@ namespace Kernel const bool hungup_before = has_hungup_impl(); auto& header = buffer.as(); + dprintln_if(DEBUG_TCP, "receiving {} {8b}", (uint8_t)m_state, header.flags); dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number); dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number); m_send_window.non_scaled_size = header.window_size; + if (m_send_window.scaled_size() == 0) + m_send_window.had_zero_window = true; bool check_payload = false; switch (m_state) @@ -657,41 +681,27 @@ namespace Kernel m_has_connected = true; break; case State::Listen: - if (header.flags == SYN) - { - if (m_pending_connections.size() == m_pending_connections.capacity()) - dprintln_if(DEBUG_TCP, "No storage to store pending connection"); - else - { - const auto options = parse_tcp_options(header); - - ConnectionInfo connection_info; - memcpy(&connection_info.address, sender, sender_len); - connection_info.address_len = sender_len; - connection_info.has_window_scale = options.window_scale.has_value(); - MUST(m_pending_connections.emplace( - connection_info, - header.seq_number + 1, - options.maximum_seqment_size.value_or(s_default_mss), - options.window_scale.value_or(0) - )); - - epoll_notify(EPOLLIN); - m_thread_blocker.unblock(); - } - } + if (header.flags != SYN) + dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket"); + else if (m_pending_connections.size() == m_pending_connections.capacity()) + dprintln_if(DEBUG_TCP, "No storage to store pending connection"); else { - auto it = m_listen_children.find(ListenKey(sender, sender_len)); - if (it == m_listen_children.end()) - { - dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket"); - break; - } - auto socket = it->value; - m_mutex.unlock(); - socket->receive_packet(buffer, sender, sender_len); - m_mutex.lock(); + const auto options = parse_tcp_options(header); + + ConnectionInfo connection_info; + memcpy(&connection_info.address, sender, sender_len); + connection_info.address_len = sender_len; + connection_info.has_window_scale = options.window_scale.has_value(); + MUST(m_pending_connections.emplace( + connection_info, + header.seq_number + 1, + options.maximum_seqment_size.value_or(s_default_mss), + options.window_scale.value_or(0) + )); + + epoll_notify(EPOLLIN); + m_thread_blocker.unblock(); } return; case State::Established: @@ -747,6 +757,8 @@ namespace Kernel break; } + // TODO: even without SACKs, if other end sends seq [0, 1000] and our current seq is 100, we should accept + // packet with seq [100, 1000] if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte) dprintln_if(DEBUG_TCP, "Missing packets"); else if (check_payload) @@ -776,16 +788,12 @@ namespace Kernel dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv); - if (m_next_flags == 0) - { - m_next_flags = ACK; - m_next_state = m_state; - } + m_should_send_ack = true; } // make sure zero window is reported if (m_next_flags == 0 && m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size()) - m_next_flags = ACK; + m_should_send_ack = true; } if (!hungup_before && has_hungup_impl()) @@ -888,7 +896,7 @@ namespace Kernel continue; } - if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq) + if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq) { const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte; ASSERT(acknowledged_bytes <= m_send_window.data_size); @@ -905,17 +913,22 @@ namespace Kernel continue; } - const bool should_retransmit = m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms; + const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms); - if (m_send_window.data_size > m_send_window.sent_size || should_retransmit) + if (m_send_window.sent_size < m_send_window.scaled_size() && (should_retransmit || m_send_window.data_size > m_send_window.sent_size)) { + m_send_window.had_zero_window = false; + ASSERT(m_connection_info.has_value()); auto* target_address = reinterpret_cast(&m_connection_info->address); auto target_address_len = m_connection_info->address_len; const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size; - const size_t total_send = BAN::Math::min(m_send_window.data_size - send_start_offset, m_send_window.scaled_size()); + const size_t total_send = BAN::Math::min( + m_send_window.data_size - send_start_offset, + m_send_window.scaled_size() - m_send_window.sent_size + ); m_send_window.current_seq = m_send_window.start_seq + send_start_offset; @@ -948,6 +961,17 @@ namespace Kernel continue; } + if (m_should_send_ack) + { + ASSERT(m_connection_info.has_value()); + auto* target_address = reinterpret_cast(&m_connection_info->address); + auto target_address_len = m_connection_info->address_len; + + m_next_flags = ACK; + if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) + dwarnln("{}", ret.error()); + } + m_thread_blocker.unblock(); m_thread_blocker.block_with_wake_time_ms(current_ms + retransmit_timeout_ms, &m_mutex); }