Kernel: Cleanup and optimize TCP

We now only send enough data to fill other ends window, not past that.
Previous logic had a but that allowed sending too much data leading to
retransmissions.

When the target sends zero window and later updates window size,
immediately retransmit non-acknowledged bytes.

Don't validate packets through listeing socket twice. The actual socket
will already verify the checksum so the listening socket does not have
to.
This commit is contained in:
Bananymous 2026-02-24 14:10:15 +02:00
parent 2ea0a24795
commit ff378e4538
2 changed files with 78 additions and 51 deletions

View File

@ -117,6 +117,7 @@ namespace Kernel
uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout
bool has_ghost_byte { false }; bool has_ghost_byte { false };
bool had_zero_window { false };
uint32_t data_tail { 0 }; uint32_t data_tail { 0 };
uint32_t data_size { 0 }; // number of bytes in this buffer uint32_t data_size { 0 }; // number of bytes in this buffer
@ -179,6 +180,8 @@ namespace Kernel
bool m_keep_alive { false }; bool m_keep_alive { false };
bool m_no_delay { false }; bool m_no_delay { false };
bool m_should_send_ack { false };
uint64_t m_time_wait_start_ms { 0 }; uint64_t m_time_wait_start_ms { 0 };
ThreadBlocker m_thread_blocker; ThreadBlocker m_thread_blocker;

View File

