diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index 583e0286..24071414 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -97,6 +97,7 @@ namespace Kernel bool has_ghost_byte { false }; + uint32_t data_tail { 0 }; uint32_t data_size { 0 }; // number of bytes in this buffer uint8_t scale_shift { 0 }; // window scale BAN::UniqPtr buffer; @@ -117,6 +118,7 @@ namespace Kernel bool has_ghost_byte { false }; + uint32_t data_tail { 0 }; uint32_t data_size { 0 }; // number of bytes in this buffer uint32_t sent_size { 0 }; // number of bytes in this buffer that have been sent BAN::UniqPtr buffer; diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index c28bef49..8a2197f3 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -227,19 +227,22 @@ namespace Kernel size_t total_recv = 0; for (int i = 0; i < message.msg_iovlen; i++) { - auto* recv_buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); + const auto* recv_base = reinterpret_cast(m_recv_window.buffer->vaddr()); + uint8_t* iov_base = static_cast(message.msg_iov[i].iov_base); const size_t nrecv = BAN::Math::min(message.msg_iov[i].iov_len, m_recv_window.data_size); - memcpy(message.msg_iov[i].iov_base, recv_buffer, nrecv); + + const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - m_recv_window.data_tail); + memcpy(iov_base, recv_base + m_recv_window.data_tail, before_wrap); + if (const size_t after_wrap = nrecv - before_wrap) + memcpy(iov_base + before_wrap, recv_base, after_wrap); total_recv += nrecv; m_recv_window.data_size -= nrecv; m_recv_window.start_seq += nrecv; + m_recv_window.data_tail = (m_recv_window.data_tail + nrecv) % m_recv_window.buffer->size(); if (m_recv_window.data_size == 0) break; - - // TODO: use circular buffer to avoid this - memmove(recv_buffer, recv_buffer + nrecv, m_recv_window.data_size); } const bool should_update_window_size = @@ -281,10 +284,16 @@ namespace Kernel size_t total_sent = 0; for (int i = 0; i < message.msg_iovlen; i++) { - auto* send_buffer = reinterpret_cast(m_send_window.buffer->vaddr()); + auto* send_base = reinterpret_cast(m_send_window.buffer->vaddr()); + const auto* iov_base = static_cast(message.msg_iov[i].iov_base); const size_t nsend = BAN::Math::min(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size); - memcpy(send_buffer + m_send_window.data_size, message.msg_iov[i].iov_base, nsend); + + const size_t send_head = (m_send_window.data_tail + m_send_window.data_size) % m_send_window.buffer->size(); + const size_t before_wrap = BAN::Math::min(nsend, m_send_window.buffer->size() - send_head); + memcpy(send_base + send_head, message.msg_iov[i].iov_base, before_wrap); + if (const size_t after_wrap = nsend - before_wrap) + memcpy(send_base, iov_base + before_wrap, after_wrap); total_sent += nsend; m_send_window.data_size += nsend; @@ -745,10 +754,16 @@ namespace Kernel auto payload = buffer.slice(header.data_offset * sizeof(uint32_t)); if (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size()) { + auto* recv_base = reinterpret_cast(m_recv_window.buffer->vaddr()); + const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->size() - m_recv_window.data_size); - auto* buffer = reinterpret_cast(m_recv_window.buffer->vaddr()); - memcpy(buffer + m_recv_window.data_size, payload.data(), nrecv); + const size_t recv_head = (m_recv_window.data_tail + m_recv_window.data_size) % m_recv_window.buffer->size(); + const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - recv_head); + memcpy(recv_base + recv_head, payload.data(), before_wrap); + if (const size_t after_wrap = nrecv - before_wrap) + memcpy(recv_base, payload.data() + before_wrap, after_wrap); + m_recv_window.data_size += nrecv; epoll_notify(EPOLLIN); @@ -869,19 +884,13 @@ namespace Kernel if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq) { - uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte; + const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte; ASSERT(acknowledged_bytes <= m_send_window.data_size); m_send_window.data_size -= acknowledged_bytes; m_send_window.start_seq += acknowledged_bytes; - - if (m_send_window.data_size > 0) - { - auto* send_buffer = reinterpret_cast(m_send_window.buffer->vaddr()); - memmove(send_buffer, send_buffer + acknowledged_bytes, m_send_window.data_size); - } - m_send_window.sent_size -= acknowledged_bytes; + m_send_window.data_tail = (m_send_window.data_tail + acknowledged_bytes) % m_send_window.buffer->size(); epoll_notify(EPOLLOUT); @@ -898,18 +907,21 @@ namespace Kernel auto* target_address = reinterpret_cast(&m_connection_info->address); auto target_address_len = m_connection_info->address_len; - const uint32_t send_base = should_retransmit ? 0 : m_send_window.sent_size; + const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size; - const uint32_t total_send = BAN::Math::min(m_send_window.data_size - send_base, m_send_window.scaled_size()); + const size_t total_send = BAN::Math::min(m_send_window.data_size - send_start_offset, m_send_window.scaled_size()); - m_send_window.current_seq = m_send_window.start_seq + send_base; + m_send_window.current_seq = m_send_window.start_seq + send_start_offset; - auto* send_buffer = reinterpret_cast(m_send_window.buffer->vaddr() + send_base); - for (uint32_t i = 0; i < total_send;) + const auto* send_base = reinterpret_cast(m_send_window.buffer->vaddr()); + for (size_t i = 0; i < total_send;) { - const uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss); + const size_t send_offset = (m_send_window.data_tail + send_start_offset + i) % m_send_window.buffer->size(); - auto message = BAN::ConstByteSpan(send_buffer + i, to_send); + const size_t max_send = BAN::Math::min(total_send - i, m_send_window.mss); + const size_t to_send = BAN::Math::min(max_send, m_send_window.buffer->size() - send_offset); + + auto message = BAN::ConstByteSpan(send_base + send_offset, to_send); m_next_flags = ACK; if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())