Kernel: Make TCP sockets use the new ring buffer
Also fix race condition that sometimes prevented window updates not being sent after zero window effectively hanging the whole socket
This commit is contained in:
@@ -35,23 +35,9 @@ namespace Kernel
|
||||
{
|
||||
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, info));
|
||||
socket->m_last_sent_window_size = s_recv_window_buffer_size;
|
||||
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
|
||||
PageTable::kernel(),
|
||||
KERNEL_OFFSET,
|
||||
~(vaddr_t)0,
|
||||
s_recv_window_buffer_size,
|
||||
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
||||
true, false
|
||||
));
|
||||
socket->m_recv_window.buffer = TRY(ByteRingBuffer::create(s_recv_window_buffer_size));
|
||||
socket->m_recv_window.scale_shift = s_window_shift;
|
||||
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
|
||||
PageTable::kernel(),
|
||||
KERNEL_OFFSET,
|
||||
~(vaddr_t)0,
|
||||
s_send_window_buffer_size,
|
||||
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
||||
true, false
|
||||
));
|
||||
socket->m_send_window.buffer = TRY(ByteRingBuffer::create(s_send_window_buffer_size));
|
||||
socket->m_thread = TRY(Thread::create_kernel(
|
||||
[](void* socket_ptr)
|
||||
{
|
||||
@@ -206,8 +192,7 @@ namespace Kernel
|
||||
|
||||
BAN::ErrorOr<size_t> TCPSocket::recvmsg_impl(msghdr& message, int flags)
|
||||
{
|
||||
flags &= (MSG_OOB | MSG_PEEK | MSG_WAITALL);
|
||||
if (flags != 0)
|
||||
if (flags & ~(MSG_PEEK))
|
||||
{
|
||||
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
|
||||
return BAN::Error::from_errno(ENOTSUP);
|
||||
@@ -222,7 +207,7 @@ namespace Kernel
|
||||
if (!m_has_connected)
|
||||
return BAN::Error::from_errno(ENOTCONN);
|
||||
|
||||
while (m_recv_window.data_size == 0)
|
||||
while (m_recv_window.buffer->empty())
|
||||
{
|
||||
if (m_state != State::Established)
|
||||
return return_with_maybe_zero();
|
||||
@@ -232,37 +217,33 @@ namespace Kernel
|
||||
message.msg_flags = 0;
|
||||
|
||||
size_t total_recv = 0;
|
||||
for (int i = 0; i < message.msg_iovlen; i++)
|
||||
for (int i = 0; i < message.msg_iovlen && total_recv < m_recv_window.buffer->size(); i++)
|
||||
{
|
||||
const auto* recv_base = reinterpret_cast<const uint8_t*>(m_recv_window.buffer->vaddr());
|
||||
uint8_t* iov_base = static_cast<uint8_t*>(message.msg_iov[i].iov_base);
|
||||
auto& iov = message.msg_iov[i];
|
||||
|
||||
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_recv_window.data_size);
|
||||
|
||||
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - m_recv_window.data_tail);
|
||||
memcpy(iov_base, recv_base + m_recv_window.data_tail, before_wrap);
|
||||
if (const size_t after_wrap = nrecv - before_wrap)
|
||||
memcpy(iov_base + before_wrap, recv_base, after_wrap);
|
||||
const size_t nrecv = BAN::Math::min(iov.iov_len, m_recv_window.buffer->size() - total_recv);
|
||||
memcpy(iov.iov_base, m_recv_window.buffer->get_data().data() + total_recv, nrecv);
|
||||
|
||||
total_recv += nrecv;
|
||||
m_recv_window.data_size -= nrecv;
|
||||
m_recv_window.start_seq += nrecv;
|
||||
m_recv_window.data_tail = (m_recv_window.data_tail + nrecv) % m_recv_window.buffer->size();
|
||||
if (m_recv_window.data_size == 0)
|
||||
break;
|
||||
}
|
||||
|
||||
const size_t update_window_threshold = m_recv_window.buffer->size() / 8;
|
||||
if (!(flags & MSG_PEEK))
|
||||
{
|
||||
m_recv_window.buffer->pop(total_recv);
|
||||
m_recv_window.start_seq += total_recv;
|
||||
}
|
||||
|
||||
const size_t update_window_threshold = m_recv_window.buffer->capacity() / 8;
|
||||
const bool should_update_window_size =
|
||||
m_last_sent_window_size != m_recv_window.buffer->size() && (
|
||||
m_last_sent_window_size != m_recv_window.buffer->capacity() && (
|
||||
(m_last_sent_window_size == 0) ||
|
||||
(m_recv_window.data_size == 0) ||
|
||||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size)
|
||||
(m_recv_window.buffer->empty()) ||
|
||||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->free())
|
||||
);
|
||||
|
||||
if (should_update_window_size)
|
||||
if (should_update_window_size || m_should_send_zero_window)
|
||||
{
|
||||
m_should_send_ack = true;
|
||||
m_should_send_window_update = true;
|
||||
m_thread_blocker.unblock();
|
||||
}
|
||||
|
||||
@@ -283,7 +264,7 @@ namespace Kernel
|
||||
if (!m_has_connected)
|
||||
return BAN::Error::from_errno(ENOTCONN);
|
||||
|
||||
while (m_send_window.data_size == m_send_window.buffer->size())
|
||||
while (m_send_window.buffer->full())
|
||||
{
|
||||
if (m_state != State::Established)
|
||||
return return_with_maybe_zero();
|
||||
@@ -293,23 +274,14 @@ namespace Kernel
|
||||
}
|
||||
|
||||
size_t total_sent = 0;
|
||||
for (int i = 0; i < message.msg_iovlen; i++)
|
||||
for (int i = 0; i < message.msg_iovlen && !m_send_window.buffer->full(); i++)
|
||||
{
|
||||
auto* send_base = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
|
||||
const auto* iov_base = static_cast<const uint8_t*>(message.msg_iov[i].iov_base);
|
||||
const auto& iov = message.msg_iov[i];
|
||||
|
||||
const size_t nsend = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size);
|
||||
|
||||
const size_t send_head = (m_send_window.data_tail + m_send_window.data_size) % m_send_window.buffer->size();
|
||||
const size_t before_wrap = BAN::Math::min(nsend, m_send_window.buffer->size() - send_head);
|
||||
memcpy(send_base + send_head, message.msg_iov[i].iov_base, before_wrap);
|
||||
if (const size_t after_wrap = nsend - before_wrap)
|
||||
memcpy(send_base, iov_base + before_wrap, after_wrap);
|
||||
const size_t nsend = BAN::Math::min(iov.iov_len, m_send_window.buffer->free());
|
||||
m_send_window.buffer->push({ static_cast<const uint8_t*>(iov.iov_base), nsend });
|
||||
|
||||
total_sent += nsend;
|
||||
m_send_window.data_size += nsend;
|
||||
if (m_send_window.data_size == m_send_window.buffer->size())
|
||||
break;
|
||||
}
|
||||
|
||||
m_thread_blocker.unblock();
|
||||
@@ -347,7 +319,7 @@ namespace Kernel
|
||||
result = m_send_window.scaled_size();
|
||||
break;
|
||||
case SO_RCVBUF:
|
||||
result = m_recv_window.buffer->size();
|
||||
result = m_recv_window.buffer->capacity();
|
||||
break;
|
||||
default:
|
||||
dwarnln("getsockopt(SOL_SOCKET, {})", option);
|
||||
@@ -420,7 +392,7 @@ namespace Kernel
|
||||
switch (request)
|
||||
{
|
||||
case FIONREAD:
|
||||
*static_cast<int*>(argument) = m_recv_window.data_size;
|
||||
*static_cast<int*>(argument) = m_recv_window.buffer->size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -433,14 +405,14 @@ namespace Kernel
|
||||
return true;
|
||||
if (m_state == State::Listen)
|
||||
return !m_pending_connections.empty();
|
||||
return m_recv_window.data_size > 0;
|
||||
return !m_recv_window.buffer->empty();
|
||||
}
|
||||
|
||||
bool TCPSocket::can_write_impl() const
|
||||
{
|
||||
if (m_state != State::Established)
|
||||
return false;
|
||||
return m_send_window.data_size < m_send_window.buffer->size();
|
||||
return !m_send_window.buffer->full();
|
||||
}
|
||||
|
||||
bool TCPSocket::has_hungup_impl() const
|
||||
@@ -530,19 +502,14 @@ namespace Kernel
|
||||
ASSERT(m_mutex.locker() == Thread::current().tid());
|
||||
ASSERT(header_buffer.size() == protocol_header_size());
|
||||
|
||||
m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size;
|
||||
if (m_should_send_zero_window)
|
||||
m_last_sent_window_size = 0;
|
||||
|
||||
m_should_send_ack = false;
|
||||
m_should_send_zero_window = false;
|
||||
m_last_sent_window_size = m_should_send_zero_window ? 0 : m_recv_window.buffer->free();
|
||||
|
||||
auto& header = header_buffer.as<TCPHeader>();
|
||||
header = {
|
||||
.src_port = bound_port(),
|
||||
.dst_port = dst_port,
|
||||
.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte,
|
||||
.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte,
|
||||
.ack_number = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte,
|
||||
.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t),
|
||||
.flags = m_next_flags,
|
||||
.window_size = BAN::Math::min<size_t>(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift),
|
||||
@@ -569,7 +536,7 @@ namespace Kernel
|
||||
|
||||
if (m_connection_info->has_window_scale)
|
||||
add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift);
|
||||
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->size());
|
||||
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->capacity());
|
||||
|
||||
m_send_window.start_seq++;
|
||||
m_send_window.current_seq = m_send_window.start_seq;
|
||||
@@ -722,7 +689,7 @@ namespace Kernel
|
||||
check_payload = true;
|
||||
if (!(header.flags & FIN))
|
||||
break;
|
||||
if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number)
|
||||
if (m_recv_window.start_seq + m_recv_window.buffer->size() != header.seq_number)
|
||||
break;
|
||||
m_next_flags = FIN | ACK;
|
||||
m_next_state = State::LastAck;
|
||||
@@ -771,7 +738,7 @@ namespace Kernel
|
||||
break;
|
||||
}
|
||||
|
||||
const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
|
||||
const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.buffer->size() + m_recv_window.has_ghost_byte;
|
||||
|
||||
if (header.seq_number > expected_seq)
|
||||
dprintln_if(DEBUG_TCP, "Missing packets");
|
||||
@@ -794,21 +761,12 @@ namespace Kernel
|
||||
payload = {};
|
||||
}
|
||||
|
||||
const bool can_receive_new_data = (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size());
|
||||
const bool can_receive_new_data = (payload.size() > 0 && !m_recv_window.buffer->full());
|
||||
|
||||
if (can_receive_new_data)
|
||||
{
|
||||
auto* recv_base = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
|
||||
|
||||
const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->size() - m_recv_window.data_size);
|
||||
|
||||
const size_t recv_head = (m_recv_window.data_tail + m_recv_window.data_size) % m_recv_window.buffer->size();
|
||||
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - recv_head);
|
||||
memcpy(recv_base + recv_head, payload.data(), before_wrap);
|
||||
if (const size_t after_wrap = nrecv - before_wrap)
|
||||
memcpy(recv_base, payload.data() + before_wrap, after_wrap);
|
||||
|
||||
m_recv_window.data_size += nrecv;
|
||||
const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->free());
|
||||
m_recv_window.buffer->push(payload.slice(0, nrecv));
|
||||
|
||||
epoll_notify(EPOLLIN);
|
||||
|
||||
@@ -816,10 +774,13 @@ namespace Kernel
|
||||
}
|
||||
|
||||
// make sure zero window is reported
|
||||
if (m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size())
|
||||
if (m_last_sent_window_size > 0 && m_recv_window.buffer->full())
|
||||
m_should_send_zero_window = true;
|
||||
else if (can_receive_new_data)
|
||||
m_should_send_ack = true;
|
||||
{
|
||||
m_next_flags = ACK;
|
||||
m_next_state = m_state;
|
||||
}
|
||||
}
|
||||
|
||||
if (!hungup_before && has_hungup_impl())
|
||||
@@ -925,12 +886,11 @@ namespace Kernel
|
||||
if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
|
||||
{
|
||||
const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
|
||||
ASSERT(acknowledged_bytes <= m_send_window.data_size);
|
||||
ASSERT(acknowledged_bytes <= m_send_window.buffer->size());
|
||||
|
||||
m_send_window.data_size -= acknowledged_bytes;
|
||||
m_send_window.start_seq += acknowledged_bytes;
|
||||
m_send_window.sent_size -= acknowledged_bytes;
|
||||
m_send_window.data_tail = (m_send_window.data_tail + acknowledged_bytes) % m_send_window.buffer->size();
|
||||
m_send_window.buffer->pop(acknowledged_bytes);
|
||||
|
||||
epoll_notify(EPOLLOUT);
|
||||
|
||||
@@ -941,7 +901,7 @@ namespace Kernel
|
||||
|
||||
const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
|
||||
|
||||
const bool can_send_new_data = (m_send_window.data_size > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
|
||||
const bool can_send_new_data = (m_send_window.buffer->size() > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
|
||||
|
||||
if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data))
|
||||
{
|
||||
@@ -951,24 +911,20 @@ namespace Kernel
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size;
|
||||
const size_t send_offset = should_retransmit ? 0 : m_send_window.sent_size;
|
||||
|
||||
const size_t total_send = BAN::Math::min<size_t>(
|
||||
m_send_window.data_size - send_start_offset,
|
||||
m_send_window.scaled_size() - send_start_offset
|
||||
m_send_window.buffer->size() - send_offset,
|
||||
m_send_window.scaled_size() - send_offset
|
||||
);
|
||||
|
||||
m_send_window.current_seq = m_send_window.start_seq + send_start_offset;
|
||||
m_send_window.current_seq = m_send_window.start_seq + send_offset;
|
||||
|
||||
const auto* send_base = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr());
|
||||
for (size_t i = 0; i < total_send;)
|
||||
{
|
||||
const size_t send_offset = (m_send_window.data_tail + send_start_offset + i) % m_send_window.buffer->size();
|
||||
const size_t to_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss);
|
||||
|
||||
const size_t max_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss);
|
||||
const size_t to_send = BAN::Math::min(max_send, m_send_window.buffer->size() - send_offset);
|
||||
|
||||
auto message = BAN::ConstByteSpan(send_base + send_offset, to_send);
|
||||
auto message = m_send_window.buffer->get_data().slice(send_offset + i, to_send);
|
||||
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())
|
||||
@@ -979,9 +935,10 @@ namespace Kernel
|
||||
|
||||
dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send);
|
||||
|
||||
m_send_window.sent_size += to_send;
|
||||
m_send_window.current_seq += to_send;
|
||||
i += to_send;
|
||||
m_send_window.current_seq += to_send;
|
||||
if (send_offset + i > m_send_window.sent_size)
|
||||
m_send_window.sent_size = send_offset + i;
|
||||
}
|
||||
|
||||
m_send_window.last_send_ms = current_ms;
|
||||
@@ -989,13 +946,23 @@ namespace Kernel
|
||||
continue;
|
||||
}
|
||||
|
||||
if (const size_t ack_count = m_should_send_ack + m_should_send_zero_window)
|
||||
if (m_last_sent_window_size == 0)
|
||||
m_should_send_zero_window = false;
|
||||
|
||||
if (m_should_send_zero_window || m_should_send_window_update)
|
||||
{
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
for (size_t i = 0; i < ack_count; i++)
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
dwarnln("{}", ret.error());
|
||||
|
||||
m_should_send_zero_window = false;
|
||||
m_should_send_window_update = false;
|
||||
|
||||
if (m_last_sent_window_size == 0 && !m_recv_window.buffer->full())
|
||||
{
|
||||
m_next_flags = ACK;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
|
||||
Reference in New Issue
Block a user