Kernel: Cleanup and optimize TCP
We now only send enough data to fill other ends window, not past that. Previous logic had a but that allowed sending too much data leading to retransmissions. When the target sends zero window and later updates window size, immediately retransmit non-acknowledged bytes. Don't validate packets through listeing socket twice. The actual socket will already verify the checksum so the listening socket does not have to.
This commit is contained in:
parent
2ea0a24795
commit
ff378e4538
|
|
@ -117,6 +117,7 @@ namespace Kernel
|
|||
uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout
|
||||
|
||||
bool has_ghost_byte { false };
|
||||
bool had_zero_window { false };
|
||||
|
||||
uint32_t data_tail { 0 };
|
||||
uint32_t data_size { 0 }; // number of bytes in this buffer
|
||||
|
|
@ -179,6 +180,8 @@ namespace Kernel
|
|||
bool m_keep_alive { false };
|
||||
bool m_no_delay { false };
|
||||
|
||||
bool m_should_send_ack { false };
|
||||
|
||||
uint64_t m_time_wait_start_ms { 0 };
|
||||
|
||||
ThreadBlocker m_thread_blocker;
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ namespace Kernel
|
|||
static constexpr size_t s_recv_window_buffer_size = 16 * PAGE_SIZE;
|
||||
static constexpr size_t s_send_window_buffer_size = 16 * PAGE_SIZE;
|
||||
|
||||
// allows upto 1 MiB windows
|
||||
static constexpr uint8_t s_window_shift = 4;
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc1122 4.2.2.6
|
||||
static constexpr uint16_t s_default_mss = 536;
|
||||
|
||||
|
|
@ -40,7 +43,7 @@ namespace Kernel
|
|||
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
||||
true, false
|
||||
));
|
||||
socket->m_recv_window.scale_shift = PAGE_SIZE_SHIFT; // use PAGE_SIZE windows
|
||||
socket->m_recv_window.scale_shift = s_window_shift;
|
||||
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
|
||||
PageTable::kernel(),
|
||||
KERNEL_OFFSET,
|
||||
|
|
@ -212,7 +215,7 @@ namespace Kernel
|
|||
|
||||
if (CMSG_FIRSTHDR(&message))
|
||||
{
|
||||
dwarnln("ignoring recvmsg control message");
|
||||
dprintln_if(DEBUG_TCP, "ignoring recvmsg control message");
|
||||
message.msg_controllen = 0;
|
||||
}
|
||||
|
||||
|
|
@ -249,13 +252,17 @@ namespace Kernel
|
|||
break;
|
||||
}
|
||||
|
||||
const size_t update_window_threshold = m_recv_window.buffer->size() / 8;
|
||||
const bool should_update_window_size =
|
||||
(m_last_sent_window_size == 0) ||
|
||||
(m_recv_window.data_size == 0) ||
|
||||
(m_last_sent_window_size + PAGE_SIZE < m_recv_window.buffer->size() - m_recv_window.data_size);
|
||||
if (m_next_flags == 0 && should_update_window_size)
|
||||
m_last_sent_window_size != m_recv_window.buffer->size() && (
|
||||
(m_last_sent_window_size == 0) ||
|
||||
(m_recv_window.data_size == 0) ||
|
||||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size)
|
||||
);
|
||||
|
||||
if (should_update_window_size)
|
||||
{
|
||||
m_next_flags = ACK;
|
||||
m_should_send_ack = true;
|
||||
m_thread_blocker.unblock();
|
||||
}
|
||||
|
||||
|
|
@ -562,6 +569,8 @@ namespace Kernel
|
|||
pseudo_header.extra = packet.size();
|
||||
header.checksum = calculate_internet_checksum(packet, pseudo_header);
|
||||
|
||||
m_should_send_ack = false;
|
||||
|
||||
dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags);
|
||||
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
|
||||
dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number);
|
||||
|
|
@ -569,7 +578,19 @@ namespace Kernel
|
|||
|
||||
void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr* sender, socklen_t sender_len)
|
||||
{
|
||||
(void)sender_len;
|
||||
if (m_state == State::Listen)
|
||||
{
|
||||
auto socket =
|
||||
[&]() -> BAN::RefPtr<TCPSocket> {
|
||||
LockGuard _(m_mutex);
|
||||
if (auto it = m_listen_children.find(ListenKey(sender, sender_len)); it != m_listen_children.end())
|
||||
return it->value;
|
||||
return {};
|
||||
}();
|
||||
|
||||
if (socket)
|
||||
return socket->receive_packet(buffer, sender, sender_len);
|
||||
}
|
||||
|
||||
{
|
||||
uint16_t checksum = 0;
|
||||
|
|
@ -609,11 +630,14 @@ namespace Kernel
|
|||
const bool hungup_before = has_hungup_impl();
|
||||
|
||||
auto& header = buffer.as<const TCPHeader>();
|
||||
|
||||
dprintln_if(DEBUG_TCP, "receiving {} {8b}", (uint8_t)m_state, header.flags);
|
||||
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
|
||||
dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number);
|
||||
|
||||
m_send_window.non_scaled_size = header.window_size;
|
||||
if (m_send_window.scaled_size() == 0)
|
||||
m_send_window.had_zero_window = true;
|
||||
|
||||
bool check_payload = false;
|
||||
switch (m_state)
|
||||
|
|
@ -657,41 +681,27 @@ namespace Kernel
|
|||
m_has_connected = true;
|
||||
break;
|
||||
case State::Listen:
|
||||
if (header.flags == SYN)
|
||||
{
|
||||
if (m_pending_connections.size() == m_pending_connections.capacity())
|
||||
dprintln_if(DEBUG_TCP, "No storage to store pending connection");
|
||||
else
|
||||
{
|
||||
const auto options = parse_tcp_options(header);
|
||||
|
||||
ConnectionInfo connection_info;
|
||||
memcpy(&connection_info.address, sender, sender_len);
|
||||
connection_info.address_len = sender_len;
|
||||
connection_info.has_window_scale = options.window_scale.has_value();
|
||||
MUST(m_pending_connections.emplace(
|
||||
connection_info,
|
||||
header.seq_number + 1,
|
||||
options.maximum_seqment_size.value_or(s_default_mss),
|
||||
options.window_scale.value_or(0)
|
||||
));
|
||||
|
||||
epoll_notify(EPOLLIN);
|
||||
m_thread_blocker.unblock();
|
||||
}
|
||||
}
|
||||
if (header.flags != SYN)
|
||||
dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket");
|
||||
else if (m_pending_connections.size() == m_pending_connections.capacity())
|
||||
dprintln_if(DEBUG_TCP, "No storage to store pending connection");
|
||||
else
|
||||
{
|
||||
auto it = m_listen_children.find(ListenKey(sender, sender_len));
|
||||
if (it == m_listen_children.end())
|
||||
{
|
||||
dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket");
|
||||
break;
|
||||
}
|
||||
auto socket = it->value;
|
||||
m_mutex.unlock();
|
||||
socket->receive_packet(buffer, sender, sender_len);
|
||||
m_mutex.lock();
|
||||
const auto options = parse_tcp_options(header);
|
||||
|
||||
ConnectionInfo connection_info;
|
||||
memcpy(&connection_info.address, sender, sender_len);
|
||||
connection_info.address_len = sender_len;
|
||||
connection_info.has_window_scale = options.window_scale.has_value();
|
||||
MUST(m_pending_connections.emplace(
|
||||
connection_info,
|
||||
header.seq_number + 1,
|
||||
options.maximum_seqment_size.value_or(s_default_mss),
|
||||
options.window_scale.value_or(0)
|
||||
));
|
||||
|
||||
epoll_notify(EPOLLIN);
|
||||
m_thread_blocker.unblock();
|
||||
}
|
||||
return;
|
||||
case State::Established:
|
||||
|
|
@ -747,6 +757,8 @@ namespace Kernel
|
|||
break;
|
||||
}
|
||||
|
||||
// TODO: even without SACKs, if other end sends seq [0, 1000] and our current seq is 100, we should accept
|
||||
// packet with seq [100, 1000]
|
||||
if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte)
|
||||
dprintln_if(DEBUG_TCP, "Missing packets");
|
||||
else if (check_payload)
|
||||
|
|
@ -776,16 +788,12 @@ namespace Kernel
|
|||
|
||||
dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv);
|
||||
|
||||
if (m_next_flags == 0)
|
||||
{
|
||||
m_next_flags = ACK;
|
||||
m_next_state = m_state;
|
||||
}
|
||||
m_should_send_ack = true;
|
||||
}
|
||||
|
||||
// make sure zero window is reported
|
||||
if (m_next_flags == 0 && m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size())
|
||||
m_next_flags = ACK;
|
||||
m_should_send_ack = true;
|
||||
}
|
||||
|
||||
if (!hungup_before && has_hungup_impl())
|
||||
|
|
@ -888,7 +896,7 @@ namespace Kernel
|
|||
continue;
|
||||
}
|
||||
|
||||
if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
|
||||
if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
|
||||
{
|
||||
const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
|
||||
ASSERT(acknowledged_bytes <= m_send_window.data_size);
|
||||
|
|
@ -905,17 +913,22 @@ namespace Kernel
|
|||
continue;
|
||||
}
|
||||
|
||||
const bool should_retransmit = m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms;
|
||||
const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
|
||||
|
||||
if (m_send_window.data_size > m_send_window.sent_size || should_retransmit)
|
||||
if (m_send_window.sent_size < m_send_window.scaled_size() && (should_retransmit || m_send_window.data_size > m_send_window.sent_size))
|
||||
{
|
||||
m_send_window.had_zero_window = false;
|
||||
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size;
|
||||
|
||||
const size_t total_send = BAN::Math::min<size_t>(m_send_window.data_size - send_start_offset, m_send_window.scaled_size());
|
||||
const size_t total_send = BAN::Math::min<size_t>(
|
||||
m_send_window.data_size - send_start_offset,
|
||||
m_send_window.scaled_size() - m_send_window.sent_size
|
||||
);
|
||||
|
||||
m_send_window.current_seq = m_send_window.start_seq + send_start_offset;
|
||||
|
||||
|
|
@ -948,6 +961,17 @@ namespace Kernel
|
|||
continue;
|
||||
}
|
||||
|
||||
if (m_should_send_ack)
|
||||
{
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
dwarnln("{}", ret.error());
|
||||
}
|
||||
|
||||
m_thread_blocker.unblock();
|
||||
m_thread_blocker.block_with_wake_time_ms(current_ms + retransmit_timeout_ms, &m_mutex);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue