From 435636a655955ba4b535003ecbf80b19260e6016 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Mon, 12 Feb 2024 04:27:50 +0200 Subject: [PATCH] Kernel: Implement super simple TCP stack No SACK support and windows are fixed size --- kernel/CMakeLists.txt | 1 + .../include/kernel/Networking/NetworkSocket.h | 1 + kernel/include/kernel/Networking/TCPSocket.h | 118 ++++ kernel/kernel/Networking/IPv4Layer.cpp | 8 + kernel/kernel/Networking/NetworkManager.cpp | 28 +- kernel/kernel/Networking/TCPSocket.cpp | 656 ++++++++++++++++++ 6 files changed, 804 insertions(+), 8 deletions(-) create mode 100644 kernel/include/kernel/Networking/TCPSocket.h create mode 100644 kernel/kernel/Networking/TCPSocket.cpp diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 72b9137a2e..69f1a8303f 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -57,6 +57,7 @@ set(KERNEL_SOURCES kernel/Networking/NetworkInterface.cpp kernel/Networking/NetworkManager.cpp kernel/Networking/NetworkSocket.cpp + kernel/Networking/TCPSocket.cpp kernel/Networking/UDPSocket.cpp kernel/Networking/UNIX/Socket.cpp kernel/OpenFileDescriptorSet.cpp diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index 19424a099d..04986fa7f5 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -12,6 +12,7 @@ namespace Kernel enum NetworkProtocol : uint8_t { ICMP = 0x01, + TCP = 0x06, UDP = 0x11, }; diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h new file mode 100644 index 0000000000..8df8fa7e76 --- /dev/null +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace Kernel +{ + + struct TCPHeader + { + BAN::NetworkEndian src_port { 0 }; + BAN::NetworkEndian dst_port { 0 }; + BAN::NetworkEndian seq_number { 0 }; + BAN::NetworkEndian ack_number { 0 }; + uint8_t reserved : 4 { 0 }; + uint8_t data_offset : 4 { 0 }; + uint8_t fin : 1 { 0 }; + uint8_t syn : 1 { 0 }; + uint8_t rst : 1 { 0 }; + uint8_t psh : 1 { 0 }; + uint8_t ack : 1 { 0 }; + uint8_t urg : 1 { 0 }; + uint8_t ece : 1 { 0 }; + uint8_t cwr : 1 { 0 }; + BAN::NetworkEndian window_size { 0 }; + BAN::NetworkEndian checksum { 0 }; + BAN::NetworkEndian urgent_pointer { 0 }; + uint8_t options[0]; + }; + static_assert(sizeof(TCPHeader) == 20); + + class TCPSocket final : public NetworkSocket + { + public: + static constexpr size_t m_tcp_options_bytes = 4; + + public: + static BAN::ErrorOr> create(NetworkLayer&, ino_t, const TmpInodeInfo&); + ~TCPSocket(); + + virtual NetworkProtocol protocol() const override { return NetworkProtocol::TCP; } + + virtual size_t protocol_header_size() const override { return sizeof(TCPHeader) + m_tcp_options_bytes; } + virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; + + protected: + virtual void on_close_impl() override; + + virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; + + virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) override; + + virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override; + virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan message, sockaddr* address, socklen_t* address_len) override; + + private: + enum class State + { + Closed = 0, + Listen, + SynSent, + SynReceived, + Established, + FinWait1, + FinWait2, + CloseWait, + Closing, + LastAck, + TimeWait, + }; + + struct WindowInfo + { + uint32_t mss { 0 }; + uint16_t size { 0 }; + uint8_t scale { 0 }; + uint32_t start_seq { 0 }; + uint32_t current_seq { 0 }; + BAN::Atomic ack_number { 0 }; + uint32_t data_size { 0 }; + uint64_t send_time_ms { 0 }; + BAN::UniqPtr window; + }; + + private: + TCPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); + void process_task(); + + void set_connection_as_closed(); + + private: + State m_state = State::Closed; + + Process* m_process { nullptr }; + + uint64_t m_time_wait_start_ms { 0 }; + + RecursiveSpinLock m_lock; + Semaphore m_semaphore; + + BAN::Atomic m_should_ack { false }; + + WindowInfo m_recv_window; + WindowInfo m_send_window; + + struct ConnectionInfo + { + sockaddr_storage address; + socklen_t address_len; + }; + BAN::Optional m_connection_info; + }; + +} diff --git a/kernel/kernel/Networking/IPv4Layer.cpp b/kernel/kernel/Networking/IPv4Layer.cpp index 7ff8e2048a..42fca57a2b 100644 --- a/kernel/kernel/Networking/IPv4Layer.cpp +++ b/kernel/kernel/Networking/IPv4Layer.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -224,6 +225,13 @@ namespace Kernel src_port = udp_header.src_port; break; } + case NetworkProtocol::TCP: + { + auto& tcp_header = ipv4_data.as(); + dst_port = tcp_header.dst_port; + src_port = tcp_header.src_port; + break; + } default: dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol); return {}; diff --git a/kernel/kernel/Networking/NetworkManager.cpp b/kernel/kernel/Networking/NetworkManager.cpp index 4fe77f947f..530526dcc7 100644 --- a/kernel/kernel/Networking/NetworkManager.cpp +++ b/kernel/kernel/Networking/NetworkManager.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -76,15 +77,17 @@ namespace Kernel switch (domain) { case SocketDomain::INET: - { - if (type != SocketType::DGRAM) - return BAN::Error::from_errno(EPROTOTYPE); + switch (type) + { + case SocketType::DGRAM: + case SocketType::STREAM: + break; + default: + return BAN::Error::from_errno(EPROTOTYPE); + } break; - } case SocketDomain::UNIX: - { break; - } default: return BAN::Error::from_errno(EAFNOSUPPORT); } @@ -100,8 +103,17 @@ namespace Kernel { case SocketDomain::INET: { - if (type == SocketType::DGRAM) - socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info)); + switch (type) + { + case SocketType::DGRAM: + socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info)); + break; + case SocketType::STREAM: + socket = TRY(TCPSocket::create(*m_ipv4_layer, ino, inode_info)); + break; + default: + ASSERT_NOT_REACHED(); + } break; } case SocketDomain::UNIX: diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp new file mode 100644 index 0000000000..327bd8d664 --- /dev/null +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -0,0 +1,656 @@ +#include +#include +#include +#include + +#include + +#define DEBUG_TCP 0 + +namespace Kernel +{ + + enum TCPOption : uint8_t + { + End = 0x00, + NOP = 0x01, + MaximumSeqmentSize = 0x02, + WindowScale = 0x03, + }; + + static constexpr size_t s_window_buffer_size = 15 * PAGE_SIZE; + static_assert(s_window_buffer_size <= UINT16_MAX); + + BAN::ErrorOr> TCPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) + { + auto* socket_ptr = new TCPSocket(network_layer, ino, inode_info); + if (socket_ptr == nullptr) + return BAN::Error::from_errno(ENOMEM); + auto socket = BAN::RefPtr::adopt(socket_ptr); + socket->m_recv_window.window = TRY(VirtualRange::create_to_vaddr_range( + PageTable::kernel(), + KERNEL_OFFSET, + ~(vaddr_t)0, + s_window_buffer_size, + PageTable::Flags::ReadWrite | PageTable::Flags::Present, + true + )); + socket->m_send_window.window = TRY(VirtualRange::create_to_vaddr_range( + PageTable::kernel(), + KERNEL_OFFSET, + ~(vaddr_t)0, + s_window_buffer_size, + PageTable::Flags::ReadWrite | PageTable::Flags::Present, + true + )); + socket->m_recv_window.size = socket->m_recv_window.window->size(); + socket->m_recv_window.scale = 0; + socket->m_process = Process::create_kernel( + [](void* socket_ptr) + { + reinterpret_cast(socket_ptr)->process_task(); + }, socket.ptr() + ); + return socket; + } + + TCPSocket::TCPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) + : NetworkSocket(network_layer, ino, inode_info) + { + m_send_window.start_seq = Random::get_u32() & 0x7FFFFFFF; + m_send_window.ack_number = m_send_window.start_seq; + m_send_window.current_seq = m_send_window.start_seq; + } + + TCPSocket::~TCPSocket() + { + ASSERT(!is_bound()); + ASSERT(m_process == nullptr); + dprintln_if(DEBUG_TCP, "socket destroyed"); + } + + void TCPSocket::on_close_impl() + { + LockGuard _(m_lock); + + if (!is_bound()) + return; + + switch (m_state) + { + case State::Established: + break; + case State::SynSent: + set_connection_as_closed(); + // fall through + case State::SynReceived: + case State::FinWait1: + case State::FinWait2: + case State::CloseWait: + case State::Closing: + case State::TimeWait: + case State::LastAck: + return; + case State::Closed: ASSERT_NOT_REACHED(); + case State::Listen: ASSERT_NOT_REACHED(); + } + + ASSERT(m_connection_info.has_value()); + auto* target_address = reinterpret_cast(&m_connection_info->address); + auto target_address_len = m_connection_info->address_len; + + m_state = State::FinWait1; + if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) + dwarnln("{}", ret.error()); + + dprintln_if(DEBUG_TCP, "Initiated close"); + } + + BAN::ErrorOr TCPSocket::connect_impl(const sockaddr* address, socklen_t address_len) + { + if (address_len > (socklen_t)sizeof(sockaddr_storage)) + address_len = sizeof(sockaddr_storage); + + LockGuard _(m_lock); + + ASSERT(!m_connection_info.has_value()); + + switch (m_state) + { + case State::Closed: + break; + case State::SynSent: + case State::SynReceived: + return BAN::Error::from_errno(EALREADY); + case State::Established: + case State::FinWait1: + case State::FinWait2: + case State::CloseWait: + case State::Closing: + case State::LastAck: + case State::TimeWait: + return BAN::Error::from_errno(EISCONN); + case State::Listen: + return BAN::Error::from_errno(EOPNOTSUPP); + }; + + if (!is_bound()) + TRY(m_network_layer.bind_socket_to_unused(this, address, address_len)); + + m_connection_info.emplace(sockaddr_storage {}, address_len); + memcpy(&m_connection_info->address, address, address_len); + + m_recv_window.mss = m_interface->payload_mtu() - m_network_layer.header_size(); + + TRY(m_network_layer.sendto(*this, {}, address, address_len)); + ASSERT(m_state == State::SynSent); + dprintln_if(DEBUG_TCP, "Sent SYN"); + + uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000; + while (m_state != State::Established) + { + LockFreeGuard free(m_lock); + if (SystemTimer::get().ms_since_boot() >= wake_time_ms) + return BAN::Error::from_errno(ECONNREFUSED); + TRY(Thread::current().block_or_eintr_or_waketime(m_semaphore, wake_time_ms, true)); + } + + return {}; + } + + template + static void add_tcp_header_option(TCPHeader& header, uint32_t value) + { + if constexpr(Op == TCPOption::MaximumSeqmentSize) + { + header.options[Off + 0] = Op; + header.options[Off + 1] = 0x04; + header.options[Off + 2] = value >> 8; + header.options[Off + 3] = value; + } + else if constexpr(Op == TCPOption::WindowScale) + { + header.options[Off + 0] = Op; + header.options[Off + 1] = 0x03; + header.options[Off + 2] = value; + } + } + + struct ParsedTCPOptions + { + BAN::Optional maximum_seqment_size; + BAN::Optional window_scale; + }; + static ParsedTCPOptions parse_tcp_options(const TCPHeader& header) + { + ParsedTCPOptions result; + + for (size_t i = 0; i < header.data_offset * sizeof(uint32_t) - sizeof(TCPHeader) - 1; i++) + { + if (header.options[i] == TCPOption::End) + break; + if (header.options[i] == TCPOption::NOP) + continue; + if (header.options[i] == TCPOption::MaximumSeqmentSize) + result.maximum_seqment_size = BAN::host_to_network_endian(*reinterpret_cast(&header.options[i + 2])); + if (header.options[i] == TCPOption::WindowScale) + result.window_scale = header.options[i + 2]; + if (header.options[i + 1] == 0) + break; + i += header.options[i + 1] - 1; + } + + return result; + } + + void TCPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader pseudo_header) + { + auto& header = packet.as(); + memset(&header, 0, sizeof(TCPHeader)); + memset(header.options, TCPOption::End, m_tcp_options_bytes); + + header.dst_port = dst_port; + header.src_port = m_port; + header.seq_number = m_send_window.current_seq; + header.ack_number = m_recv_window.ack_number.load(); + header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t); + header.window_size = m_recv_window.window->size(); + + switch (m_state) + { + case State::Closed: + { + LockGuard _(m_lock); + header.syn = 1; + add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_recv_window.mss); + add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale); + m_state = State::SynSent; + break; + } + case State::SynSent: + header.ack = 1; + break; + case State::SynReceived: + header.ack = 1; + m_state = State::Established; + break; + case State::Established: + header.ack = 1; + break; + case State::CloseWait: + { + LockGuard _(m_lock); + header.ack = 1; + header.fin = 1; + header.ack_number = header.ack_number + 1; + m_state = State::LastAck; + dprintln_if(DEBUG_TCP, "Waiting for last ack"); + break; + } + case State::FinWait1: + { + LockGuard _(m_lock); + header.ack = 1; + header.fin = 1; + m_state = State::FinWait2; + break; + } + case State::FinWait2: + { + LockGuard _(m_lock); + header.ack = 1; + header.seq_number = header.seq_number + 1; + header.ack_number = header.ack_number + 1; + m_state = State::TimeWait; + m_time_wait_start_ms = SystemTimer::get().ms_since_boot(); + break; + } + case State::Listen: ASSERT_NOT_REACHED(); + case State::Closing: ASSERT_NOT_REACHED(); + case State::LastAck: ASSERT_NOT_REACHED(); + case State::TimeWait: ASSERT_NOT_REACHED(); + } + + pseudo_header.extra = packet.size(); + header.checksum = calculate_internet_checksum(packet, pseudo_header); + } + + void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr_storage& sender) + { + { + uint16_t checksum = 0; + + if (sender.ss_family == AF_INET) + { + auto& sockaddr_in = *reinterpret_cast(&sender); + checksum = calculate_internet_checksum(buffer, + PseudoHeader { + .src_ipv4 = BAN::IPv4Address(sockaddr_in.sin_addr.s_addr), + .dst_ipv4 = m_interface->get_ipv4_address(), + .protocol = NetworkProtocol::TCP, + .extra = buffer.size() + } + ); + } + else + { + dwarnln("No tcp checksum validation for socket family {}", sender.ss_family); + return; + } + + if (checksum != 0) + { + dprintln("Checksum does not match"); + return; + } + } + + auto& header = buffer.as(); + + m_send_window.size = header.window_size; + + auto payload = buffer.slice(header.data_offset * sizeof(uint32_t)); + + switch (m_state) + { + case State::Closed: + break; + case State::SynSent: + { + if (!header.ack || !header.syn) + break; + + LockGuard _(m_lock); + + auto options = parse_tcp_options(header); + if (options.maximum_seqment_size.has_value()) + m_send_window.mss = *options.maximum_seqment_size; + if (options.window_scale.has_value()) + m_send_window.scale = *options.window_scale; + else + m_recv_window.scale = 0; + + m_send_window.start_seq = m_send_window.start_seq + 1; + m_send_window.ack_number = m_send_window.start_seq; + m_send_window.current_seq = m_send_window.start_seq; + + m_recv_window.start_seq = header.seq_number + 1; + m_recv_window.ack_number = m_recv_window.start_seq; + + dprintln_if(DEBUG_TCP, "Got SYN/ACK"); + + m_should_ack = true; + m_state = State::SynReceived; + break; + } + case State::FinWait2: + if (!header.ack) + break; + if (header.fin) + m_should_ack = true; + // fall through + case State::TimeWait: + case State::CloseWait: + case State::Established: + { + if (!header.ack) + break; + + LockGuard _(m_lock); + if (header.fin) + { + if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number) + dprintln_if(DEBUG_TCP, "Got FIN, but missing packets"); + else + { + m_should_ack = true; + m_state = State::CloseWait; + dprintln_if(DEBUG_TCP, "Got FIN"); + } + break; + } + if (header.ack_number > m_send_window.ack_number) + m_send_window.ack_number = header.ack_number; + if (payload.size() > 0) + { + if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size) + { + dprintln_if(DEBUG_TCP, "Missing packet"); + break; + } + + if (m_recv_window.data_size + payload.size() > m_recv_window.window->size()) + { + dwarnln("Cannot fit received bytes to window"); + break; + } + + auto* buffer = reinterpret_cast(m_recv_window.window->vaddr()); + memcpy(buffer + m_recv_window.data_size, payload.data(), payload.size()); + m_recv_window.data_size += payload.size(); + + m_should_ack = true; + + dprintln_if(DEBUG_TCP, "Received {} bytes", payload.size()); + } + break; + } + case State::LastAck: + if (!header.ack) + break; + set_connection_as_closed(); + dprintln_if(DEBUG_TCP, "Got final ACK"); + break; + case State::Listen: ASSERT_NOT_REACHED(); + case State::SynReceived: ASSERT_NOT_REACHED(); + case State::FinWait1: ASSERT_NOT_REACHED(); + case State::Closing: ASSERT_NOT_REACHED(); + } + + m_semaphore.unblock(); + } + + void TCPSocket::set_connection_as_closed() + { + if (is_bound()) + { + m_network_layer.unbind_socket(this, m_port); + m_interface = nullptr; + m_port = PORT_NONE; + dprintln_if(DEBUG_TCP, "Socket unbound"); + } + + m_process = nullptr; + } + + void TCPSocket::process_task() + { + // FIXME: this should be dynamic + static constexpr uint32_t retransmit_timeout_ms = 100; + + BAN::RefPtr keep_alive = this; + + while (m_process) + { + uint64_t current_ms = SystemTimer::get().ms_since_boot(); + + if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 6'000) + set_connection_as_closed(); + + { + LockGuard _(m_lock); + + if (m_should_ack || m_recv_window.start_seq + m_recv_window.data_size != m_recv_window.ack_number) + { + m_should_ack = false; + + ASSERT(m_connection_info.has_value()); + auto* target_address = reinterpret_cast(&m_connection_info->address); + auto target_address_len = m_connection_info->address_len; + + m_recv_window.ack_number = m_recv_window.start_seq + m_recv_window.data_size; + if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) + dwarnln("{}", ret.error()); + + continue; + } + + bool is_send_open = false; + switch (m_state) + { + case State::Listen: + case State::Established: + case State::CloseWait: + case State::LastAck: + is_send_open = true; + break; + case State::SynSent: + case State::SynReceived: + case State::FinWait1: + case State::FinWait2: + case State::TimeWait: + case State::Closed: + is_send_open = false; + break; + case State::Closing: ASSERT_NOT_REACHED(); + } + + if (is_send_open && m_send_window.ack_number > m_send_window.start_seq) + { + uint32_t acknowledged_bytes = m_send_window.ack_number - m_send_window.start_seq; + ASSERT(acknowledged_bytes <= m_send_window.data_size); + + m_send_window.data_size -= acknowledged_bytes; + m_send_window.start_seq += acknowledged_bytes; + + if (m_send_window.data_size > 0) + { + auto* send_buffer = reinterpret_cast(m_send_window.window->vaddr()); + memmove(send_buffer, send_buffer + acknowledged_bytes, m_send_window.data_size); + } + else + { + m_send_window.send_time_ms = 0; + } + + dprintln_if(DEBUG_TCP, "Target acknowledged {} bytes", acknowledged_bytes); + + continue; + } + + if (is_send_open && m_send_window.data_size > 0 && current_ms >= m_send_window.send_time_ms + retransmit_timeout_ms) + { + ASSERT(m_connection_info.has_value()); + auto* target_address = reinterpret_cast(&m_connection_info->address); + auto target_address_len = m_connection_info->address_len; + + const uint32_t total_send = BAN::Math::min(m_send_window.data_size, m_send_window.size << m_send_window.scale); + + m_send_window.current_seq = m_send_window.start_seq; + + auto* send_buffer = reinterpret_cast(m_send_window.window->vaddr()); + for (uint32_t i = 0; i < total_send;) + { + uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss); + + auto message = BAN::ConstByteSpan(send_buffer + i, to_send); + + if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error()) + { + dwarnln("{}", ret.error()); + break; + } + + dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send); + + m_send_window.current_seq += to_send; + i += to_send; + } + + m_send_window.send_time_ms = current_ms; + + continue; + } + } + + m_semaphore.block_with_wake_time(current_ms + retransmit_timeout_ms); + } + + m_semaphore.unblock(); + } + + BAN::ErrorOr TCPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) + { + LockGuard _(m_lock); + + if (m_state == State::Closed) + return BAN::Error::from_errno(ENOTCONN); + + while (m_recv_window.data_size == 0) + { + switch (m_state) + { + case State::SynSent: + case State::SynReceived: + case State::Established: + case State::CloseWait: + case State::Listen: + break; + case State::FinWait1: + case State::FinWait2: + case State::LastAck: + case State::TimeWait: + return BAN::Error::from_errno(ECONNRESET); + case State::Closed: ASSERT_NOT_REACHED(); + case State::Closing: ASSERT_NOT_REACHED(); + }; + + LockFreeGuard free(m_lock); + TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); + } + + uint32_t to_recv = BAN::Math::min(buffer.size(), m_recv_window.data_size); + + auto* recv_buffer = reinterpret_cast(m_recv_window.window->vaddr()); + memcpy(buffer.data(), recv_buffer, to_recv); + + m_recv_window.data_size -= to_recv; + m_recv_window.start_seq += to_recv; + if (m_recv_window.data_size > 0) + memmove(recv_buffer, recv_buffer + to_recv, m_recv_window.data_size); + + return to_recv; + } + + BAN::ErrorOr TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t) + { + if (address) + return BAN::Error::from_errno(EISCONN); + + if (message.size() > m_send_window.window->size()) + return BAN::Error::from_errno(EMSGSIZE); + + LockGuard _(m_lock); + + if (m_state == State::Closed) + return BAN::Error::from_errno(ENOTCONN); + + while (m_send_window.data_size + message.size() > m_send_window.window->size()) + { + switch (m_state) + { + case State::SynSent: + case State::SynReceived: + case State::Established: + case State::CloseWait: + case State::Listen: + break; + case State::FinWait1: + case State::FinWait2: + case State::LastAck: + case State::TimeWait: + return BAN::Error::from_errno(ECONNRESET); + case State::Closed: ASSERT_NOT_REACHED(); + case State::Closing: ASSERT_NOT_REACHED(); + }; + + LockFreeGuard free(m_lock); + TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); + } + + { + auto* buffer = reinterpret_cast(m_send_window.window->vaddr()); + memcpy(buffer + m_send_window.data_size, message.data(), message.size()); + m_send_window.data_size += message.size(); + } + + uint32_t target_ack = m_send_window.start_seq + m_send_window.data_size; + m_semaphore.unblock(); + + while (m_send_window.start_seq < target_ack) + { + switch (m_state) + { + case State::SynSent: + case State::SynReceived: + case State::Established: + case State::CloseWait: + case State::Listen: + case State::TimeWait: + case State::FinWait1: + case State::FinWait2: + break; + case State::LastAck: + return BAN::Error::from_errno(ECONNRESET); + case State::Closed: ASSERT_NOT_REACHED(); + case State::Closing: ASSERT_NOT_REACHED(); + }; + + LockFreeGuard free(m_lock); + TRY(Thread::current().block_or_eintr_indefinite(m_semaphore)); + } + + return message.size(); + } + +}