diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index 8df8fa7e76..068f39a4f2 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -73,17 +73,33 @@ namespace Kernel TimeWait, }; - struct WindowInfo + struct RecvWindowInfo { - uint32_t mss { 0 }; - uint16_t size { 0 }; - uint8_t scale { 0 }; - uint32_t start_seq { 0 }; - uint32_t current_seq { 0 }; - BAN::Atomic ack_number { 0 }; - uint32_t data_size { 0 }; - uint64_t send_time_ms { 0 }; - BAN::UniqPtr window; + uint32_t start_seq { 0 }; // sequence number of first byte in buffer + + bool has_ghost_byte { false }; + + uint32_t data_size { 0 }; // number of bytes in this buffer + BAN::UniqPtr buffer; + }; + + struct SendWindowInfo + { + uint32_t mss { 0 }; // maximum segment size + uint16_t non_scaled_size { 0 }; // window size without scaling + uint8_t scale { 0 }; // window scale + uint32_t scaled_size() const { return (uint32_t)non_scaled_size << scale; } + + uint32_t start_seq { 0 }; // sequence number of first byte in buffer + uint32_t current_seq { 0 }; // sequence number of next send + uint32_t current_ack { 0 }; // sequence number aknowledged by connection + + uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout + + bool has_ghost_byte { false }; + + uint32_t data_size { 0 }; // number of bytes in this buffer + BAN::UniqPtr buffer; }; private: @@ -104,8 +120,8 @@ namespace Kernel BAN::Atomic m_should_ack { false }; - WindowInfo m_recv_window; - WindowInfo m_send_window; + RecvWindowInfo m_recv_window; + SendWindowInfo m_send_window; struct ConnectionInfo { diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index 327bd8d664..7cda7baa7b 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -27,7 +27,7 @@ namespace Kernel if (socket_ptr == nullptr) return BAN::Error::from_errno(ENOMEM); auto socket = BAN::RefPtr::adopt(socket_ptr); - socket->m_recv_window.window = TRY(VirtualRange::create_to_vaddr_range( + socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range( PageTable::kernel(), KERNEL_OFFSET, ~(vaddr_t)0, @@ -35,7 +35,7 @@ namespace Kernel PageTable::Flags::ReadWrite | PageTable::Flags::Present, true )); - socket->m_send_window.window = TRY(VirtualRange::create_to_vaddr_range( + socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range( PageTable::kernel(), KERNEL_OFFSET, ~(vaddr_t)0, @@ -43,8 +43,6 @@ namespace Kernel PageTable::Flags::ReadWrite | PageTable::Flags::Present, true )); - socket->m_recv_window.size = socket->m_recv_window.window->size(); - socket->m_recv_window.scale = 0; socket->m_process = Process::create_kernel( [](void* socket_ptr) { @@ -58,7 +56,6 @@ namespace Kernel : NetworkSocket(network_layer, ino, inode_info) { m_send_window.start_seq = Random::get_u32() & 0x7FFFFFFF; - m_send_window.ack_number = m_send_window.start_seq; m_send_window.current_seq = m_send_window.start_seq; } @@ -95,13 +92,8 @@ namespace Kernel case State::Listen: ASSERT_NOT_REACHED(); } - ASSERT(m_connection_info.has_value()); - auto* target_address = reinterpret_cast(&m_connection_info->address); - auto target_address_len = m_connection_info->address_len; - m_state = State::FinWait1; - if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) - dwarnln("{}", ret.error()); + m_should_ack = true; dprintln_if(DEBUG_TCP, "Initiated close"); } @@ -140,8 +132,6 @@ namespace Kernel m_connection_info.emplace(sockaddr_storage {}, address_len); memcpy(&m_connection_info->address, address, address_len); - m_recv_window.mss = m_interface->payload_mtu() - m_network_layer.header_size(); - TRY(m_network_layer.sendto(*this, {}, address, address_len)); ASSERT(m_state == State::SynSent); dprintln_if(DEBUG_TCP, "Sent SYN"); @@ -211,10 +201,12 @@ namespace Kernel header.dst_port = dst_port; header.src_port = m_port; - header.seq_number = m_send_window.current_seq; - header.ack_number = m_recv_window.ack_number.load(); + header.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte; + header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte; header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t); - header.window_size = m_recv_window.window->size(); + header.window_size = m_recv_window.buffer->size(); + + ASSERT(m_recv_window.buffer->size() < 1 << (8 * sizeof(header.window_size))); switch (m_state) { @@ -222,9 +214,11 @@ namespace Kernel { LockGuard _(m_lock); header.syn = 1; - add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_recv_window.mss); - add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale); + add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_interface->payload_mtu() - m_network_layer.header_size()); + add_tcp_header_option<4, TCPOption::WindowScale>(header, 0); m_state = State::SynSent; + m_send_window.start_seq++; + m_send_window.current_seq = m_send_window.start_seq; break; } case State::SynSent: @@ -242,9 +236,8 @@ namespace Kernel LockGuard _(m_lock); header.ack = 1; header.fin = 1; - header.ack_number = header.ack_number + 1; m_state = State::LastAck; - dprintln_if(DEBUG_TCP, "Waiting for last ack"); + dprintln_if(DEBUG_TCP, "Waiting for last ACK"); break; } case State::FinWait1: @@ -259,10 +252,9 @@ namespace Kernel { LockGuard _(m_lock); header.ack = 1; - header.seq_number = header.seq_number + 1; - header.ack_number = header.ack_number + 1; m_state = State::TimeWait; m_time_wait_start_ms = SystemTimer::get().ms_since_boot(); + dprintln_if(DEBUG_TCP, "Sent final ACK"); break; } case State::Listen: ASSERT_NOT_REACHED(); @@ -307,7 +299,7 @@ namespace Kernel auto& header = buffer.as(); - m_send_window.size = header.window_size; + m_send_window.non_scaled_size = header.window_size; auto payload = buffer.slice(header.data_offset * sizeof(uint32_t)); @@ -322,20 +314,22 @@ namespace Kernel LockGuard _(m_lock); + if (header.ack_number != m_send_window.current_seq) + { + dprintln_if(DEBUG_TCP, "Invalid ack number in SYN/ACK", (uint32_t)header.ack_number, m_send_window.current_seq); + break; + } + auto options = parse_tcp_options(header); if (options.maximum_seqment_size.has_value()) m_send_window.mss = *options.maximum_seqment_size; if (options.window_scale.has_value()) m_send_window.scale = *options.window_scale; - else - m_recv_window.scale = 0; - m_send_window.start_seq = m_send_window.start_seq + 1; - m_send_window.ack_number = m_send_window.start_seq; - m_send_window.current_seq = m_send_window.start_seq; + m_send_window.start_seq = m_send_window.current_seq; + m_send_window.current_ack = m_send_window.current_seq; m_recv_window.start_seq = header.seq_number + 1; - m_recv_window.ack_number = m_recv_window.start_seq; dprintln_if(DEBUG_TCP, "Got SYN/ACK"); @@ -344,11 +338,6 @@ namespace Kernel break; } case State::FinWait2: - if (!header.ack) - break; - if (header.fin) - m_should_ack = true; - // fall through case State::TimeWait: case State::CloseWait: case State::Established: @@ -357,20 +346,28 @@ namespace Kernel break; LockGuard _(m_lock); + if (header.fin) { if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number) dprintln_if(DEBUG_TCP, "Got FIN, but missing packets"); else { + if (m_state == State::FinWait2) + m_send_window.has_ghost_byte = true; + else + m_state = State::CloseWait; + + m_recv_window.has_ghost_byte = true; m_should_ack = true; - m_state = State::CloseWait; dprintln_if(DEBUG_TCP, "Got FIN"); } break; } - if (header.ack_number > m_send_window.ack_number) - m_send_window.ack_number = header.ack_number; + + if (header.ack_number > m_send_window.current_ack) + m_send_window.current_ack = header.ack_number; + if (payload.size() > 0) { if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size) @@ -379,13 +376,13 @@ namespace Kernel break; } - if (m_recv_window.data_size + payload.size() > m_recv_window.window->size()) + if (m_recv_window.data_size + payload.size() > m_recv_window.buffer->size()) { - dwarnln("Cannot fit received bytes to window"); + dprintln_if(DEBUG_TCP, "Cannot fit received bytes to window, waiting for retransmission"); break; } - auto* buffer = reinterpret_cast(m_recv_window.window->vaddr()); + auto* buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); memcpy(buffer + m_recv_window.data_size, payload.data(), payload.size()); m_recv_window.data_size += payload.size(); @@ -393,13 +390,14 @@ namespace Kernel dprintln_if(DEBUG_TCP, "Received {} bytes", payload.size()); } + break; } case State::LastAck: if (!header.ack) break; - set_connection_as_closed(); dprintln_if(DEBUG_TCP, "Got final ACK"); + set_connection_as_closed(); break; case State::Listen: ASSERT_NOT_REACHED(); case State::SynReceived: ASSERT_NOT_REACHED(); @@ -426,7 +424,7 @@ namespace Kernel void TCPSocket::process_task() { // FIXME: this should be dynamic - static constexpr uint32_t retransmit_timeout_ms = 100; + static constexpr uint32_t retransmit_timeout_ms = 1000; BAN::RefPtr keep_alive = this; @@ -434,13 +432,13 @@ namespace Kernel { uint64_t current_ms = SystemTimer::get().ms_since_boot(); - if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 6'000) + if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 30'000) set_connection_as_closed(); { LockGuard _(m_lock); - if (m_should_ack || m_recv_window.start_seq + m_recv_window.data_size != m_recv_window.ack_number) + if (m_should_ack) { m_should_ack = false; @@ -448,49 +446,28 @@ namespace Kernel auto* target_address = reinterpret_cast(&m_connection_info->address); auto target_address_len = m_connection_info->address_len; - m_recv_window.ack_number = m_recv_window.start_seq + m_recv_window.data_size; if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) dwarnln("{}", ret.error()); continue; } - bool is_send_open = false; - switch (m_state) + if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq) { - case State::Listen: - case State::Established: - case State::CloseWait: - case State::LastAck: - is_send_open = true; - break; - case State::SynSent: - case State::SynReceived: - case State::FinWait1: - case State::FinWait2: - case State::TimeWait: - case State::Closed: - is_send_open = false; - break; - case State::Closing: ASSERT_NOT_REACHED(); - } - - if (is_send_open && m_send_window.ack_number > m_send_window.start_seq) - { - uint32_t acknowledged_bytes = m_send_window.ack_number - m_send_window.start_seq; - ASSERT(acknowledged_bytes <= m_send_window.data_size); + uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte; + ASSERT_LTE(acknowledged_bytes, m_send_window.data_size); m_send_window.data_size -= acknowledged_bytes; m_send_window.start_seq += acknowledged_bytes; if (m_send_window.data_size > 0) { - auto* send_buffer = reinterpret_cast(m_send_window.window->vaddr()); + auto* send_buffer = reinterpret_cast(m_send_window.buffer->vaddr()); memmove(send_buffer, send_buffer + acknowledged_bytes, m_send_window.data_size); } else { - m_send_window.send_time_ms = 0; + m_send_window.last_send_ms = 0; } dprintln_if(DEBUG_TCP, "Target acknowledged {} bytes", acknowledged_bytes); @@ -498,20 +475,20 @@ namespace Kernel continue; } - if (is_send_open && m_send_window.data_size > 0 && current_ms >= m_send_window.send_time_ms + retransmit_timeout_ms) + if (m_send_window.data_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms) { ASSERT(m_connection_info.has_value()); auto* target_address = reinterpret_cast(&m_connection_info->address); auto target_address_len = m_connection_info->address_len; - const uint32_t total_send = BAN::Math::min(m_send_window.data_size, m_send_window.size << m_send_window.scale); + const uint32_t total_send = BAN::Math::min(m_send_window.data_size, m_send_window.scaled_size()); m_send_window.current_seq = m_send_window.start_seq; - auto* send_buffer = reinterpret_cast(m_send_window.window->vaddr()); + auto* send_buffer = reinterpret_cast(m_send_window.buffer->vaddr()); for (uint32_t i = 0; i < total_send;) { - uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss); + const uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss); auto message = BAN::ConstByteSpan(send_buffer + i, to_send); @@ -527,7 +504,7 @@ namespace Kernel i += to_send; } - m_send_window.send_time_ms = current_ms; + m_send_window.last_send_ms = current_ms; continue; } @@ -571,7 +548,7 @@ namespace Kernel uint32_t to_recv = BAN::Math::min(buffer.size(), m_recv_window.data_size); - auto* recv_buffer = reinterpret_cast(m_recv_window.window->vaddr()); + auto* recv_buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); memcpy(buffer.data(), recv_buffer, to_recv); m_recv_window.data_size -= to_recv; @@ -582,20 +559,28 @@ namespace Kernel return to_recv; } - BAN::ErrorOr TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t) + BAN::ErrorOr TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) { if (address) return BAN::Error::from_errno(EISCONN); - if (message.size() > m_send_window.window->size()) - return BAN::Error::from_errno(EMSGSIZE); + if (message.size() > m_send_window.buffer->size()) + { + for (size_t i = 0; i < message.size(); i++) + { + const size_t to_send = BAN::Math::min(message.size() - i, m_send_window.buffer->size()); + TRY(sendto_impl(message.slice(i, to_send), address, address_len)); + i += to_send; + } + return message.size(); + } LockGuard _(m_lock); if (m_state == State::Closed) return BAN::Error::from_errno(ENOTCONN); - while (m_send_window.data_size + message.size() > m_send_window.window->size()) + while (true) { switch (m_state) { @@ -614,12 +599,15 @@ namespace Kernel case State::Closing: ASSERT_NOT_REACHED(); }; + if (m_send_window.data_size + message.size() <= m_send_window.buffer->size()) + break; + LockFreeGuard free(m_lock); TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); } { - auto* buffer = reinterpret_cast(m_send_window.window->vaddr()); + auto* buffer = reinterpret_cast(m_send_window.buffer->vaddr()); memcpy(buffer + m_send_window.data_size, message.data(), message.size()); m_send_window.data_size += message.size(); }