diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index 3f6f41aa..583e0286 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -167,6 +167,8 @@ namespace Kernel State m_next_state { State::Closed }; uint8_t m_next_flags { 0 }; + size_t m_last_sent_window_size { 0 }; + Thread* m_thread { nullptr }; // TODO: actually support these diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index b85d655a..c28bef49 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -28,6 +28,7 @@ namespace Kernel BAN::ErrorOr> TCPSocket::create(NetworkLayer& network_layer, const Info& info) { auto socket = TRY(BAN::RefPtr::create(network_layer, info)); + socket->m_last_sent_window_size = s_recv_window_buffer_size; socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range( PageTable::kernel(), KERNEL_OFFSET, @@ -241,6 +242,16 @@ namespace Kernel memmove(recv_buffer, recv_buffer + nrecv, m_recv_window.data_size); } + const bool should_update_window_size = + (m_last_sent_window_size == 0) || + (m_recv_window.data_size == 0) || + (m_last_sent_window_size + PAGE_SIZE < m_recv_window.buffer->size() - m_recv_window.data_size); + if (m_next_flags == 0 && should_update_window_size) + { + m_next_flags = ACK; + m_thread_blocker.unblock(); + } + return total_recv; } @@ -502,12 +513,14 @@ namespace Kernel memset(&header, 0, sizeof(TCPHeader)); memset(header.options, TCPOption::End, m_tcp_options_bytes); + m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size; + header.src_port = bound_port(); header.dst_port = dst_port; header.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte; header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte; header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t); - header.window_size = BAN::Math::min(0xFFFF, m_recv_window.buffer->size() >> m_recv_window.scale_shift); + header.window_size = BAN::Math::min(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift); header.flags = m_next_flags; if (header.flags & FIN) m_send_window.has_ghost_byte = true; @@ -730,27 +743,28 @@ namespace Kernel m_send_window.current_ack = header.ack_number; auto payload = buffer.slice(header.data_offset * sizeof(uint32_t)); - if (payload.size() > 0) + if (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size()) { - if (m_recv_window.data_size + payload.size() > m_recv_window.buffer->size()) - dprintln_if(DEBUG_TCP, "Cannot fit received bytes to window, waiting for retransmission"); - else + const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->size() - m_recv_window.data_size); + + auto* buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); + memcpy(buffer + m_recv_window.data_size, payload.data(), nrecv); + m_recv_window.data_size += nrecv; + + epoll_notify(EPOLLIN); + + dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv); + + if (m_next_flags == 0) { - auto* buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); - memcpy(buffer + m_recv_window.data_size, payload.data(), payload.size()); - m_recv_window.data_size += payload.size(); - - epoll_notify(EPOLLIN); - - dprintln_if(DEBUG_TCP, "Received {} bytes", payload.size()); - - if (m_next_flags == 0) - { - m_next_flags = ACK; - m_next_state = m_state; - } + m_next_flags = ACK; + m_next_state = m_state; } } + + // make sure zero window is reported + if (m_next_flags == 0 && m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size()) + m_next_flags = ACK; } if (!hungup_before && has_hungup_impl())