From 4f0457a268787623236b01045b72f4c0a2399ad4 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Thu, 20 Jun 2024 13:26:50 +0300 Subject: [PATCH] Kernel: Rewrite a lot of TCP code and implement TCP server sockets TCP stack is now implemented much closer to spec --- kernel/include/kernel/Networking/IPv4Layer.h | 1 + .../include/kernel/Networking/NetworkLayer.h | 2 + .../include/kernel/Networking/NetworkSocket.h | 2 +- kernel/include/kernel/Networking/TCPSocket.h | 91 ++- kernel/include/kernel/Networking/UDPSocket.h | 2 +- kernel/kernel/Networking/IPv4Layer.cpp | 4 +- kernel/kernel/Networking/TCPSocket.cpp | 706 ++++++++++-------- kernel/kernel/Networking/UDPSocket.cpp | 12 +- 8 files changed, 484 insertions(+), 336 deletions(-) diff --git a/kernel/include/kernel/Networking/IPv4Layer.h b/kernel/include/kernel/Networking/IPv4Layer.h index 0167529a..f56345b3 100644 --- a/kernel/include/kernel/Networking/IPv4Layer.h +++ b/kernel/include/kernel/Networking/IPv4Layer.h @@ -51,6 +51,7 @@ namespace Kernel virtual BAN::ErrorOr sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override; + virtual SocketDomain domain() const override { return SocketDomain::INET ;} virtual size_t header_size() const override { return sizeof(IPv4Header); } private: diff --git a/kernel/include/kernel/Networking/NetworkLayer.h b/kernel/include/kernel/Networking/NetworkLayer.h index 68dfe92c..9e35ce7e 100644 --- a/kernel/include/kernel/Networking/NetworkLayer.h +++ b/kernel/include/kernel/Networking/NetworkLayer.h @@ -15,6 +15,7 @@ namespace Kernel static_assert(sizeof(PseudoHeader) == 12); class NetworkSocket; + enum class SocketDomain; enum class SocketType; class NetworkLayer @@ -29,6 +30,7 @@ namespace Kernel virtual BAN::ErrorOr sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0; + virtual SocketDomain domain() const = 0; virtual size_t header_size() const = 0; protected: diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index 72c28281..66e6271c 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -34,7 +34,7 @@ namespace Kernel virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0; virtual NetworkProtocol protocol() const = 0; - virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) = 0; + virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) = 0; bool is_bound() const { return m_interface != nullptr; } diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index a031d99d..5fcddfed 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -11,6 +12,18 @@ namespace Kernel { + enum TCPFlags : uint8_t + { + FIN = 0x01, + SYN = 0x02, + RST = 0x04, + PSH = 0x08, + ACK = 0x10, + URG = 0x20, + ECE = 0x40, + CWR = 0x80, + }; + struct TCPHeader { BAN::NetworkEndian src_port { 0 }; @@ -19,14 +32,7 @@ namespace Kernel BAN::NetworkEndian ack_number { 0 }; uint8_t reserved : 4 { 0 }; uint8_t data_offset : 4 { 0 }; - uint8_t fin : 1 { 0 }; - uint8_t syn : 1 { 0 }; - uint8_t rst : 1 { 0 }; - uint8_t psh : 1 { 0 }; - uint8_t ack : 1 { 0 }; - uint8_t urg : 1 { 0 }; - uint8_t ece : 1 { 0 }; - uint8_t cwr : 1 { 0 }; + uint8_t flags { }; BAN::NetworkEndian window_size { 0 }; BAN::NetworkEndian checksum { 0 }; BAN::NetworkEndian urgent_pointer { 0 }; @@ -49,16 +55,18 @@ namespace Kernel virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; protected: + virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) override; virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; + virtual BAN::ErrorOr listen_impl(int) override; + virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; + virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override; + virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override; - virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) override; + virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; - virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override; - virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan message, sockaddr* address, socklen_t* address_len) override; - - virtual bool can_read_impl() const override { return m_recv_window.data_size; } - virtual bool can_write_impl() const override { return m_state == State::Established; } - virtual bool has_error_impl() const override { return m_state != State::Established && m_state != State::Listen && m_state != State::SynSent && m_state != State::SynReceived; } + virtual bool can_read_impl() const override; + virtual bool can_write_impl() const override; + virtual bool has_error_impl() const override { return false; } private: enum class State @@ -105,6 +113,33 @@ namespace Kernel BAN::UniqPtr buffer; }; + struct ConnectionInfo + { + sockaddr_storage address; + socklen_t address_len; + }; + + struct PendingConnection + { + ConnectionInfo target; + uint32_t target_start_seq; + }; + + struct ListenKey + { + ListenKey(const sockaddr* addr, socklen_t addr_len); + ListenKey(BAN::IPv4Address addr, uint16_t port) + : address(addr), port(port) + {} + bool operator==(const ListenKey& other) const; + BAN::IPv4Address address { 0 }; + uint16_t port { 0 }; + }; + struct ListenKeyHash + { + BAN::hash_t operator()(ListenKey key) const; + }; + private: TCPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); void process_task(); @@ -112,27 +147,35 @@ namespace Kernel void start_close_sequence(); void set_connection_as_closed(); + void remove_listen_child(BAN::RefPtr); + + BAN::ErrorOr return_with_maybe_zero(); + private: State m_state = State::Closed; + State m_next_state { State::Closed }; + uint8_t m_next_flags { 0 }; + Process* m_process { nullptr }; uint64_t m_time_wait_start_ms { 0 }; - Mutex m_lock; - Semaphore m_semaphore; - - BAN::Atomic m_should_ack { false }; + Semaphore m_semaphore; RecvWindowInfo m_recv_window; SendWindowInfo m_send_window; - struct ConnectionInfo - { - sockaddr_storage address; - socklen_t address_len; - }; + bool m_has_connected { false }; + bool m_has_sent_zero { false }; + BAN::Optional m_connection_info; + BAN::Queue m_pending_connections; + + BAN::RefPtr m_listen_parent; + BAN::HashMap, ListenKeyHash> m_listen_children; + + friend class BAN::RefPtr; }; } diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index 1ba27f61..9bd5572d 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -31,7 +31,7 @@ namespace Kernel virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; protected: - virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) override; + virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override; diff --git a/kernel/kernel/Networking/IPv4Layer.cpp b/kernel/kernel/Networking/IPv4Layer.cpp index 9ef3d629..f4301002 100644 --- a/kernel/kernel/Networking/IPv4Layer.cpp +++ b/kernel/kernel/Networking/IPv4Layer.cpp @@ -298,9 +298,9 @@ namespace Kernel sockaddr_in sender; sender.sin_family = AF_INET; - sender.sin_port = BAN::NetworkEndian(src_port); + sender.sin_port = BAN::host_to_network_endian(src_port); sender.sin_addr.s_addr = src_ipv4.raw; - bound_socket->receive_packet(ipv4_data, *reinterpret_cast(&sender)); + bound_socket->receive_packet(ipv4_data, reinterpret_cast(&sender), sizeof(sender)); return {}; } diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index 21253f5a..352cdf47 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -1,8 +1,10 @@ #include +#include #include #include #include +#include #include #define DEBUG_TCP 0 @@ -23,10 +25,7 @@ namespace Kernel BAN::ErrorOr> TCPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) { - auto* socket_ptr = new TCPSocket(network_layer, ino, inode_info); - if (socket_ptr == nullptr) - return BAN::Error::from_errno(ENOMEM); - auto socket = BAN::RefPtr::adopt(socket_ptr); + auto socket = TRY(BAN::RefPtr::create(network_layer, ino, inode_info)); socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range( PageTable::kernel(), KERNEL_OFFSET, @@ -63,6 +62,66 @@ namespace Kernel { ASSERT(!is_bound()); ASSERT(m_process == nullptr); + dprintln_if(DEBUG_TCP, "Socket destroyed"); + } + + BAN::ErrorOr TCPSocket::accept_impl(sockaddr* address, socklen_t* address_len) + { + if (m_state != State::Listen) + return BAN::Error::from_errno(EINVAL); + + while (m_pending_connections.empty()) + { + LockFreeGuard _(m_mutex); + TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); + } + + auto connection = m_pending_connections.front(); + m_pending_connections.pop(); + + + auto listen_key = ListenKey( + reinterpret_cast(&connection.target.address), + connection.target.address_len + ); + if (auto it = m_listen_children.find(listen_key); it != m_listen_children.end()) + return BAN::Error::from_errno(ECONNABORTED); + + BAN::RefPtr return_inode; + { + auto return_inode_tmp = TRY(NetworkManager::get().create_socket(m_network_layer.domain(), SocketType::STREAM, mode().mode & ~Mode::TYPE_MASK, uid(), gid())); + return_inode = static_cast(return_inode_tmp.ptr()); + } + + return_inode->m_mutex.lock(); + return_inode->m_port = m_port; + return_inode->m_interface = m_interface; + return_inode->m_listen_parent = this; + return_inode->m_connection_info.emplace(connection.target); + return_inode->m_recv_window.start_seq = connection.target_start_seq; + return_inode->m_next_flags = SYN | ACK; + return_inode->m_next_state = State::SynReceived; + return_inode->m_mutex.unlock(); + + TRY(m_listen_children.emplace(listen_key, return_inode)); + + const uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000; + while (!return_inode->m_has_connected) + { + if (SystemTimer::get().ms_since_boot() >= wake_time_ms) + return BAN::Error::from_errno(ECONNABORTED); + LockFreeGuard free(m_mutex); + TRY(Thread::current().block_or_eintr_or_waketime(return_inode->m_semaphore, wake_time_ms, true)); + } + + if (address) + { + ASSERT(address_len); + *address_len = BAN::Math::min(*address_len, connection.target.address_len); + memcpy(address, &connection.target.address, *address_len); + } + + return TRY(Process::current().open_inode(return_inode, O_RDWR)); } BAN::ErrorOr TCPSocket::connect_impl(const sockaddr* address, socklen_t address_len) @@ -88,9 +147,8 @@ namespace Kernel case State::Closing: case State::LastAck: case State::TimeWait: - return BAN::Error::from_errno(EISCONN); case State::Listen: - return BAN::Error::from_errno(EOPNOTSUPP); + return BAN::Error::from_errno(EISCONN); }; if (!is_bound()) @@ -99,22 +157,166 @@ namespace Kernel m_connection_info.emplace(sockaddr_storage {}, address_len); memcpy(&m_connection_info->address, address, address_len); + m_next_flags = SYN; TRY(m_network_layer.sendto(*this, {}, address, address_len)); - ASSERT(m_state == State::SynSent); - dprintln_if(DEBUG_TCP, "Sent SYN"); + m_next_flags = 0; + m_state = State::SynSent; - uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000; - while (m_state != State::Established) + const uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000; + while (!m_has_connected) { - LockFreeGuard free(m_mutex); if (SystemTimer::get().ms_since_boot() >= wake_time_ms) return BAN::Error::from_errno(ECONNREFUSED); + LockFreeGuard free(m_mutex); TRY(Thread::current().block_or_eintr_or_waketime(m_semaphore, wake_time_ms, true)); } return {}; } + BAN::ErrorOr TCPSocket::listen_impl(int backlog) + { + if (!is_bound()) + return BAN::Error::from_errno(EDESTADDRREQ); + if (m_connection_info.has_value()) + return BAN::Error::from_errno(EINVAL); + + backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN); + TRY(m_pending_connections.reserve(backlog)); + m_state = State::Listen; + + return {}; + } + + BAN::ErrorOr TCPSocket::bind_impl(const sockaddr* address, socklen_t address_len) + { + if (is_bound()) + return BAN::Error::from_errno(EINVAL); + return m_network_layer.bind_socket_to_address(this, address, address_len); + } + + BAN::ErrorOr TCPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) + { + if (!m_has_connected) + return BAN::Error::from_errno(ENOTCONN); + + while (m_recv_window.data_size == 0) + { + if (m_state != State::Established) + return return_with_maybe_zero(); + LockFreeGuard free(m_mutex); + TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); + } + + const uint32_t to_recv = BAN::Math::min(buffer.size(), m_recv_window.data_size); + + auto* recv_buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); + memcpy(buffer.data(), recv_buffer, to_recv); + + m_recv_window.data_size -= to_recv; + m_recv_window.start_seq += to_recv; + if (m_recv_window.data_size > 0) + memmove(recv_buffer, recv_buffer + to_recv, m_recv_window.data_size); + + return to_recv; + } + + BAN::ErrorOr TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + { + if (address) + return BAN::Error::from_errno(EISCONN); + if (!m_has_connected) + return BAN::Error::from_errno(ENOTCONN); + + if (message.size() > m_send_window.buffer->size()) + { + size_t nsent = 0; + while (nsent < message.size()) + { + const size_t to_send = BAN::Math::min(message.size() - nsent, m_send_window.buffer->size()); + TRY(sendto_impl(message.slice(nsent, to_send), address, address_len)); + nsent += to_send; + } + return nsent; + } + + while (true) + { + if (m_state != State::Established) + return return_with_maybe_zero(); + if (m_send_window.data_size + message.size() <= m_send_window.buffer->size()) + break; + LockFreeGuard free(m_mutex); + TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); + } + + { + auto* buffer = reinterpret_cast(m_send_window.buffer->vaddr()); + memcpy(buffer + m_send_window.data_size, message.data(), message.size()); + m_send_window.data_size += message.size(); + } + + const uint32_t target_ack = m_send_window.start_seq + m_send_window.data_size; + m_semaphore.unblock(); + + while (m_send_window.start_seq < target_ack) + { + if (m_state != State::Established) + return return_with_maybe_zero(); + LockFreeGuard free(m_mutex); + TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); + } + + return message.size(); + } + + bool TCPSocket::can_read_impl() const + { + if (m_has_connected && !m_has_sent_zero && m_state != State::Established && m_state != State::Listen) + return true; + if (m_state == State::Listen) + return !m_pending_connections.empty(); + return m_recv_window.data_size > 0; + } + + bool TCPSocket::can_write_impl() const + { + if (m_state != State::Established) + return false; + return m_send_window.data_size < m_send_window.buffer->size(); + } + + BAN::ErrorOr TCPSocket::return_with_maybe_zero() + { + ASSERT(m_state != State::Established); + if (!m_has_sent_zero) + { + m_has_sent_zero = true; + return 0; + } + return BAN::Error::from_errno(ECONNRESET); + } + + TCPSocket::ListenKey::ListenKey(const sockaddr* addr, socklen_t addr_len) + { + ASSERT(addr->sa_family == AF_INET); + ASSERT(addr_len >= (socklen_t)sizeof(sockaddr_in)); + + const auto* addr_in = reinterpret_cast(addr); + address = BAN::IPv4Address(addr_in->sin_addr.s_addr); + port = BAN::network_endian_to_host(addr_in->sin_port); + } + + bool TCPSocket::ListenKey::operator==(const ListenKey& other) const + { + return address == other.address && port == other.port; + } + + BAN::hash_t TCPSocket::ListenKeyHash::operator()(ListenKey key) const + { + return BAN::hash()(key.address) ^ BAN::hash()(key.port); + } + template static void add_tcp_header_option(TCPHeader& header, uint32_t value) { @@ -162,6 +364,9 @@ namespace Kernel void TCPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader pseudo_header) { + ASSERT(m_next_flags); + ASSERT(m_mutex.locker() == Thread::current().tid()); + auto& header = packet.as(); memset(&header, 0, sizeof(TCPHeader)); memset(header.options, TCPOption::End, m_tcp_options_bytes); @@ -172,79 +377,43 @@ namespace Kernel header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte; header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t); header.window_size = m_recv_window.buffer->size(); + header.flags = m_next_flags; + if (header.flags & FIN) + m_send_window.has_ghost_byte = true; + m_next_flags = 0; - ASSERT(m_recv_window.buffer->size() < 1 << (8 * sizeof(header.window_size))); + ASSERT(m_recv_window.buffer->size() < (1 << (8 * sizeof(header.window_size)))); - switch (m_state) + if (m_state == State::Closed) { - case State::Closed: - { - LockGuard _(m_mutex); - header.syn = 1; - add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_interface->payload_mtu() - m_network_layer.header_size()); - add_tcp_header_option<4, TCPOption::WindowScale>(header, 0); - m_state = State::SynSent; - m_send_window.start_seq++; - m_send_window.current_seq = m_send_window.start_seq; - break; - } - case State::SynSent: - header.ack = 1; - break; - case State::SynReceived: - header.ack = 1; - m_state = State::Established; - break; - case State::Established: - header.ack = 1; - break; - case State::CloseWait: - { - LockGuard _(m_mutex); - header.ack = 1; - header.fin = 1; - m_state = State::LastAck; - dprintln_if(DEBUG_TCP, "Waiting for last ACK"); - break; - } - case State::FinWait1: - { - LockGuard _(m_mutex); - header.ack = 1; - header.fin = 1; - m_state = State::FinWait2; - break; - } - case State::FinWait2: - { - LockGuard _(m_mutex); - header.ack = 1; - m_state = State::TimeWait; - m_time_wait_start_ms = SystemTimer::get().ms_since_boot(); - dprintln_if(DEBUG_TCP, "Sent final ACK"); - break; - } - case State::Listen: ASSERT_NOT_REACHED(); - case State::Closing: ASSERT_NOT_REACHED(); - case State::LastAck: ASSERT_NOT_REACHED(); - case State::TimeWait: ASSERT_NOT_REACHED(); + add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_interface->payload_mtu() - m_network_layer.header_size()); + add_tcp_header_option<4, TCPOption::WindowScale>(header, 0); + m_send_window.mss = 1440; + m_send_window.start_seq++; + m_send_window.current_seq = m_send_window.start_seq; } pseudo_header.extra = packet.size(); header.checksum = calculate_internet_checksum(packet, pseudo_header); + + dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags); + dprintln_if(DEBUG_TCP, " {}", (uint32_t)header.ack_number); + dprintln_if(DEBUG_TCP, " {}", (uint32_t)header.seq_number); } - void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr_storage& sender) + void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr* sender, socklen_t sender_len) { + (void)sender_len; + { uint16_t checksum = 0; - if (sender.ss_family == AF_INET) + if (sender->sa_family == AF_INET) { - auto& sockaddr_in = *reinterpret_cast(&sender); + auto& addr_in = *reinterpret_cast(sender); checksum = calculate_internet_checksum(buffer, PseudoHeader { - .src_ipv4 = BAN::IPv4Address(sockaddr_in.sin_addr.s_addr), + .src_ipv4 = BAN::IPv4Address(addr_in.sin_addr.s_addr), .dst_ipv4 = m_interface->get_ipv4_address(), .protocol = NetworkProtocol::TCP, .extra = buffer.size() @@ -253,7 +422,7 @@ namespace Kernel } else { - dwarnln("No tcp checksum validation for socket family {}", sender.ss_family); + dwarnln("No tcp checksum validation for socket family {}", sender->sa_family); return; } @@ -264,26 +433,28 @@ namespace Kernel } } + LockGuard _(m_mutex); + auto& header = buffer.as(); + dprintln_if(DEBUG_TCP, "receiving {} {8b}", (uint8_t)m_state, header.flags); + dprintln_if(DEBUG_TCP, " {}", (uint32_t)header.ack_number); + dprintln_if(DEBUG_TCP, " {}", (uint32_t)header.seq_number); m_send_window.non_scaled_size = header.window_size; - auto payload = buffer.slice(header.data_offset * sizeof(uint32_t)); - + bool check_payload = false; switch (m_state) { case State::Closed: break; case State::SynSent: { - if (!header.ack || !header.syn) + if (header.flags != (SYN | ACK)) break; - LockGuard _(m_mutex); - if (header.ack_number != m_send_window.current_seq) { - dprintln_if(DEBUG_TCP, "Invalid ack number in SYN/ACK", (uint32_t)header.ack_number, m_send_window.current_seq); + dprintln_if(DEBUG_TCP, "Invalid ack number in SYN/ACK"); break; } @@ -298,120 +469,149 @@ namespace Kernel m_recv_window.start_seq = header.seq_number + 1; - dprintln_if(DEBUG_TCP, "Got SYN/ACK"); - - m_should_ack = true; - m_state = State::SynReceived; + m_next_flags = ACK; + m_next_state = State::Established; break; } - case State::FinWait2: - case State::TimeWait: - case State::CloseWait: - case State::Established: - { - if (!header.ack) + case State::SynReceived: + if (header.flags != ACK) break; - - LockGuard _(m_mutex); - - if (header.fin) + m_state = State::Established; + m_has_connected = true; + break; + case State::Listen: + if (header.flags == SYN) { - if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number) - dprintln_if(DEBUG_TCP, "Got FIN, but missing packets"); + if (m_pending_connections.size() == m_pending_connections.capacity()) + dprintln_if(DEBUG_TCP, "No storage to store pending connection"); else { - if (m_state == State::FinWait2) - m_send_window.has_ghost_byte = true; - else - m_state = State::CloseWait; - - m_recv_window.has_ghost_byte = true; - m_should_ack = true; - dprintln_if(DEBUG_TCP, "Got FIN"); + ConnectionInfo connection_info; + memcpy(&connection_info.address, sender, sender_len); + connection_info.address_len = sender_len; + MUST(m_pending_connections.emplace( + connection_info, + header.seq_number + 1 + )); } - break; } - - if (header.ack_number > m_send_window.current_ack) - m_send_window.current_ack = header.ack_number; - - if (payload.size() > 0) + else { - if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size) + auto it = m_listen_children.find(ListenKey(sender, sender_len)); + if (it == m_listen_children.end()) { - dprintln_if(DEBUG_TCP, "Missing packet"); + dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket"); break; } + auto socket = it->value; - if (m_recv_window.data_size + payload.size() > m_recv_window.buffer->size()) - { - dprintln_if(DEBUG_TCP, "Cannot fit received bytes to window, waiting for retransmission"); - break; - } + LockFreeGuard _(m_mutex); + socket->receive_packet(buffer, sender, sender_len); + return; + } + break; + case State::Established: + check_payload = true; + if (!(header.flags & FIN)) + break; + if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number) + break; + m_next_flags = FIN | ACK; + m_next_state = State::LastAck; + break; + case State::CloseWait: + check_payload = true; + if (!(header.flags & FIN)) + break; + m_next_flags = FIN; + m_next_state = State::LastAck; + break; + case State::LastAck: + check_payload = true; + if (!(header.flags & ACK)) + break; + set_connection_as_closed(); + break; + case State::FinWait1: + check_payload = true; + if (!(header.flags & (FIN | ACK))) + break; + if ((header.flags & (FIN | ACK)) == (FIN | ACK)) + m_next_state = State::TimeWait; + if (header.flags & FIN) + m_next_state = State::Closing; + if (header.flags & ACK) + m_state = State::FinWait2; + else + m_next_flags = ACK; + break; + case State::FinWait2: + check_payload = true; + if (!(header.flags & FIN)) + break; + m_next_flags = ACK; + m_next_state = State::TimeWait; + break; + case State::Closing: + check_payload = true; + if (!(header.flags & ACK)) + break; + m_state = State::TimeWait; + break; + case State::TimeWait: + check_payload = true; + break; + } + if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte) + dprintln_if(DEBUG_TCP, "Missing packets"); + else if (check_payload) + { + if (header.flags & FIN) + m_recv_window.has_ghost_byte = true; + + if (header.ack_number > m_send_window.current_ack) + m_send_window.current_ack = header.ack_number; + + auto payload = buffer.slice(header.data_offset * sizeof(uint32_t)); + if (payload.size() > 0) + { + if (m_recv_window.data_size + payload.size() > m_recv_window.buffer->size()) + dprintln_if(DEBUG_TCP, "Cannot fit received bytes to window, waiting for retransmission"); + else + { auto* buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); memcpy(buffer + m_recv_window.data_size, payload.data(), payload.size()); m_recv_window.data_size += payload.size(); - m_should_ack = true; - dprintln_if(DEBUG_TCP, "Received {} bytes", payload.size()); - } - break; + if (m_next_flags == 0) + { + m_next_flags = ACK; + m_next_state = m_state; + } + } } - case State::LastAck: - if (!header.ack) - break; - dprintln_if(DEBUG_TCP, "Got final ACK"); - set_connection_as_closed(); - break; - case State::Listen: ASSERT_NOT_REACHED(); - case State::SynReceived: ASSERT_NOT_REACHED(); - case State::FinWait1: ASSERT_NOT_REACHED(); - case State::Closing: ASSERT_NOT_REACHED(); } m_semaphore.unblock(); } - - void TCPSocket::start_close_sequence() - { - LockGuard _(m_mutex); - - if (!is_bound()) - return; - - switch (m_state) - { - case State::Established: - break; - case State::SynSent: - set_connection_as_closed(); - return; - case State::SynReceived: - case State::FinWait1: - case State::FinWait2: - case State::CloseWait: - case State::Closing: - case State::TimeWait: - case State::LastAck: - return; - case State::Closed: ASSERT_NOT_REACHED(); - case State::Listen: ASSERT_NOT_REACHED(); - } - - m_state = State::FinWait1; - m_should_ack = true; - - dprintln_if(DEBUG_TCP, "Initiated close"); - } void TCPSocket::set_connection_as_closed() { if (is_bound()) { - m_network_layer.unbind_socket(this, m_port); + // NOTE: Only listen socket can unbind the socket as + // listen socket is always alive to redirect packets + if (!m_listen_parent) + m_network_layer.unbind_socket(this, m_port); + else + { + m_listen_parent->remove_listen_child(this); + // Listen children are not actually bound, so they have to be manually removed + NetworkManager::get().TmpFileSystem::remove_from_cache(this); + } m_interface = nullptr; m_port = PORT_NONE; dprintln_if(DEBUG_TCP, "Socket unbound"); @@ -420,44 +620,68 @@ namespace Kernel m_process = nullptr; } + void TCPSocket::remove_listen_child(BAN::RefPtr socket) + { + LockGuard _(m_mutex); + + auto it = m_listen_children.find(ListenKey( + reinterpret_cast(&socket->m_connection_info->address), + socket->m_connection_info->address_len + )); + if (it == m_listen_children.end()) + { + dwarnln("remove_listen_child with non-mapped socket"); + return; + } + + m_listen_children.remove(it); + } + void TCPSocket::process_task() { // FIXME: this should be dynamic static constexpr uint32_t retransmit_timeout_ms = 1000; - BAN::RefPtr keep_alive = this; - bool started_close_sequence = false; + BAN::RefPtr keep_alive { this }; while (m_process) { const uint64_t current_ms = SystemTimer::get().ms_since_boot(); - if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 30'000) - { - set_connection_as_closed(); - continue; - } - - // This is the last instance - if (!started_close_sequence && ref_count() == 1) - { - start_close_sequence(); - started_close_sequence = true; - continue; - } - { LockGuard _(m_mutex); - if (m_should_ack.compare_exchange(true, false)) + if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 30'000) + { + set_connection_as_closed(); + continue; + } + + // This is the last instance (one instance in network manager and another keep_alive) + if (ref_count() == 2) + { + if (m_state == State::Listen) + { + set_connection_as_closed(); + continue; + } + if (m_state == State::Established) + { + m_next_flags = FIN | ACK; + m_next_state = State::FinWait1; + } + } + + if (m_next_flags) { ASSERT(m_connection_info.has_value()); auto* target_address = reinterpret_cast(&m_connection_info->address); auto target_address_len = m_connection_info->address_len; - if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) dwarnln("{}", ret.error()); - + m_state = m_next_state; + if (m_state == State::Established) + m_has_connected = true; continue; } @@ -501,6 +725,7 @@ namespace Kernel auto message = BAN::ConstByteSpan(send_buffer + i, to_send); + m_next_flags = ACK; if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error()) { dwarnln("{}", ret.error()); @@ -526,129 +751,4 @@ namespace Kernel m_semaphore.unblock(); } - BAN::ErrorOr TCPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) - { - LockGuard _(m_mutex); - - if (m_state == State::Closed) - return BAN::Error::from_errno(ENOTCONN); - - while (m_recv_window.data_size == 0) - { - switch (m_state) - { - case State::SynSent: - case State::SynReceived: - case State::Established: - case State::CloseWait: - case State::Listen: - break; - case State::FinWait1: - case State::FinWait2: - case State::LastAck: - case State::TimeWait: - return BAN::Error::from_errno(ECONNRESET); - case State::Closed: ASSERT_NOT_REACHED(); - case State::Closing: ASSERT_NOT_REACHED(); - }; - - LockFreeGuard free(m_mutex); - TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); - } - - uint32_t to_recv = BAN::Math::min(buffer.size(), m_recv_window.data_size); - - auto* recv_buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); - memcpy(buffer.data(), recv_buffer, to_recv); - - m_recv_window.data_size -= to_recv; - m_recv_window.start_seq += to_recv; - if (m_recv_window.data_size > 0) - memmove(recv_buffer, recv_buffer + to_recv, m_recv_window.data_size); - - return to_recv; - } - - BAN::ErrorOr TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) - { - if (address) - return BAN::Error::from_errno(EISCONN); - - if (message.size() > m_send_window.buffer->size()) - { - for (size_t i = 0; i < message.size(); i++) - { - const size_t to_send = BAN::Math::min(message.size() - i, m_send_window.buffer->size()); - TRY(sendto_impl(message.slice(i, to_send), address, address_len)); - i += to_send; - } - return message.size(); - } - - LockGuard _(m_mutex); - - if (m_state == State::Closed) - return BAN::Error::from_errno(ENOTCONN); - - while (true) - { - switch (m_state) - { - case State::SynSent: - case State::SynReceived: - case State::Established: - case State::CloseWait: - case State::Listen: - break; - case State::FinWait1: - case State::FinWait2: - case State::LastAck: - case State::TimeWait: - return BAN::Error::from_errno(ECONNRESET); - case State::Closed: ASSERT_NOT_REACHED(); - case State::Closing: ASSERT_NOT_REACHED(); - }; - - if (m_send_window.data_size + message.size() <= m_send_window.buffer->size()) - break; - - LockFreeGuard free(m_mutex); - TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); - } - - { - auto* buffer = reinterpret_cast(m_send_window.buffer->vaddr()); - memcpy(buffer + m_send_window.data_size, message.data(), message.size()); - m_send_window.data_size += message.size(); - } - - uint32_t target_ack = m_send_window.start_seq + m_send_window.data_size; - m_semaphore.unblock(); - - while (m_send_window.start_seq < target_ack) - { - switch (m_state) - { - case State::SynSent: - case State::SynReceived: - case State::Established: - case State::CloseWait: - case State::Listen: - case State::TimeWait: - case State::FinWait1: - case State::FinWait2: - break; - case State::LastAck: - return BAN::Error::from_errno(ECONNRESET); - case State::Closed: ASSERT_NOT_REACHED(); - case State::Closing: ASSERT_NOT_REACHED(); - }; - - LockFreeGuard free(m_mutex); - TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); - } - - return message.size(); - } - } diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index def27747..a6e9d2f4 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -40,8 +40,10 @@ namespace Kernel header.checksum = 0; } - void UDPSocket::receive_packet(BAN::ConstByteSpan packet, const sockaddr_storage& sender) + void UDPSocket::receive_packet(BAN::ConstByteSpan packet, const sockaddr* sender, socklen_t sender_len) { + (void)sender_len; + //auto& header = packet.as(); auto payload = packet.slice(sizeof(UDPHeader)); @@ -62,10 +64,10 @@ namespace Kernel void* buffer = reinterpret_cast(m_packet_buffer->vaddr() + m_packet_total_size); memcpy(buffer, payload.data(), payload.size()); - m_packets.emplace(PacketInfo { - .sender = sender, - .packet_size = payload.size() - }); + PacketInfo packet_info; + memcpy(&packet_info.sender, sender, sender_len); + packet_info.packet_size = payload.size(); + m_packets.emplace(packet_info); m_packet_total_size += payload.size(); m_packet_semaphore.unblock();