diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index 11abddba..23fc803b 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include @@ -93,36 +93,32 @@ namespace Kernel struct RecvWindowInfo { - uint32_t start_seq { 0 }; // sequence number of first byte in buffer + uint32_t start_seq { 0 }; // sequence number of first byte in buffer - bool has_ghost_byte { false }; + bool has_ghost_byte { false }; - uint32_t data_tail { 0 }; - uint32_t data_size { 0 }; // number of bytes in this buffer - uint8_t scale_shift { 0 }; // window scale - BAN::UniqPtr buffer; + uint8_t scale_shift { 0 }; // window scale + BAN::UniqPtr buffer; }; struct SendWindowInfo { - uint32_t mss { 0 }; // maximum segment size - uint16_t non_scaled_size { 0 }; // window size without scaling - uint8_t scale_shift { 0 }; // window scale - uint32_t scaled_size() const { return (uint32_t)non_scaled_size << scale_shift; } + uint32_t mss { 0 }; // maximum segment size + uint16_t non_scaled_size { 0 }; // window size without scaling + uint8_t scale_shift { 0 }; // window scale + uint32_t scaled_size() const { return (uint32_t)non_scaled_size << scale_shift; } - uint32_t start_seq { 0 }; // sequence number of first byte in buffer - uint32_t current_seq { 0 }; // sequence number of next send - uint32_t current_ack { 0 }; // sequence number aknowledged by connection + uint32_t start_seq { 0 }; // sequence number of first byte in buffer + uint32_t current_seq { 0 }; // sequence number of next send + uint32_t current_ack { 0 }; // sequence number aknowledged by connection - uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout + uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout - bool has_ghost_byte { false }; - bool had_zero_window { false }; + bool has_ghost_byte { false }; + bool had_zero_window { false }; - uint32_t data_tail { 0 }; - uint32_t data_size { 0 }; // number of bytes in this buffer - uint32_t sent_size { 0 }; // number of bytes in this buffer that have been sent - BAN::UniqPtr buffer; + uint32_t sent_size { 0 }; // number of bytes in this buffer that have been sent + BAN::UniqPtr buffer; }; struct ConnectionInfo @@ -180,8 +176,8 @@ namespace Kernel bool m_keep_alive { false }; bool m_no_delay { false }; - bool m_should_send_ack { false }; bool m_should_send_zero_window { false }; + bool m_should_send_window_update { false }; uint64_t m_time_wait_start_ms { 0 }; diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index b699518b..b3f66c09 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -35,23 +35,9 @@ namespace Kernel { auto socket = TRY(BAN::RefPtr::create(network_layer, info)); socket->m_last_sent_window_size = s_recv_window_buffer_size; - socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range( - PageTable::kernel(), - KERNEL_OFFSET, - ~(vaddr_t)0, - s_recv_window_buffer_size, - PageTable::Flags::ReadWrite | PageTable::Flags::Present, - true, false - )); + socket->m_recv_window.buffer = TRY(ByteRingBuffer::create(s_recv_window_buffer_size)); socket->m_recv_window.scale_shift = s_window_shift; - socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range( - PageTable::kernel(), - KERNEL_OFFSET, - ~(vaddr_t)0, - s_send_window_buffer_size, - PageTable::Flags::ReadWrite | PageTable::Flags::Present, - true, false - )); + socket->m_send_window.buffer = TRY(ByteRingBuffer::create(s_send_window_buffer_size)); socket->m_thread = TRY(Thread::create_kernel( [](void* socket_ptr) { @@ -206,8 +192,7 @@ namespace Kernel BAN::ErrorOr TCPSocket::recvmsg_impl(msghdr& message, int flags) { - flags &= (MSG_OOB | MSG_PEEK | MSG_WAITALL); - if (flags != 0) + if (flags & ~(MSG_PEEK)) { dwarnln("TODO: recvmsg with flags 0x{H}", flags); return BAN::Error::from_errno(ENOTSUP); @@ -222,7 +207,7 @@ namespace Kernel if (!m_has_connected) return BAN::Error::from_errno(ENOTCONN); - while (m_recv_window.data_size == 0) + while (m_recv_window.buffer->empty()) { if (m_state != State::Established) return return_with_maybe_zero(); @@ -232,37 +217,33 @@ namespace Kernel message.msg_flags = 0; size_t total_recv = 0; - for (int i = 0; i < message.msg_iovlen; i++) + for (int i = 0; i < message.msg_iovlen && total_recv < m_recv_window.buffer->size(); i++) { - const auto* recv_base = reinterpret_cast(m_recv_window.buffer->vaddr()); - uint8_t* iov_base = static_cast(message.msg_iov[i].iov_base); + auto& iov = message.msg_iov[i]; - const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, m_recv_window.data_size); - - const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - m_recv_window.data_tail); - memcpy(iov_base, recv_base + m_recv_window.data_tail, before_wrap); - if (const size_t after_wrap = nrecv - before_wrap) - memcpy(iov_base + before_wrap, recv_base, after_wrap); + const size_t nrecv = BAN::Math::min(iov.iov_len, m_recv_window.buffer->size() - total_recv); + memcpy(iov.iov_base, m_recv_window.buffer->get_data().data() + total_recv, nrecv); total_recv += nrecv; - m_recv_window.data_size -= nrecv; - m_recv_window.start_seq += nrecv; - m_recv_window.data_tail = (m_recv_window.data_tail + nrecv) % m_recv_window.buffer->size(); - if (m_recv_window.data_size == 0) - break; } - const size_t update_window_threshold = m_recv_window.buffer->size() / 8; + if (!(flags & MSG_PEEK)) + { + m_recv_window.buffer->pop(total_recv); + m_recv_window.start_seq += total_recv; + } + + const size_t update_window_threshold = m_recv_window.buffer->capacity() / 8; const bool should_update_window_size = - m_last_sent_window_size != m_recv_window.buffer->size() && ( + m_last_sent_window_size != m_recv_window.buffer->capacity() && ( (m_last_sent_window_size == 0) || - (m_recv_window.data_size == 0) || - (m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size) + (m_recv_window.buffer->empty()) || + (m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->free()) ); - if (should_update_window_size) + if (should_update_window_size || m_should_send_zero_window) { - m_should_send_ack = true; + m_should_send_window_update = true; m_thread_blocker.unblock(); } @@ -283,7 +264,7 @@ namespace Kernel if (!m_has_connected) return BAN::Error::from_errno(ENOTCONN); - while (m_send_window.data_size == m_send_window.buffer->size()) + while (m_send_window.buffer->full()) { if (m_state != State::Established) return return_with_maybe_zero(); @@ -293,23 +274,14 @@ namespace Kernel } size_t total_sent = 0; - for (int i = 0; i < message.msg_iovlen; i++) + for (int i = 0; i < message.msg_iovlen && !m_send_window.buffer->full(); i++) { - auto* send_base = reinterpret_cast(m_send_window.buffer->vaddr()); - const auto* iov_base = static_cast(message.msg_iov[i].iov_base); + const auto& iov = message.msg_iov[i]; - const size_t nsend = BAN::Math::min(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size); - - const size_t send_head = (m_send_window.data_tail + m_send_window.data_size) % m_send_window.buffer->size(); - const size_t before_wrap = BAN::Math::min(nsend, m_send_window.buffer->size() - send_head); - memcpy(send_base + send_head, message.msg_iov[i].iov_base, before_wrap); - if (const size_t after_wrap = nsend - before_wrap) - memcpy(send_base, iov_base + before_wrap, after_wrap); + const size_t nsend = BAN::Math::min(iov.iov_len, m_send_window.buffer->free()); + m_send_window.buffer->push({ static_cast(iov.iov_base), nsend }); total_sent += nsend; - m_send_window.data_size += nsend; - if (m_send_window.data_size == m_send_window.buffer->size()) - break; } m_thread_blocker.unblock(); @@ -347,7 +319,7 @@ namespace Kernel result = m_send_window.scaled_size(); break; case SO_RCVBUF: - result = m_recv_window.buffer->size(); + result = m_recv_window.buffer->capacity(); break; default: dwarnln("getsockopt(SOL_SOCKET, {})", option); @@ -420,7 +392,7 @@ namespace Kernel switch (request) { case FIONREAD: - *static_cast(argument) = m_recv_window.data_size; + *static_cast(argument) = m_recv_window.buffer->size(); return 0; } @@ -433,14 +405,14 @@ namespace Kernel return true; if (m_state == State::Listen) return !m_pending_connections.empty(); - return m_recv_window.data_size > 0; + return !m_recv_window.buffer->empty(); } bool TCPSocket::can_write_impl() const { if (m_state != State::Established) return false; - return m_send_window.data_size < m_send_window.buffer->size(); + return !m_send_window.buffer->full(); } bool TCPSocket::has_hungup_impl() const @@ -530,19 +502,14 @@ namespace Kernel ASSERT(m_mutex.locker() == Thread::current().tid()); ASSERT(header_buffer.size() == protocol_header_size()); - m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size; - if (m_should_send_zero_window) - m_last_sent_window_size = 0; - - m_should_send_ack = false; - m_should_send_zero_window = false; + m_last_sent_window_size = m_should_send_zero_window ? 0 : m_recv_window.buffer->free(); auto& header = header_buffer.as(); header = { .src_port = bound_port(), .dst_port = dst_port, .seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte, - .ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte, + .ack_number = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte, .data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t), .flags = m_next_flags, .window_size = BAN::Math::min(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift), @@ -569,7 +536,7 @@ namespace Kernel if (m_connection_info->has_window_scale) add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift); - header.window_size = BAN::Math::min(0xFFFF, m_recv_window.buffer->size()); + header.window_size = BAN::Math::min(0xFFFF, m_recv_window.buffer->capacity()); m_send_window.start_seq++; m_send_window.current_seq = m_send_window.start_seq; @@ -722,7 +689,7 @@ namespace Kernel check_payload = true; if (!(header.flags & FIN)) break; - if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number) + if (m_recv_window.start_seq + m_recv_window.buffer->size() != header.seq_number) break; m_next_flags = FIN | ACK; m_next_state = State::LastAck; @@ -771,7 +738,7 @@ namespace Kernel break; } - const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte; + const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte; if (header.seq_number > expected_seq) dprintln_if(DEBUG_TCP, "Missing packets"); @@ -794,21 +761,12 @@ namespace Kernel payload = {}; } - const bool can_receive_new_data = (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size()); + const bool can_receive_new_data = (payload.size() > 0 && !m_recv_window.buffer->full()); if (can_receive_new_data) { - auto* recv_base = reinterpret_cast(m_recv_window.buffer->vaddr()); - - const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->size() - m_recv_window.data_size); - - const size_t recv_head = (m_recv_window.data_tail + m_recv_window.data_size) % m_recv_window.buffer->size(); - const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - recv_head); - memcpy(recv_base + recv_head, payload.data(), before_wrap); - if (const size_t after_wrap = nrecv - before_wrap) - memcpy(recv_base, payload.data() + before_wrap, after_wrap); - - m_recv_window.data_size += nrecv; + const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->free()); + m_recv_window.buffer->push(payload.slice(0, nrecv)); epoll_notify(EPOLLIN); @@ -816,10 +774,13 @@ namespace Kernel } // make sure zero window is reported - if (m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size()) + if (m_last_sent_window_size > 0 && m_recv_window.buffer->full()) m_should_send_zero_window = true; else if (can_receive_new_data) - m_should_send_ack = true; + { + m_next_flags = ACK; + m_next_state = m_state; + } } if (!hungup_before && has_hungup_impl()) @@ -925,12 +886,11 @@ namespace Kernel if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq) { const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte; - ASSERT(acknowledged_bytes <= m_send_window.data_size); + ASSERT(acknowledged_bytes <= m_send_window.buffer->size()); - m_send_window.data_size -= acknowledged_bytes; m_send_window.start_seq += acknowledged_bytes; m_send_window.sent_size -= acknowledged_bytes; - m_send_window.data_tail = (m_send_window.data_tail + acknowledged_bytes) % m_send_window.buffer->size(); + m_send_window.buffer->pop(acknowledged_bytes); epoll_notify(EPOLLOUT); @@ -941,7 +901,7 @@ namespace Kernel const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms); - const bool can_send_new_data = (m_send_window.data_size > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size()); + const bool can_send_new_data = (m_send_window.buffer->size() > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size()); if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data)) { @@ -951,24 +911,20 @@ namespace Kernel auto* target_address = reinterpret_cast(&m_connection_info->address); auto target_address_len = m_connection_info->address_len; - const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size; + const size_t send_offset = should_retransmit ? 0 : m_send_window.sent_size; const size_t total_send = BAN::Math::min( - m_send_window.data_size - send_start_offset, - m_send_window.scaled_size() - send_start_offset + m_send_window.buffer->size() - send_offset, + m_send_window.scaled_size() - send_offset ); - m_send_window.current_seq = m_send_window.start_seq + send_start_offset; + m_send_window.current_seq = m_send_window.start_seq + send_offset; - const auto* send_base = reinterpret_cast(m_send_window.buffer->vaddr()); for (size_t i = 0; i < total_send;) { - const size_t send_offset = (m_send_window.data_tail + send_start_offset + i) % m_send_window.buffer->size(); + const size_t to_send = BAN::Math::min(total_send - i, m_send_window.mss); - const size_t max_send = BAN::Math::min(total_send - i, m_send_window.mss); - const size_t to_send = BAN::Math::min(max_send, m_send_window.buffer->size() - send_offset); - - auto message = BAN::ConstByteSpan(send_base + send_offset, to_send); + auto message = m_send_window.buffer->get_data().slice(send_offset + i, to_send); m_next_flags = ACK; if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error()) @@ -979,9 +935,10 @@ namespace Kernel dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send); - m_send_window.sent_size += to_send; - m_send_window.current_seq += to_send; i += to_send; + m_send_window.current_seq += to_send; + if (send_offset + i > m_send_window.sent_size) + m_send_window.sent_size = send_offset + i; } m_send_window.last_send_ms = current_ms; @@ -989,13 +946,23 @@ namespace Kernel continue; } - if (const size_t ack_count = m_should_send_ack + m_should_send_zero_window) + if (m_last_sent_window_size == 0) + m_should_send_zero_window = false; + + if (m_should_send_zero_window || m_should_send_window_update) { ASSERT(m_connection_info.has_value()); auto* target_address = reinterpret_cast(&m_connection_info->address); auto target_address_len = m_connection_info->address_len; - for (size_t i = 0; i < ack_count; i++) + m_next_flags = ACK; + if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) + dwarnln("{}", ret.error()); + + m_should_send_zero_window = false; + m_should_send_window_update = false; + + if (m_last_sent_window_size == 0 && !m_recv_window.buffer->full()) { m_next_flags = ACK; if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())