Kernel: Cleanup TCP code
This commit is contained in:
parent
ccde8148a7
commit
f50b4be162
|
@ -73,17 +73,33 @@ namespace Kernel
|
|||
TimeWait,
|
||||
};
|
||||
|
||||
struct WindowInfo
|
||||
struct RecvWindowInfo
|
||||
{
|
||||
uint32_t mss { 0 };
|
||||
uint16_t size { 0 };
|
||||
uint8_t scale { 0 };
|
||||
uint32_t start_seq { 0 };
|
||||
uint32_t current_seq { 0 };
|
||||
BAN::Atomic<uint32_t> ack_number { 0 };
|
||||
uint32_t data_size { 0 };
|
||||
uint64_t send_time_ms { 0 };
|
||||
BAN::UniqPtr<VirtualRange> window;
|
||||
uint32_t start_seq { 0 }; // sequence number of first byte in buffer
|
||||
|
||||
bool has_ghost_byte { false };
|
||||
|
||||
uint32_t data_size { 0 }; // number of bytes in this buffer
|
||||
BAN::UniqPtr<VirtualRange> buffer;
|
||||
};
|
||||
|
||||
struct SendWindowInfo
|
||||
{
|
||||
uint32_t mss { 0 }; // maximum segment size
|
||||
uint16_t non_scaled_size { 0 }; // window size without scaling
|
||||
uint8_t scale { 0 }; // window scale
|
||||
uint32_t scaled_size() const { return (uint32_t)non_scaled_size << scale; }
|
||||
|
||||
uint32_t start_seq { 0 }; // sequence number of first byte in buffer
|
||||
uint32_t current_seq { 0 }; // sequence number of next send
|
||||
uint32_t current_ack { 0 }; // sequence number aknowledged by connection
|
||||
|
||||
uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout
|
||||
|
||||
bool has_ghost_byte { false };
|
||||
|
||||
uint32_t data_size { 0 }; // number of bytes in this buffer
|
||||
BAN::UniqPtr<VirtualRange> buffer;
|
||||
};
|
||||
|
||||
private:
|
||||
|
@ -104,8 +120,8 @@ namespace Kernel
|
|||
|
||||
BAN::Atomic<bool> m_should_ack { false };
|
||||
|
||||
WindowInfo m_recv_window;
|
||||
WindowInfo m_send_window;
|
||||
RecvWindowInfo m_recv_window;
|
||||
SendWindowInfo m_send_window;
|
||||
|
||||
struct ConnectionInfo
|
||||
{
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace Kernel
|
|||
if (socket_ptr == nullptr)
|
||||
return BAN::Error::from_errno(ENOMEM);
|
||||
auto socket = BAN::RefPtr<TCPSocket>::adopt(socket_ptr);
|
||||
socket->m_recv_window.window = TRY(VirtualRange::create_to_vaddr_range(
|
||||
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
|
||||
PageTable::kernel(),
|
||||
KERNEL_OFFSET,
|
||||
~(vaddr_t)0,
|
||||
|
@ -35,7 +35,7 @@ namespace Kernel
|
|||
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
||||
true
|
||||
));
|
||||
socket->m_send_window.window = TRY(VirtualRange::create_to_vaddr_range(
|
||||
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
|
||||
PageTable::kernel(),
|
||||
KERNEL_OFFSET,
|
||||
~(vaddr_t)0,
|
||||
|
@ -43,8 +43,6 @@ namespace Kernel
|
|||
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
||||
true
|
||||
));
|
||||
socket->m_recv_window.size = socket->m_recv_window.window->size();
|
||||
socket->m_recv_window.scale = 0;
|
||||
socket->m_process = Process::create_kernel(
|
||||
[](void* socket_ptr)
|
||||
{
|
||||
|
@ -58,7 +56,6 @@ namespace Kernel
|
|||
: NetworkSocket(network_layer, ino, inode_info)
|
||||
{
|
||||
m_send_window.start_seq = Random::get_u32() & 0x7FFFFFFF;
|
||||
m_send_window.ack_number = m_send_window.start_seq;
|
||||
m_send_window.current_seq = m_send_window.start_seq;
|
||||
}
|
||||
|
||||
|
@ -95,13 +92,8 @@ namespace Kernel
|
|||
case State::Listen: ASSERT_NOT_REACHED();
|
||||
}
|
||||
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
m_state = State::FinWait1;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
dwarnln("{}", ret.error());
|
||||
m_should_ack = true;
|
||||
|
||||
dprintln_if(DEBUG_TCP, "Initiated close");
|
||||
}
|
||||
|
@ -140,8 +132,6 @@ namespace Kernel
|
|||
m_connection_info.emplace(sockaddr_storage {}, address_len);
|
||||
memcpy(&m_connection_info->address, address, address_len);
|
||||
|
||||
m_recv_window.mss = m_interface->payload_mtu() - m_network_layer.header_size();
|
||||
|
||||
TRY(m_network_layer.sendto(*this, {}, address, address_len));
|
||||
ASSERT(m_state == State::SynSent);
|
||||
dprintln_if(DEBUG_TCP, "Sent SYN");
|
||||
|
@ -211,10 +201,12 @@ namespace Kernel
|
|||
|
||||
header.dst_port = dst_port;
|
||||
header.src_port = m_port;
|
||||
header.seq_number = m_send_window.current_seq;
|
||||
header.ack_number = m_recv_window.ack_number.load();
|
||||
header.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte;
|
||||
header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
|
||||
header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t);
|
||||
header.window_size = m_recv_window.window->size();
|
||||
header.window_size = m_recv_window.buffer->size();
|
||||
|
||||
ASSERT(m_recv_window.buffer->size() < 1 << (8 * sizeof(header.window_size)));
|
||||
|
||||
switch (m_state)
|
||||
{
|
||||
|
@ -222,9 +214,11 @@ namespace Kernel
|
|||
{
|
||||
LockGuard _(m_lock);
|
||||
header.syn = 1;
|
||||
add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_recv_window.mss);
|
||||
add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale);
|
||||
add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_interface->payload_mtu() - m_network_layer.header_size());
|
||||
add_tcp_header_option<4, TCPOption::WindowScale>(header, 0);
|
||||
m_state = State::SynSent;
|
||||
m_send_window.start_seq++;
|
||||
m_send_window.current_seq = m_send_window.start_seq;
|
||||
break;
|
||||
}
|
||||
case State::SynSent:
|
||||
|
@ -242,9 +236,8 @@ namespace Kernel
|
|||
LockGuard _(m_lock);
|
||||
header.ack = 1;
|
||||
header.fin = 1;
|
||||
header.ack_number = header.ack_number + 1;
|
||||
m_state = State::LastAck;
|
||||
dprintln_if(DEBUG_TCP, "Waiting for last ack");
|
||||
dprintln_if(DEBUG_TCP, "Waiting for last ACK");
|
||||
break;
|
||||
}
|
||||
case State::FinWait1:
|
||||
|
@ -259,10 +252,9 @@ namespace Kernel
|
|||
{
|
||||
LockGuard _(m_lock);
|
||||
header.ack = 1;
|
||||
header.seq_number = header.seq_number + 1;
|
||||
header.ack_number = header.ack_number + 1;
|
||||
m_state = State::TimeWait;
|
||||
m_time_wait_start_ms = SystemTimer::get().ms_since_boot();
|
||||
dprintln_if(DEBUG_TCP, "Sent final ACK");
|
||||
break;
|
||||
}
|
||||
case State::Listen: ASSERT_NOT_REACHED();
|
||||
|
@ -307,7 +299,7 @@ namespace Kernel
|
|||
|
||||
auto& header = buffer.as<const TCPHeader>();
|
||||
|
||||
m_send_window.size = header.window_size;
|
||||
m_send_window.non_scaled_size = header.window_size;
|
||||
|
||||
auto payload = buffer.slice(header.data_offset * sizeof(uint32_t));
|
||||
|
||||
|
@ -322,20 +314,22 @@ namespace Kernel
|
|||
|
||||
LockGuard _(m_lock);
|
||||
|
||||
if (header.ack_number != m_send_window.current_seq)
|
||||
{
|
||||
dprintln_if(DEBUG_TCP, "Invalid ack number in SYN/ACK", (uint32_t)header.ack_number, m_send_window.current_seq);
|
||||
break;
|
||||
}
|
||||
|
||||
auto options = parse_tcp_options(header);
|
||||
if (options.maximum_seqment_size.has_value())
|
||||
m_send_window.mss = *options.maximum_seqment_size;
|
||||
if (options.window_scale.has_value())
|
||||
m_send_window.scale = *options.window_scale;
|
||||
else
|
||||
m_recv_window.scale = 0;
|
||||
|
||||
m_send_window.start_seq = m_send_window.start_seq + 1;
|
||||
m_send_window.ack_number = m_send_window.start_seq;
|
||||
m_send_window.current_seq = m_send_window.start_seq;
|
||||
m_send_window.start_seq = m_send_window.current_seq;
|
||||
m_send_window.current_ack = m_send_window.current_seq;
|
||||
|
||||
m_recv_window.start_seq = header.seq_number + 1;
|
||||
m_recv_window.ack_number = m_recv_window.start_seq;
|
||||
|
||||
dprintln_if(DEBUG_TCP, "Got SYN/ACK");
|
||||
|
||||
|
@ -344,11 +338,6 @@ namespace Kernel
|
|||
break;
|
||||
}
|
||||
case State::FinWait2:
|
||||
if (!header.ack)
|
||||
break;
|
||||
if (header.fin)
|
||||
m_should_ack = true;
|
||||
// fall through
|
||||
case State::TimeWait:
|
||||
case State::CloseWait:
|
||||
case State::Established:
|
||||
|
@ -357,20 +346,28 @@ namespace Kernel
|
|||
break;
|
||||
|
||||
LockGuard _(m_lock);
|
||||
|
||||
if (header.fin)
|
||||
{
|
||||
if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number)
|
||||
dprintln_if(DEBUG_TCP, "Got FIN, but missing packets");
|
||||
else
|
||||
{
|
||||
m_should_ack = true;
|
||||
if (m_state == State::FinWait2)
|
||||
m_send_window.has_ghost_byte = true;
|
||||
else
|
||||
m_state = State::CloseWait;
|
||||
|
||||
m_recv_window.has_ghost_byte = true;
|
||||
m_should_ack = true;
|
||||
dprintln_if(DEBUG_TCP, "Got FIN");
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (header.ack_number > m_send_window.ack_number)
|
||||
m_send_window.ack_number = header.ack_number;
|
||||
|
||||
if (header.ack_number > m_send_window.current_ack)
|
||||
m_send_window.current_ack = header.ack_number;
|
||||
|
||||
if (payload.size() > 0)
|
||||
{
|
||||
if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size)
|
||||
|
@ -379,13 +376,13 @@ namespace Kernel
|
|||
break;
|
||||
}
|
||||
|
||||
if (m_recv_window.data_size + payload.size() > m_recv_window.window->size())
|
||||
if (m_recv_window.data_size + payload.size() > m_recv_window.buffer->size())
|
||||
{
|
||||
dwarnln("Cannot fit received bytes to window");
|
||||
dprintln_if(DEBUG_TCP, "Cannot fit received bytes to window, waiting for retransmission");
|
||||
break;
|
||||
}
|
||||
|
||||
auto* buffer = reinterpret_cast<uint8_t*>(m_recv_window.window->vaddr());
|
||||
auto* buffer = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
|
||||
memcpy(buffer + m_recv_window.data_size, payload.data(), payload.size());
|
||||
m_recv_window.data_size += payload.size();
|
||||
|
||||
|
@ -393,13 +390,14 @@ namespace Kernel
|
|||
|
||||
dprintln_if(DEBUG_TCP, "Received {} bytes", payload.size());
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case State::LastAck:
|
||||
if (!header.ack)
|
||||
break;
|
||||
set_connection_as_closed();
|
||||
dprintln_if(DEBUG_TCP, "Got final ACK");
|
||||
set_connection_as_closed();
|
||||
break;
|
||||
case State::Listen: ASSERT_NOT_REACHED();
|
||||
case State::SynReceived: ASSERT_NOT_REACHED();
|
||||
|
@ -426,7 +424,7 @@ namespace Kernel
|
|||
void TCPSocket::process_task()
|
||||
{
|
||||
// FIXME: this should be dynamic
|
||||
static constexpr uint32_t retransmit_timeout_ms = 100;
|
||||
static constexpr uint32_t retransmit_timeout_ms = 1000;
|
||||
|
||||
BAN::RefPtr<TCPSocket> keep_alive = this;
|
||||
|
||||
|
@ -434,13 +432,13 @@ namespace Kernel
|
|||
{
|
||||
uint64_t current_ms = SystemTimer::get().ms_since_boot();
|
||||
|
||||
if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 6'000)
|
||||
if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 30'000)
|
||||
set_connection_as_closed();
|
||||
|
||||
{
|
||||
LockGuard _(m_lock);
|
||||
|
||||
if (m_should_ack || m_recv_window.start_seq + m_recv_window.data_size != m_recv_window.ack_number)
|
||||
if (m_should_ack)
|
||||
{
|
||||
m_should_ack = false;
|
||||
|
||||
|
@ -448,49 +446,28 @@ namespace Kernel
|
|||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
m_recv_window.ack_number = m_recv_window.start_seq + m_recv_window.data_size;
|
||||
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
|
||||
dwarnln("{}", ret.error());
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_send_open = false;
|
||||
switch (m_state)
|
||||
if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
|
||||
{
|
||||
case State::Listen:
|
||||
case State::Established:
|
||||
case State::CloseWait:
|
||||
case State::LastAck:
|
||||
is_send_open = true;
|
||||
break;
|
||||
case State::SynSent:
|
||||
case State::SynReceived:
|
||||
case State::FinWait1:
|
||||
case State::FinWait2:
|
||||
case State::TimeWait:
|
||||
case State::Closed:
|
||||
is_send_open = false;
|
||||
break;
|
||||
case State::Closing: ASSERT_NOT_REACHED();
|
||||
}
|
||||
|
||||
if (is_send_open && m_send_window.ack_number > m_send_window.start_seq)
|
||||
{
|
||||
uint32_t acknowledged_bytes = m_send_window.ack_number - m_send_window.start_seq;
|
||||
ASSERT(acknowledged_bytes <= m_send_window.data_size);
|
||||
uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
|
||||
ASSERT_LTE(acknowledged_bytes, m_send_window.data_size);
|
||||
|
||||
m_send_window.data_size -= acknowledged_bytes;
|
||||
m_send_window.start_seq += acknowledged_bytes;
|
||||
|
||||
if (m_send_window.data_size > 0)
|
||||
{
|
||||
auto* send_buffer = reinterpret_cast<uint8_t*>(m_send_window.window->vaddr());
|
||||
auto* send_buffer = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
|
||||
memmove(send_buffer, send_buffer + acknowledged_bytes, m_send_window.data_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_send_window.send_time_ms = 0;
|
||||
m_send_window.last_send_ms = 0;
|
||||
}
|
||||
|
||||
dprintln_if(DEBUG_TCP, "Target acknowledged {} bytes", acknowledged_bytes);
|
||||
|
@ -498,20 +475,20 @@ namespace Kernel
|
|||
continue;
|
||||
}
|
||||
|
||||
if (is_send_open && m_send_window.data_size > 0 && current_ms >= m_send_window.send_time_ms + retransmit_timeout_ms)
|
||||
if (m_send_window.data_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms)
|
||||
{
|
||||
ASSERT(m_connection_info.has_value());
|
||||
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
|
||||
auto target_address_len = m_connection_info->address_len;
|
||||
|
||||
const uint32_t total_send = BAN::Math::min<uint32_t>(m_send_window.data_size, m_send_window.size << m_send_window.scale);
|
||||
const uint32_t total_send = BAN::Math::min<uint32_t>(m_send_window.data_size, m_send_window.scaled_size());
|
||||
|
||||
m_send_window.current_seq = m_send_window.start_seq;
|
||||
|
||||
auto* send_buffer = reinterpret_cast<const uint8_t*>(m_send_window.window->vaddr());
|
||||
auto* send_buffer = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr());
|
||||
for (uint32_t i = 0; i < total_send;)
|
||||
{
|
||||
uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss);
|
||||
const uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss);
|
||||
|
||||
auto message = BAN::ConstByteSpan(send_buffer + i, to_send);
|
||||
|
||||
|
@ -527,7 +504,7 @@ namespace Kernel
|
|||
i += to_send;
|
||||
}
|
||||
|
||||
m_send_window.send_time_ms = current_ms;
|
||||
m_send_window.last_send_ms = current_ms;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
@ -571,7 +548,7 @@ namespace Kernel
|
|||
|
||||
uint32_t to_recv = BAN::Math::min<uint32_t>(buffer.size(), m_recv_window.data_size);
|
||||
|
||||
auto* recv_buffer = reinterpret_cast<uint8_t*>(m_recv_window.window->vaddr());
|
||||
auto* recv_buffer = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
|
||||
memcpy(buffer.data(), recv_buffer, to_recv);
|
||||
|
||||
m_recv_window.data_size -= to_recv;
|
||||
|
@ -582,20 +559,28 @@ namespace Kernel
|
|||
return to_recv;
|
||||
}
|
||||
|
||||
BAN::ErrorOr<size_t> TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t)
|
||||
BAN::ErrorOr<size_t> TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
|
||||
{
|
||||
if (address)
|
||||
return BAN::Error::from_errno(EISCONN);
|
||||
|
||||
if (message.size() > m_send_window.window->size())
|
||||
return BAN::Error::from_errno(EMSGSIZE);
|
||||
if (message.size() > m_send_window.buffer->size())
|
||||
{
|
||||
for (size_t i = 0; i < message.size(); i++)
|
||||
{
|
||||
const size_t to_send = BAN::Math::min<size_t>(message.size() - i, m_send_window.buffer->size());
|
||||
TRY(sendto_impl(message.slice(i, to_send), address, address_len));
|
||||
i += to_send;
|
||||
}
|
||||
return message.size();
|
||||
}
|
||||
|
||||
LockGuard _(m_lock);
|
||||
|
||||
if (m_state == State::Closed)
|
||||
return BAN::Error::from_errno(ENOTCONN);
|
||||
|
||||
while (m_send_window.data_size + message.size() > m_send_window.window->size())
|
||||
while (true)
|
||||
{
|
||||
switch (m_state)
|
||||
{
|
||||
|
@ -614,12 +599,15 @@ namespace Kernel
|
|||
case State::Closing: ASSERT_NOT_REACHED();
|
||||
};
|
||||
|
||||
if (m_send_window.data_size + message.size() <= m_send_window.buffer->size())
|
||||
break;
|
||||
|
||||
LockFreeGuard free(m_lock);
|
||||
TRY(Thread::current().block_or_eintr_indefinite(m_semaphore));
|
||||
}
|
||||
|
||||
{
|
||||
auto* buffer = reinterpret_cast<uint8_t*>(m_send_window.window->vaddr());
|
||||
auto* buffer = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
|
||||
memcpy(buffer + m_send_window.data_size, message.data(), message.size());
|
||||
m_send_window.data_size += message.size();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue