Kernel: Make TCP sockets use the new ring buffer
Also fix race condition that sometimes prevented window updates not being sent after zero window effectively hanging the whole socket
This commit is contained in:
parent
8b8af1a9d9
commit
812ae77cd7
|
|
@ -4,7 +4,7 @@
|
|||
#include <BAN/Endianness.h>
|
||||
#include <BAN/Queue.h>
|
||||
#include <kernel/Lock/Mutex.h>
|
||||
#include <kernel/Memory/VirtualRange.h>
|
||||
#include <kernel/Memory/ByteRingBuffer.h>
|
||||
#include <kernel/Networking/NetworkInterface.h>
|
||||
#include <kernel/Networking/NetworkSocket.h>
|
||||
#include <kernel/Thread.h>
|
||||
|
|
@ -97,10 +97,8 @@ namespace Kernel
|
|||
|
||||
bool has_ghost_byte { false };
|
||||
|
||||
uint32_t data_tail { 0 };
|
||||
uint32_t data_size { 0 }; // number of bytes in this buffer
|
||||
uint8_t scale_shift { 0 }; // window scale
|
||||
BAN::UniqPtr<VirtualRange> buffer;
|
||||
BAN::UniqPtr<ByteRingBuffer> buffer;
|
||||
};
|
||||
|
||||
struct SendWindowInfo
|
||||
|
|
@ -119,10 +117,8 @@ namespace Kernel
|
|||
bool has_ghost_byte { false };
|
||||
bool had_zero_window { false };
|
||||
|
||||
uint32_t data_tail { 0 };
|
||||
uint32_t data_size { 0 }; // number of bytes in this buffer
|
||||
uint32_t sent_size { 0 }; // number of bytes in this buffer that have been sent
|
||||
BAN::UniqPtr<VirtualRange> buffer;
|
||||
BAN::UniqPtr<ByteRingBuffer> buffer;
|
||||
};
|
||||
|
||||
struct ConnectionInfo
|
||||
|
|
@ -180,8 +176,8 @@ namespace Kernel
|
|||
bool m_keep_alive { false };
|
||||
bool m_no_delay { false };
|
||||
|
||||
bool m_should_send_ack { false };
|
||||
bool m_should_send_zero_window { false };
|
||||
bool m_should_send_window_update { false };
|
||||
|
||||
uint64_t m_time_wait_start_ms { 0 };
|
||||
|
||||
|
|
|
|||
|
|
@ -35,23 +35,9 @@ namespace Kernel
|
|||
{
|
||||
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, info));
|
||||
socket->m_last_sent_window_size = s_recv_window_buffer_size;
|
||||
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
|
||||
PageTable::kernel(),
|
||||
KERNEL_OFFSET,
|
||||
~(vaddr_t)0,
|
||||
s_recv_window_buffer_size,
|
||||
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
||||
true, false
|
||||
));
|
||||
socket->m_recv_window.buffer = TRY(ByteRingBuffer::create(s_recv_window_buffer_size));
|
||||
socket->m_recv_window.scale_shift = s_window_shift;
|
||||
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
|
||||
PageTable::kernel(),
|
||||
KERNEL_OFFSET,
|
||||
~(vaddr_t)0,
|
||||
s_send_window_buffer_size,
|
||||
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
||||
true, false
|
||||
));
|
||||
socket->m_send_window.buffer = TRY(ByteRingBuffer::create(s_send_window_buffer_size));
|
||||
socket->m_thread = TRY(Thread::create_kernel(
|
||||
[](void* socket_ptr)
|
||||
{
|
||||
|
|
@ -206,8 +192,7 @@ namespace Kernel
|
|||
|
||||
BAN::ErrorOr<size_t> TCPSocket::recvmsg_impl(msghdr& message, int flags)
|
||||
{
|
||||
flags &= (MSG_OOB | MSG_PEEK | MSG_WAITALL);
|
||||
if (flags != 0)
|
||||
if (flags & ~(MSG_PEEK))
|
||||
{
|
||||
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
|
||||
return BAN::Error::from_errno(ENOTSUP);
|
||||
|
|
@ -222,7 +207,7 @@ namespace Kernel
|
|||
if (!m_has_connected)
|
||||
return BAN::Error::from_errno(ENOTCONN);
|
||||
|
||||
while (m_recv_window.data_size == 0)
|
||||
while (m_recv_window.buffer->empty())
|
||||
{
|
||||
if (m_state != State::Established)
|
||||
return return_with_maybe_zero();
|
||||
|
|
@ -232,37 +217,33 @@ namespace Kernel
|
|||
message.msg_flags = 0;
|
||||
|
||||
size_t total_recv = 0;
|
||||
for (int i = 0; i < message.msg_iovlen; i++)
|
||||
for (int i = 0; i < message.msg_iovlen && total_recv < m_recv_window.buffer->size(); i++)
|
||||
{
|
||||
const auto* recv_base = reinterpret_cast<const uint8_t*>(m_recv_window.buffer->vaddr());
|
||||
uint8_t* iov_base = static_cast<uint8_t*>(message.msg_iov[i].iov_base);
|
||||
auto& iov = message.msg_iov[i];
|
||||
|
||||
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_recv_window.data_size);
|
||||
|
||||
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - m_recv_window.data_tail);
|
||||
memcpy(iov_base, recv_base + m_recv_window.data_tail, before_wrap);
|
||||
if (const size_t after_wrap = nrecv - before_wrap)
|
||||
memcpy(iov_base + before_wrap, recv_base, after_wrap);
|
||||
const size_t nrecv = BAN::Math::min(iov.iov_len, m_recv_window.buffer->size() - total_recv);
|
||||
memcpy(iov.iov_base, m_recv_window.buffer->get_data().data() + total_recv, nrecv);
|
||||
|
||||
total_recv += nrecv;
|
||||
m_recv_window.data_size -= nrecv;
|
||||
m_recv_window.start_seq += nrecv;
|
||||
m_recv_window.data_tail = (m_recv_window.data_tail + nrecv) % m_recv_window.buffer->size();
|
||||
if (m_recv_window.data_size == 0)
|
||||
break;
|
||||
}
|
||||
|
||||
const size_t update_window_threshold = m_recv_window.buffer->size() / 8;
|
||||
if (!(flags & MSG_PEEK))
|
||||
{
|
||||
m_recv_window.buffer->pop(total_recv);
|
||||
m_recv_window.start_seq += total_recv;
|
||||
}
|
||||
|
||||
const size_t update_window_threshold = m_recv_window.buffer->capacity() / 8;
|
||||
const bool should_update_window_size =
|
||||
m_last_sent_window_size != m_recv_window.buffer->size() && (
|
||||
m_last_sent_window_size != m_recv_window.buffer->capacity() && (
|
||||
(m_last_sent_window_size == 0) ||
|
||||
(m_recv_window.data_size == 0) ||
|
||||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size)
|
||||
(m_recv_window.buffer->empty()) ||
|
||||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->free())
|
||||
);
|
||||
|
||||
if (should_update_window_size)
|
||||
if (should_update_window_size || m_should_send_zero_window)
|
||||
{
|
||||
m_should_send_ack = true;
|
||||
m_should_send_window_update = true;
|
||||
m_thread_blocker.unblock();
|
||||
}
|
||||
|
||||
|
|
@ -283,7 +264,7 @@ namespace Kernel
|
|||
if (!m_has_connected)
|
||||
return BAN::Error::from_errno(ENOTCONN);
|
||||
|
||||
while (m_send_window.data_size == m_send_window.buffer->size())
|
||||
while (m_send_window.buffer->full())
|
||||
{
|
||||
if (m_state != State::Established)
|
||||
return return_with_maybe_zero();
|
||||
|
|
@ -293,23 +274,14 @@ namespace Kernel
|
|||
}
|
||||
|
||||
size_t total_sent = 0;
|
||||
for (int i = 0; i < message.msg_iovlen; i++)
|
||||
for (int i = 0; i < message.msg_iovlen && !m_send_window.buffer->full(); i++)
|
||||
{
|
||||
auto* send_base = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
|
||||
const auto* iov_base = static_cast<const uint8_t*>(message.msg_iov[i].iov_base);
|
||||
const auto& iov = message.msg_iov[i];
|
||||
|
||||
const size_t nsend = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size);
|
||||
|
||||
const size_t send_head = (m_send_window.data_tail + m_send_window.data_size) % m_send_window.buffer->size();
|
||||
const size_t before_wrap = BAN::Math::min(nsend, m_send_window.buffer->size() - send_head);
|
||||
memcpy(send_base + send_head, message.msg_iov[i].iov_base, before_wrap);
|
||||
if (const size_t after_wrap = nsend - before_wrap)
|
||||
memcpy(send_base, iov_base + before_wrap, after_wrap);
|
||||
const size_t nsend = BAN::Math::min(iov.iov_len, m_send_window.buffer->free());
|
||||
m_send_window.buffer->push({ static_cast<const uint8_t*>(iov.iov_base), nsend });
|
||||
|
||||
total_sent += nsend;
|
||||
m_send_window.data_size += nsend;
|
||||
if (m_send_window.data_size == m_send_window.buffer->size())
|
||||
break;
|
||||
}
|
||||
|
||||
m_thread_blocker.unblock();
|
||||
|
|
@ -347,7 +319,7 @@ namespace Kernel
|
|||
result = m_send_window.scaled_size();
|
||||
break;
|
||||
case SO_RCVBUF:
|
||||
result = m_recv_window.buffer->size();
|
||||
result = m_recv_window.buffer->capacity();
|
||||
break;
|
||||
default:
|
||||
dwarnln("getsockopt(SOL_SOCKET, {})", option);
|
||||
|
|
@ -420,7 +392,7 @@ namespace Kernel
|
|||
switch (request)
|
||||
{
|
||||
case FIONREAD:
|
||||
*static_cast<int*>(argument) = m_recv_window.data_size;
|
||||
*static_cast<int*>(argument) = m_recv_window.buffer->size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -433,14 +405,14 @@ namespace Kernel
|
|||
return true;
|
||||
if (m_state == State::Listen)
|
||||
return !m_pending_connections.empty();
|
||||
return m_recv_window.data_size > 0;
|
||||
return !m_recv_window.buffer->empty();
|
||||
}
|
||||
|
||||
bool TCPSocket::can_write_impl() const
|
||||
{
|
||||
if (m_state != State::Established)
|
||||
return false;
|
||||
return m_send_window.data_size < m_send_window.buffer->size();
|
||||
return !m_send_window.buffer->full();
|
||||
}
|
||||
|
||||
bool TCPSocket::has_hungup_impl() const
|
||||
|
|
@ -530,19 +502,14 @@ namespace Kernel
|
|||
ASSERT(m_mutex.locker() == Thread::current().tid());
|
||||
ASSERT(header_buffer.size() == protocol_header_size());
|
||||
|
||||
m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size;
|
||||
if (m_should_send_zero_window)
|
||||
m_last_sent_window_size = 0;
|
||||
|
||||
m_should_send_ack = false;
|
||||
m_should_send_zero_window = false;
|
||||
m_last_sent_window_size = m_should_send_zero_window ? 0 : m_recv_window.buffer->free();
|
||||
|
||||
auto& header = header_buffer.as<TCPHeader>();
|
||||
header = {
|
||||
.src_port = bound_port(),
|
||||
.dst_port = dst_port,
|
||||
.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte,
|
||||
.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte,
|
||||
.ack_number = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte,
|
||||
.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t),
|
||||
.flags = m_next_flags,
|
||||
.window_size = BAN::Math::min<size_t>(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift),
|
||||
|
|
@ -569,7 +536,7 @@ namespace Kernel
|
|||
|
||||
if (m_connection_info->has_window_scale)
|
||||
add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift);
|
||||
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->size());
|
||||
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->capacity());
|
||||
|
||||
m_send_window.start_seq++;
|
||||
m_send_window.current_seq = m_send_window.start_seq;
|
||||
|
|
@ -722,7 +689,7 @@ namespace Kernel
|
|||
check_payload = true;
|
||||
if (!(header.flags & FIN))
|
||||
break;
|
||||
if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number)
|
||||
if (m_recv_window.start_seq + m_recv_window.buffer->size() != header.seq_number)
|
||||
break;
|
||||
m_next_flags = FIN | ACK;
|
||||
m_next_state = State::LastAck;
|
||||
|
|
@ -771,7 +738,7 @@ namespace Kernel
|
|||
break;
|
||||
}
|
||||
|
||||
const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
|
||||
const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte;
|
||||
|
||||
if (header.seq_number > expected_seq)
|
||||
dprintln_if(DEBUG_TCP, "Missing packets");
|
||||
|
|
@ -794,21 +761,12 @@ namespace Kernel
|
|||
payload = {};
|
||||
}
|
||||
|
||||
const bool can_receive_new_data = (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size());
|
||||
const bool can_receive_new_data = (payload.size() > 0 && !m_recv_window.buffer->full());
|
||||
|
||||
if (can_receive_new_data)
|
||||
{
|
||||
auto* recv_base = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
|
||||
|
||||
const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->size() - m_recv_window.data_size);
|
||||
|
||||
const size_t recv_head = (m_recv_window.data_tail + m_recv_window.data_size) % m_recv_window.buffer->size();
|
||||
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - recv_head);
|
||||
memcpy(recv_base + recv_head, payload.data(), before_wrap);
|
||||
if (const size_t after_wrap = nrecv - before_wrap)
|
||||
memcpy(recv_base, payload.data() + before_wrap, after_wrap);
|
||||
|
||||
m_recv_window.data_size += nrecv;
|
||||
const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->free());
|
||||
m_recv_window.buffer->push(payload.slice(0, nrecv));
|
||||
|
||||
epoll_notify(EPOLLIN);
|
||||
|
||||
|
|
@ -816,10 +774,13 @@ namespace Kernel
|
|||
}
|
||||
|
||||
// make sure zero window is reported
|
||||
if (m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size())
|
||||
if (m_last_sent_window_size > 0 && m_recv_window.buffer->full())
|
||||
m_should_send_zero_window = true;
|
||||
else if (can_receive_new_data)
|
||||
m_should_send_ack = true;
|
||||
{
|
||||
m_next_flags = ACK;
|
||||
m_next_state = m_state;
|
||||
}
|
||||
}
|
||||
|
||||
if (!hungup_before && has_hungup_impl())
|
||||
|
|
@ -925,12 +886,11 @@ namespace Kernel
|
|||
if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
|
||||
{
|
||||
const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
|
||||
ASSERT(acknowledged_bytes <= m_send_window.data_size);
|
||||
ASSERT(acknowledged_bytes <= m_send_window.buffer->size());
|
||||
|
||||
m_send_window.data_size -= acknowledged_bytes;
|
||||
m_send_window.start_seq += acknowledged_bytes;
|
||||
m_send_window.sent_size -= acknowledged_bytes;
|
||||
m_send_window.data_tail = (m_send_window.data_tail + acknowledged_bytes) % m_send_window.buffer->size();
|
||||
m_send_window.buffer->pop(acknowledged_bytes);
|
||||
|
||||
epoll_notify(EPOLLOUT);
|
||||
|
||||
|
|
@ -941,7 +901,7 @@ namespace Kernel
|
|||
|
||||
const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
|
||||
|
||||
const bool can_send_new_data = (m_send_window.data_size > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
|
||||
const bool can_send_new_data = (m_send_window.buffer->size() > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
|
||||
|
||||
if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data))
|
||||
{
|
||||
|
|
@ -951,24 +911,20 @@ namespace Kernel
|
|||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size;
|
||||
const size_t send_offset = should_retransmit ? 0 : m_send_window.sent_size;
|
||||
|
||||
const size_t total_send = BAN::Math::min<size_t>(
|
||||
m_send_window.data_size - send_start_offset,
|
||||
m_send_window.scaled_size() - send_start_offset
|
||||
m_send_window.buffer->size() - send_offset,
|
||||
m_send_window.scaled_size() - send_offset
|
||||
);
|
||||
|
||||
m_send_window.current_seq = m_send_window.start_seq + send_start_offset;
|
||||
m_send_window.current_seq = m_send_window.start_seq + send_offset;
|
||||
|
||||
const auto* send_base = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr());
|
||||
for (size_t i = 0; i < total_send;)
|
||||
{
|
||||
const size_t send_offset = (m_send_window.data_tail + send_start_offset + i) % m_send_window.buffer->size();
|
||||
const size_t to_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss);
|
||||
|
||||
const size_t max_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss);
|
||||
const size_t to_send = BAN::Math::min(max_send, m_send_window.buffer->size() - send_offset);
|
||||
|
||||
auto message = BAN::ConstByteSpan(send_base + send_offset, to_send);
|
||||
auto message = m_send_window.buffer->get_data().slice(send_offset + i, to_send);
|
||||
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())
|
||||
|
|
@ -979,9 +935,10 @@ namespace Kernel
|
|||
|
||||
dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send);
|
||||
|
||||
m_send_window.sent_size += to_send;
|
||||
m_send_window.current_seq += to_send;
|
||||
i += to_send;
|
||||
m_send_window.current_seq += to_send;
|
||||
if (send_offset + i > m_send_window.sent_size)
|
||||
m_send_window.sent_size = send_offset + i;
|
||||
}
|
||||
|
||||
m_send_window.last_send_ms = current_ms;
|
||||
|
|
@ -989,13 +946,23 @@ namespace Kernel
|
|||
continue;
|
||||
}
|
||||
|
||||
if (const size_t ack_count = m_should_send_ack + m_should_send_zero_window)
|
||||
if (m_last_sent_window_size == 0)
|
||||
m_should_send_zero_window = false;
|
||||
|
||||
if (m_should_send_zero_window || m_should_send_window_update)
|
||||
{
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
for (size_t i = 0; i < ack_count; i++)
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
dwarnln("{}", ret.error());
|
||||
|
||||
m_should_send_zero_window = false;
|
||||
m_should_send_window_update = false;
|
||||
|
||||
if (m_last_sent_window_size == 0 && !m_recv_window.buffer->full())
|
||||
{
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
|
|
|
|||
Loading…
Reference in New Issue