@ -25,6 +25,9 @@ namespace Kernel
static constexpr size_t s_recv_window_buffer_size = 16 * PAGE_SIZE; static constexpr size_t s_recv_window_buffer_size = 16 * PAGE_SIZE;
static constexpr size_t s_send_window_buffer_size = 16 * PAGE_SIZE; static constexpr size_t s_send_window_buffer_size = 16 * PAGE_SIZE;
// allows upto 1 MiB windows
static constexpr uint8_t s_window_shift = 4;
// https://www.rfc-editor.org/rfc/rfc1122 4.2.2.6 // https://www.rfc-editor.org/rfc/rfc1122 4.2.2.6
static constexpr uint16_t s_default_mss = 536; static constexpr uint16_t s_default_mss = 536;
@ -40,7 +43,7 @@ namespace Kernel
PageTable::Flags::ReadWrite | PageTable::Flags::Present, PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false true, false
)); ));
socket->m_recv_window.scale_shift = PAGE_SIZE_SHIFT; // use PAGE_SIZE windows socket->m_recv_window.scale_shift = s_window_shift;
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range( socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(), PageTable::kernel(),
KERNEL_OFFSET, KERNEL_OFFSET,
@ -212,7 +215,7 @@ namespace Kernel
if (CMSG_FIRSTHDR(&message)) if (CMSG_FIRSTHDR(&message))
{ {
dwarnln("ignoring recvmsg control message"); dprintln_if(DEBUG_TCP, "ignoring recvmsg control message");
message.msg_controllen = 0; message.msg_controllen = 0;
} }
@ -249,13 +252,17 @@ namespace Kernel
break; break;
} }
const size_t update_window_threshold = m_recv_window.buffer->size() / 8;
const bool should_update_window_size = const bool should_update_window_size =
(m_last_sent_window_size == 0) || m_last_sent_window_size != m_recv_window.buffer->size() && (
(m_recv_window.data_size == 0) || (m_last_sent_window_size == 0) ||
(m_last_sent_window_size + PAGE_SIZE < m_recv_window.buffer->size() - m_recv_window.data_size); (m_recv_window.data_size == 0) ||
if (m_next_flags == 0 && should_update_window_size) (m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size)
);
if (should_update_window_size)
{ {
m_next_flags = ACK; m_should_send_ack = true;
m_thread_blocker.unblock(); m_thread_blocker.unblock();
} }
@ -562,6 +569,8 @@ namespace Kernel
pseudo_header.extra = packet.size(); pseudo_header.extra = packet.size();
header.checksum = calculate_internet_checksum(packet, pseudo_header); header.checksum = calculate_internet_checksum(packet, pseudo_header);
m_should_send_ack = false;
dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags); dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags);
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number); dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number); dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number);
@ -569,7 +578,19 @@ namespace Kernel
void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr* sender, socklen_t sender_len) void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr* sender, socklen_t sender_len)
{ {
(void)sender_len; if (m_state == State::Listen)
{
auto socket =
[&]() -> BAN::RefPtr<TCPSocket> {
LockGuard _(m_mutex);
if (auto it = m_listen_children.find(ListenKey(sender, sender_len)); it != m_listen_children.end())
return it->value;
return {};
}();
if (socket)
return socket->receive_packet(buffer, sender, sender_len);
}
{ {
uint16_t checksum = 0; uint16_t checksum = 0;
@ -609,11 +630,14 @@ namespace Kernel
const bool hungup_before = has_hungup_impl(); const bool hungup_before = has_hungup_impl();
auto& header = buffer.as<const TCPHeader>(); auto& header = buffer.as<const TCPHeader>();
dprintln_if(DEBUG_TCP, "receiving {} {8b}", (uint8_t)m_state, header.flags); dprintln_if(DEBUG_TCP, "receiving {} {8b}", (uint8_t)m_state, header.flags);
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number); dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number); dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number);
m_send_window.non_scaled_size = header.window_size; m_send_window.non_scaled_size = header.window_size;
if (m_send_window.scaled_size() == 0)
m_send_window.had_zero_window = true;
bool check_payload = false; bool check_payload = false;
switch (m_state) switch (m_state)
@ -657,41 +681,27 @@ namespace Kernel
m_has_connected = true; m_has_connected = true;
break; break;
case State::Listen: case State::Listen:
if (header.flags == SYN) if (header.flags != SYN)
{ dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket");
if (m_pending_connections.size() == m_pending_connections.capacity()) else if (m_pending_connections.size() == m_pending_connections.capacity())
dprintln_if(DEBUG_TCP, "No storage to store pending connection"); dprintln_if(DEBUG_TCP, "No storage to store pending connection");
else
{
const auto options = parse_tcp_options(header);
ConnectionInfo connection_info;
memcpy(&connection_info.address, sender, sender_len);
connection_info.address_len = sender_len;
connection_info.has_window_scale = options.window_scale.has_value();
MUST(m_pending_connections.emplace(
connection_info,
header.seq_number + 1,
options.maximum_seqment_size.value_or(s_default_mss),
options.window_scale.value_or(0)
));
epoll_notify(EPOLLIN);
m_thread_blocker.unblock();
}
}
else else
{ {
auto it = m_listen_children.find(ListenKey(sender, sender_len)); const auto options = parse_tcp_options(header);
if (it == m_listen_children.end())
{ ConnectionInfo connection_info;
dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket"); memcpy(&connection_info.address, sender, sender_len);
break; connection_info.address_len = sender_len;
} connection_info.has_window_scale = options.window_scale.has_value();
auto socket = it->value; MUST(m_pending_connections.emplace(
m_mutex.unlock(); connection_info,
socket->receive_packet(buffer, sender, sender_len); header.seq_number + 1,
m_mutex.lock(); options.maximum_seqment_size.value_or(s_default_mss),
options.window_scale.value_or(0)
));
epoll_notify(EPOLLIN);
m_thread_blocker.unblock();
} }
return; return;
case State::Established: case State::Established:
@ -747,6 +757,8 @@ namespace Kernel
break; break;
} }
// TODO: even without SACKs, if other end sends seq [0, 1000] and our current seq is 100, we should accept
// packet with seq [100, 1000]
if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte) if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte)
dprintln_if(DEBUG_TCP, "Missing packets"); dprintln_if(DEBUG_TCP, "Missing packets");
else if (check_payload) else if (check_payload)
@ -776,16 +788,12 @@ namespace Kernel
dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv); dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv);
if (m_next_flags == 0) m_should_send_ack = true;
{
m_next_flags = ACK;
m_next_state = m_state;
}
} }
// make sure zero window is reported // make sure zero window is reported
if (m_next_flags == 0 && m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size()) if (m_next_flags == 0 && m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size())
m_next_flags = ACK; m_should_send_ack = true;
} }
if (!hungup_before && has_hungup_impl()) if (!hungup_before && has_hungup_impl())
@ -888,7 +896,7 @@ namespace Kernel
continue; continue;
} }
if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq) if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
{ {
const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte; const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
ASSERT(acknowledged_bytes <= m_send_window.data_size); ASSERT(acknowledged_bytes <= m_send_window.data_size);
@ -905,17 +913,22 @@ namespace Kernel
continue; continue;
} }
const bool should_retransmit = m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms; const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
if (m_send_window.data_size > m_send_window.sent_size || should_retransmit) if (m_send_window.sent_size < m_send_window.scaled_size() && (should_retransmit || m_send_window.data_size > m_send_window.sent_size))
{ {
m_send_window.had_zero_window = false;
ASSERT(m_connection_info.has_value()); ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address); auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len; auto target_address_len = m_connection_info->address_len;
const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size; const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size;
const size_t total_send = BAN::Math::min<size_t>(m_send_window.data_size - send_start_offset, m_send_window.scaled_size()); const size_t total_send = BAN::Math::min<size_t>(
m_send_window.data_size - send_start_offset,
m_send_window.scaled_size() - m_send_window.sent_size
);
m_send_window.current_seq = m_send_window.start_seq + send_start_offset; m_send_window.current_seq = m_send_window.start_seq + send_start_offset;
@ -948,6 +961,17 @@ namespace Kernel
continue; continue;
} }
if (m_should_send_ack)
{
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
dwarnln("{}", ret.error());
}
m_thread_blocker.unblock(); m_thread_blocker.unblock();
m_thread_blocker.block_with_wake_time_ms(current_ms + retransmit_timeout_ms, &m_mutex); m_thread_blocker.block_with_wake_time_ms(current_ms + retransmit_timeout_ms, &m_mutex);
} }