banan-os/kernel/kernel/Networking/TCPSocket.cpp

1014 lines
30 KiB
C++

#include <kernel/Lock/LockGuard.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/TCPSocket.h>
#include <kernel/Process.h>
#include <kernel/Random.h>
#include <kernel/Timer/Timer.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/epoll.h>
#include <sys/ioctl.h>
namespace Kernel
{
enum TCPOption : uint8_t
{
End = 0x00,
NOP = 0x01,
MaximumSeqmentSize = 0x02,
WindowScale = 0x03,
};
static constexpr size_t s_recv_window_buffer_size = 16 * PAGE_SIZE;
static constexpr size_t s_send_window_buffer_size = 16 * PAGE_SIZE;
// allows upto 1 MiB windows
static constexpr uint8_t s_window_shift = 4;
// https://www.rfc-editor.org/rfc/rfc1122 4.2.2.6
static constexpr uint16_t s_default_mss = 536;
BAN::ErrorOr<BAN::RefPtr<TCPSocket>> TCPSocket::create(NetworkLayer& network_layer, const Info& info)
{
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, info));
socket->m_last_sent_window_size = s_recv_window_buffer_size;
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(vaddr_t)0,
s_recv_window_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false
));
socket->m_recv_window.scale_shift = s_window_shift;
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(vaddr_t)0,
s_send_window_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false
));
socket->m_thread = TRY(Thread::create_kernel(
[](void* socket_ptr)
{
reinterpret_cast<TCPSocket*>(socket_ptr)->process_task();
}, socket.ptr()
));
TRY(Processor::scheduler().add_thread(socket->m_thread));
// hack to keep socket alive until its process starts
socket->ref();
return socket;
}
TCPSocket::TCPSocket(NetworkLayer& network_layer, const Info& info)
: NetworkSocket(network_layer, info)
{
m_send_window.start_seq = Random::get_u32() & 0x7FFFFFFF;
m_send_window.current_seq = m_send_window.start_seq;
}
TCPSocket::~TCPSocket()
{
ASSERT(!is_bound());
ASSERT(m_thread == nullptr);
dprintln_if(DEBUG_TCP, "Socket destroyed");
}
BAN::ErrorOr<long> TCPSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags)
{
if (m_state != State::Listen)
return BAN::Error::from_errno(EINVAL);
while (m_pending_connections.empty())
TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex));
auto connection = m_pending_connections.front();
m_pending_connections.pop();
auto listen_key = ListenKey(
reinterpret_cast<const sockaddr*>(&connection.target.address),
connection.target.address_len
);
if (auto it = m_listen_children.find(listen_key); it != m_listen_children.end())
return BAN::Error::from_errno(ECONNABORTED);
BAN::RefPtr<TCPSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(m_network_layer.domain(), Socket::Type::STREAM, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = static_cast<TCPSocket*>(return_inode_tmp.ptr());
}
return_inode->m_mutex.lock();
memcpy(&return_inode->m_address, &m_address, m_address_len);
return_inode->m_address_len = m_address_len;
return_inode->m_listen_parent = this;
return_inode->m_connection_info.emplace(connection.target);
return_inode->m_recv_window.start_seq = connection.target_start_seq;
return_inode->m_send_window.scale_shift = connection.window_scale;
return_inode->m_send_window.mss = connection.maximum_seqment_size;
return_inode->m_next_flags = SYN | ACK;
return_inode->m_next_state = State::SynReceived;
if (!return_inode->m_connection_info->has_window_scale)
return_inode->m_recv_window.scale_shift = 0;
return_inode->m_mutex.unlock();
TRY(m_listen_children.emplace(listen_key, return_inode));
const uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000;
while (!return_inode->m_has_connected)
TRY(Thread::current().block_or_eintr_or_waketime_ms(return_inode->m_thread_blocker, wake_time_ms, true, &m_mutex));
if (address)
{
ASSERT(address_len);
*address_len = BAN::Math::min(*address_len, connection.target.address_len);
memcpy(address, &connection.target.address, *address_len);
}
return TRY(Process::current().open_inode(VirtualFileSystem::File(return_inode, "<tcp socket>"_sv), O_RDWR | flags));
}
BAN::ErrorOr<void> TCPSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len > (socklen_t)sizeof(sockaddr_storage))
address_len = sizeof(sockaddr_storage);
LockGuard _(m_mutex);
if (m_connection_info.has_value())
return BAN::Error::from_errno(EISCONN);
switch (m_state)
{
case State::Closed:
break;
case State::SynSent:
case State::SynReceived:
return BAN::Error::from_errno(EALREADY);
case State::Established:
case State::FinWait1:
case State::FinWait2:
case State::CloseWait:
case State::Closing:
case State::LastAck:
case State::TimeWait:
case State::Listen:
return BAN::Error::from_errno(EISCONN);
};
if (!is_bound())
TRY(m_network_layer.bind_socket_with_target(this, address, address_len));
m_connection_info.emplace(sockaddr_storage {}, address_len, true);
memcpy(&m_connection_info->address, address, address_len);
m_next_flags = SYN;
if (m_network_layer.sendto(*this, {}, address, address_len).is_error())
{
set_connection_as_closed();
return BAN::Error::from_errno(ECONNREFUSED);
}
m_next_flags = 0;
m_state = State::SynSent;
const uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000;
while (!m_has_connected)
TRY(Thread::current().block_or_eintr_or_waketime_ms(m_thread_blocker, wake_time_ms, true, &m_mutex));
return {};
}
BAN::ErrorOr<void> TCPSocket::listen_impl(int backlog)
{
if (!is_bound())
return BAN::Error::from_errno(EDESTADDRREQ);
if (m_connection_info.has_value())
return BAN::Error::from_errno(EINVAL);
backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN);
TRY(m_pending_connections.reserve(backlog));
m_state = State::Listen;
return {};
}
BAN::ErrorOr<void> TCPSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (is_bound())
return BAN::Error::from_errno(EINVAL);
return m_network_layer.bind_socket_to_address(this, address, address_len);
}
BAN::ErrorOr<size_t> TCPSocket::recvmsg_impl(msghdr& message, int flags)
{
flags &= (MSG_OOB | MSG_PEEK | MSG_WAITALL);
if (flags != 0)
{
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
{
dprintln_if(DEBUG_TCP, "ignoring recvmsg control message");
message.msg_controllen = 0;
}
if (!m_has_connected)
return BAN::Error::from_errno(ENOTCONN);
while (m_recv_window.data_size == 0)
{
if (m_state != State::Established)
return return_with_maybe_zero();
TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex));
}
message.msg_flags = 0;
size_t total_recv = 0;
for (int i = 0; i < message.msg_iovlen; i++)
{
const auto* recv_base = reinterpret_cast<const uint8_t*>(m_recv_window.buffer->vaddr());
uint8_t* iov_base = static_cast<uint8_t*>(message.msg_iov[i].iov_base);
const size_t nrecv = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_recv_window.data_size);
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - m_recv_window.data_tail);
memcpy(iov_base, recv_base + m_recv_window.data_tail, before_wrap);
if (const size_t after_wrap = nrecv - before_wrap)
memcpy(iov_base + before_wrap, recv_base, after_wrap);
total_recv += nrecv;
m_recv_window.data_size -= nrecv;
m_recv_window.start_seq += nrecv;
m_recv_window.data_tail = (m_recv_window.data_tail + nrecv) % m_recv_window.buffer->size();
if (m_recv_window.data_size == 0)
break;
}
const size_t update_window_threshold = m_recv_window.buffer->size() / 8;
const bool should_update_window_size =
m_last_sent_window_size != m_recv_window.buffer->size() && (
(m_last_sent_window_size == 0) ||
(m_recv_window.data_size == 0) ||
(m_last_sent_window_size + update_window_threshold <= m_recv_window.buffer->size() - m_recv_window.data_size)
);
if (should_update_window_size)
{
m_should_send_ack = true;
m_thread_blocker.unblock();
}
return total_recv;
}
BAN::ErrorOr<size_t> TCPSocket::sendmsg_impl(const msghdr& message, int flags)
{
if (flags & ~(MSG_NOSIGNAL | MSG_DONTWAIT))
{
dwarnln("TODO: sendmsg with flags 0x{H}", flags);
return BAN::Error::from_errno(ENOTSUP);
}
if (CMSG_FIRSTHDR(&message))
dwarnln("ignoring sendmsg control message");
if (!m_has_connected)
return BAN::Error::from_errno(ENOTCONN);
while (m_send_window.data_size == m_send_window.buffer->size())
{
if (m_state != State::Established)
return return_with_maybe_zero();
if (flags & MSG_DONTWAIT)
return BAN::Error::from_errno(EAGAIN);
TRY(Thread::current().block_or_eintr_indefinite(m_thread_blocker, &m_mutex));
}
size_t total_sent = 0;
for (int i = 0; i < message.msg_iovlen; i++)
{
auto* send_base = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
const auto* iov_base = static_cast<const uint8_t*>(message.msg_iov[i].iov_base);
const size_t nsend = BAN::Math::min<size_t>(message.msg_iov[i].iov_len, m_send_window.buffer->size() - m_send_window.data_size);
const size_t send_head = (m_send_window.data_tail + m_send_window.data_size) % m_send_window.buffer->size();
const size_t before_wrap = BAN::Math::min(nsend, m_send_window.buffer->size() - send_head);
memcpy(send_base + send_head, message.msg_iov[i].iov_base, before_wrap);
if (const size_t after_wrap = nsend - before_wrap)
memcpy(send_base, iov_base + before_wrap, after_wrap);
total_sent += nsend;
m_send_window.data_size += nsend;
if (m_send_window.data_size == m_send_window.buffer->size())
break;
}
m_thread_blocker.unblock();
return total_sent;
}
BAN::ErrorOr<void> TCPSocket::getpeername_impl(sockaddr* address, socklen_t* address_len)
{
if (!m_has_connected && m_state != State::Established)
return BAN::Error::from_errno(ENOTCONN);
ASSERT(m_connection_info.has_value());
const size_t to_copy = BAN::Math::min(m_connection_info->address_len, *address_len);
memcpy(address, &m_connection_info->address, to_copy);
*address_len = to_copy;
return {};
}
BAN::ErrorOr<void> TCPSocket::getsockopt_impl(int level, int option, void* value, socklen_t* value_len)
{
int result;
switch (level)
{
case SOL_SOCKET:
switch (option)
{
case SO_KEEPALIVE:
result = m_keep_alive;
break;
case SO_ERROR:
result = 0;
break;
case SO_SNDBUF:
result = m_send_window.scaled_size();
break;
case SO_RCVBUF:
result = m_recv_window.buffer->size();
break;
default:
dwarnln("getsockopt(SOL_SOCKET, {})", option);
return BAN::Error::from_errno(ENOPROTOOPT);
}
break;
case IPPROTO_TCP:
switch (option)
{
case TCP_NODELAY:
result = m_no_delay;
break;
default:
dwarnln("getsockopt(IPPROTO_TCP, {})", option);
return BAN::Error::from_errno(ENOPROTOOPT);
}
break;
default:
dwarnln("getsockopt({}, {})", level, option);
return BAN::Error::from_errno(EINVAL);
}
const size_t len = BAN::Math::min<size_t>(sizeof(result), *value_len);
memcpy(value, &result, len);
*value_len = sizeof(int);
return {};
}
BAN::ErrorOr<void> TCPSocket::setsockopt_impl(int level, int option, const void* value, socklen_t value_len)
{
switch (level)
{
case SOL_SOCKET:
switch (option)
{
case SO_KEEPALIVE:
if (value_len != sizeof(int))
return BAN::Error::from_errno(EINVAL);
m_keep_alive = *static_cast<const int*>(value);
break;
default:
dwarnln("setsockopt(SOL_SOCKET, {})", option);
return BAN::Error::from_errno(ENOPROTOOPT);
}
break;
case IPPROTO_TCP:
switch (option)
{
case TCP_NODELAY:
if (value_len != sizeof(int))
return BAN::Error::from_errno(EINVAL);
m_no_delay = *static_cast<const int*>(value);
break;
default:
dwarnln("setsockopt(IPPROTO_TCP, {})", option);
return BAN::Error::from_errno(ENOPROTOOPT);
}
break;
default:
dwarnln("setsockopt({}, {})", level, option);
return BAN::Error::from_errno(EINVAL);
}
return {};
}
BAN::ErrorOr<long> TCPSocket::ioctl_impl(int request, void* argument)
{
switch (request)
{
case FIONREAD:
*static_cast<int*>(argument) = m_recv_window.data_size;
return 0;
}
return NetworkSocket::ioctl_impl(request, argument);
}
bool TCPSocket::can_read_impl() const
{
if (m_has_connected && !m_has_sent_zero && m_state != State::Established && m_state != State::Listen)
return true;
if (m_state == State::Listen)
return !m_pending_connections.empty();
return m_recv_window.data_size > 0;
}
bool TCPSocket::can_write_impl() const
{
if (m_state != State::Established)
return false;
return m_send_window.data_size < m_send_window.buffer->size();
}
bool TCPSocket::has_hungup_impl() const
{
return m_has_connected && m_state != State::Established;
}
BAN::ErrorOr<size_t> TCPSocket::return_with_maybe_zero()
{
ASSERT(m_state != State::Established);
if (!m_has_sent_zero)
{
m_has_sent_zero = true;
return 0;
}
return BAN::Error::from_errno(ECONNRESET);
}
TCPSocket::ListenKey::ListenKey(const sockaddr* addr, socklen_t addr_len)
{
ASSERT(addr->sa_family == AF_INET);
ASSERT(addr_len >= (socklen_t)sizeof(sockaddr_in));
const auto* addr_in = reinterpret_cast<const sockaddr_in*>(addr);
address = BAN::IPv4Address(addr_in->sin_addr.s_addr);
port = BAN::network_endian_to_host(addr_in->sin_port);
}
bool TCPSocket::ListenKey::operator==(const ListenKey& other) const
{
return address == other.address && port == other.port;
}
BAN::hash_t TCPSocket::ListenKeyHash::operator()(ListenKey key) const
{
return BAN::hash<BAN::IPv4Address>()(key.address) ^ BAN::hash<uint16_t>()(key.port);
}
template<size_t Off, TCPOption Op>
static void add_tcp_header_option(TCPHeader& header, uint32_t value)
{
if constexpr(Op == TCPOption::MaximumSeqmentSize)
{
header.options[Off + 0] = Op;
header.options[Off + 1] = 0x04;
header.options[Off + 2] = value >> 8;
header.options[Off + 3] = value;
}
else if constexpr(Op == TCPOption::WindowScale)
{
header.options[Off + 0] = Op;
header.options[Off + 1] = 0x03;
header.options[Off + 2] = value;
}
}
struct ParsedTCPOptions
{
BAN::Optional<uint16_t> maximum_seqment_size;
BAN::Optional<uint8_t> window_scale;
};
static ParsedTCPOptions parse_tcp_options(const TCPHeader& header)
{
ParsedTCPOptions result;
for (size_t i = 0; i < header.data_offset * sizeof(uint32_t) - sizeof(TCPHeader) - 1; i++)
{
if (header.options[i] == TCPOption::End)
break;
if (header.options[i] == TCPOption::NOP)
continue;
if (header.options[i] == TCPOption::MaximumSeqmentSize)
result.maximum_seqment_size = BAN::network_endian_to_host(*reinterpret_cast<const uint16_t*>(&header.options[i + 2]));
if (header.options[i] == TCPOption::WindowScale)
result.window_scale = header.options[i + 2];
if (header.options[i + 1] == 0)
break;
i += header.options[i + 1] - 1;
}
return result;
}
void TCPSocket::get_protocol_header(BAN::ByteSpan header_buffer, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader pseudo_header)
{
ASSERT(m_next_flags);
ASSERT(m_mutex.locker() == Thread::current().tid());
ASSERT(header_buffer.size() == protocol_header_size());
m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size;
if (m_should_send_zero_window)
m_last_sent_window_size = 0;
m_should_send_ack = false;
m_should_send_zero_window = false;
auto& header = header_buffer.as<TCPHeader>();
header = {
.src_port = bound_port(),
.dst_port = dst_port,
.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte,
.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte,
.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t),
.flags = m_next_flags,
.window_size = BAN::Math::min<size_t>(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift),
.checksum = 0,
.urgent_pointer = 0,
};
memset(header.options, 0, m_tcp_options_bytes);
if (header.flags & FIN)
m_send_window.has_ghost_byte = true;
m_next_flags = 0;
if (m_state == State::Closed || m_state == State::SynReceived)
{
const sockaddr_in target {
.sin_family = AF_INET,
.sin_port = dst_port,
.sin_addr = { .s_addr = pseudo_header.dst_ipv4.raw },
.sin_zero = {},
};
auto interface = MUST(this->interface(reinterpret_cast<const sockaddr*>(&target), sizeof(target)));
add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, interface->payload_mtu() - m_network_layer.header_size() - protocol_header_size());
if (m_connection_info->has_window_scale)
add_tcp_header_option<4, TCPOption::WindowScale>(header, m_recv_window.scale_shift);
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_recv_window.buffer->size());
m_send_window.start_seq++;
m_send_window.current_seq = m_send_window.start_seq;
}
const BAN::ConstByteSpan buffers[] {
BAN::ConstByteSpan::from(pseudo_header),
header_buffer,
payload,
};
header.checksum = calculate_internet_checksum({ buffers, sizeof(buffers) / sizeof(*buffers) });
dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags);
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number);
}
void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr* sender, socklen_t sender_len)
{
if (m_state == State::Listen)
{
auto socket =
[&]() -> BAN::RefPtr<TCPSocket> {
LockGuard _(m_mutex);
if (auto it = m_listen_children.find(ListenKey(sender, sender_len)); it != m_listen_children.end())
return it->value;
return {};
}();
if (socket)
return socket->receive_packet(buffer, sender, sender_len);
}
{
uint16_t checksum = 0;
if (sender->sa_family == AF_INET)
{
auto interface_or_error = interface(sender, sender_len);
if (interface_or_error.is_error())
return;
auto interface = interface_or_error.release_value();
auto& addr_in = *reinterpret_cast<const sockaddr_in*>(sender);
const PseudoHeader pseudo_header {
.src_ipv4 = BAN::IPv4Address(addr_in.sin_addr.s_addr),
.dst_ipv4 = interface->get_ipv4_address(),
.protocol = NetworkProtocol::TCP,
.length = buffer.size(),
};
const BAN::ConstByteSpan buffers[] {
BAN::ConstByteSpan::from(pseudo_header),
buffer
};
checksum = calculate_internet_checksum({ buffers, sizeof(buffers) / sizeof(*buffers) });
}
else
{
dwarnln("No tcp checksum validation for socket family {}", sender->sa_family);
return;
}
if (checksum != 0)
{
dprintln("Checksum does not match");
return;
}
}
LockGuard _(m_mutex);
const bool hungup_before = has_hungup_impl();
auto& header = buffer.as<const TCPHeader>();
dprintln_if(DEBUG_TCP, "receiving {} {8b}", (uint8_t)m_state, header.flags);
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
dprintln_if(DEBUG_TCP, " seq {}", (uint32_t)header.seq_number);
m_send_window.non_scaled_size = header.window_size;
if (m_send_window.scaled_size() == 0)
m_send_window.had_zero_window = true;
bool check_payload = false;
switch (m_state)
{
case State::Closed:
break;
case State::SynSent:
{
if (header.flags != (SYN | ACK))
break;
if (header.ack_number != m_send_window.current_seq)
{
dprintln_if(DEBUG_TCP, "Invalid ack number in SYN/ACK");
break;
}
auto options = parse_tcp_options(header);
m_send_window.mss = options.maximum_seqment_size.value_or(s_default_mss);
if (options.window_scale.has_value())
m_send_window.scale_shift = *options.window_scale;
else
{
m_recv_window.scale_shift = 0;
m_connection_info->has_window_scale = false;
}
m_send_window.start_seq = m_send_window.current_seq;
m_send_window.current_ack = m_send_window.current_seq;
m_recv_window.start_seq = header.seq_number + 1;
m_next_flags = ACK;
m_next_state = State::Established;
break;
}
case State::SynReceived:
if (header.flags != ACK)
break;
m_state = State::Established;
m_has_connected = true;
break;
case State::Listen:
if (header.flags != SYN)
dprintln_if(DEBUG_TCP, "Unexpected packet to listening socket");
else if (m_pending_connections.size() == m_pending_connections.capacity())
dprintln_if(DEBUG_TCP, "No storage to store pending connection");
else
{
const auto options = parse_tcp_options(header);
ConnectionInfo connection_info;
memcpy(&connection_info.address, sender, sender_len);
connection_info.address_len = sender_len;
connection_info.has_window_scale = options.window_scale.has_value();
MUST(m_pending_connections.emplace(
connection_info,
header.seq_number + 1,
options.maximum_seqment_size.value_or(s_default_mss),
options.window_scale.value_or(0)
));
epoll_notify(EPOLLIN);
m_thread_blocker.unblock();
}
return;
case State::Established:
check_payload = true;
if (!(header.flags & FIN))
break;
if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number)
break;
m_next_flags = FIN | ACK;
m_next_state = State::LastAck;
break;
case State::CloseWait:
check_payload = true;
if (!(header.flags & FIN))
break;
m_next_flags = FIN;
m_next_state = State::LastAck;
break;
case State::LastAck:
check_payload = true;
if (!(header.flags & ACK))
break;
set_connection_as_closed();
break;
case State::FinWait1:
check_payload = true;
if (!(header.flags & (FIN | ACK)))
break;
if ((header.flags & (FIN | ACK)) == (FIN | ACK))
m_next_state = State::TimeWait;
else if (header.flags & FIN)
m_next_state = State::Closing;
else if (header.flags & ACK)
m_state = State::FinWait2;
else
m_next_flags = ACK;
break;
case State::FinWait2:
check_payload = true;
if (!(header.flags & FIN))
break;
m_next_flags = ACK;
m_next_state = State::TimeWait;
break;
case State::Closing:
check_payload = true;
if (!(header.flags & ACK))
break;
m_state = State::TimeWait;
break;
case State::TimeWait:
check_payload = true;
break;
}
const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
if (header.seq_number > expected_seq)
dprintln_if(DEBUG_TCP, "Missing packets");
else if (check_payload)
{
if (header.flags & FIN)
m_recv_window.has_ghost_byte = true;
if (header.ack_number > m_send_window.current_ack)
m_send_window.current_ack = header.ack_number;
auto payload = buffer.slice(header.data_offset * sizeof(uint32_t));
if (header.seq_number < expected_seq)
{
const uint32_t already_received_bytes = expected_seq - header.seq_number;
if (already_received_bytes <= payload.size())
payload = payload.slice(already_received_bytes);
else
payload = {};
}
const bool can_receive_new_data = (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size());
if (can_receive_new_data)
{
auto* recv_base = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
const size_t nrecv = BAN::Math::min(payload.size(), m_recv_window.buffer->size() - m_recv_window.data_size);
const size_t recv_head = (m_recv_window.data_tail + m_recv_window.data_size) % m_recv_window.buffer->size();
const size_t before_wrap = BAN::Math::min(nrecv, m_recv_window.buffer->size() - recv_head);
memcpy(recv_base + recv_head, payload.data(), before_wrap);
if (const size_t after_wrap = nrecv - before_wrap)
memcpy(recv_base, payload.data() + before_wrap, after_wrap);
m_recv_window.data_size += nrecv;
epoll_notify(EPOLLIN);
dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv);
}
// make sure zero window is reported
if (m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size())
m_should_send_zero_window = true;
else if (can_receive_new_data)
m_should_send_ack = true;
}
if (!hungup_before && has_hungup_impl())
epoll_notify(EPOLLHUP);
m_thread_blocker.unblock();
}
void TCPSocket::set_connection_as_closed()
{
if (is_bound())
{
// NOTE: Only listen socket can unbind the socket as
// listen socket is always alive to redirect packets
if (!m_listen_parent)
m_network_layer.unbind_socket(bound_port());
else
m_listen_parent->remove_listen_child(this);
m_address.ss_family = AF_UNSPEC;
m_address_len = 0;
dprintln_if(DEBUG_TCP, "Socket unbound");
}
m_thread = nullptr;
}
void TCPSocket::remove_listen_child(BAN::RefPtr<TCPSocket> socket)
{
LockGuard _(m_mutex);
auto it = m_listen_children.find(ListenKey(
reinterpret_cast<const sockaddr*>(&socket->m_connection_info->address),
socket->m_connection_info->address_len
));
if (it == m_listen_children.end())
{
dwarnln("remove_listen_child with non-mapped socket");
return;
}
m_listen_children.remove(it);
}
void TCPSocket::process_task()
{
// FIXME: this should be dynamic
static constexpr uint32_t retransmit_timeout_ms = 1000;
BAN::RefPtr<TCPSocket> keep_alive { this };
// socket's creation did a manual ref(), let's undo it here
this->unref();
LockGuard _(m_mutex);
while (m_thread)
{
const uint64_t current_ms = SystemTimer::get().ms_since_boot();
switch (m_state)
{
case State::TimeWait:
if (current_ms < m_time_wait_start_ms + 30'000)
break;
// TimeWait timeout
set_connection_as_closed();
continue;
case State::Listen:
if (ref_count() > 1)
break;
// Listen socket closed
// ref_count = keep_alieve
set_connection_as_closed();
continue;
case State::Established:
if (ref_count() > static_cast<uint32_t>(1 + !!m_listen_parent))
break;
// Connected socket closed
// ref_count = keep_alive + listen's hashmap
m_next_flags = FIN | ACK;
m_next_state = State::FinWait1;
break;
default:
break;
}
if (m_next_flags)
{
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
dwarnln("{}", ret.error());
const bool hungup_before = has_hungup_impl();
m_state = m_next_state;
if (m_state == State::Established)
m_has_connected = true;
if (!hungup_before && has_hungup_impl())
epoll_notify(EPOLLHUP);
continue;
}
if (m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
{
const uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
ASSERT(acknowledged_bytes <= m_send_window.data_size);
m_send_window.data_size -= acknowledged_bytes;
m_send_window.start_seq += acknowledged_bytes;
m_send_window.sent_size -= acknowledged_bytes;
m_send_window.data_tail = (m_send_window.data_tail + acknowledged_bytes) % m_send_window.buffer->size();
epoll_notify(EPOLLOUT);
dprintln_if(DEBUG_TCP, "Target acknowledged {} bytes", acknowledged_bytes);
continue;
}
const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
const bool can_send_new_data = (m_send_window.data_size > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data))
{
m_send_window.had_zero_window = false;
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
const size_t send_start_offset = should_retransmit ? 0 : m_send_window.sent_size;
const size_t total_send = BAN::Math::min<size_t>(
m_send_window.data_size - send_start_offset,
m_send_window.scaled_size() - send_start_offset
);
m_send_window.current_seq = m_send_window.start_seq + send_start_offset;
const auto* send_base = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr());
for (size_t i = 0; i < total_send;)
{
const size_t send_offset = (m_send_window.data_tail + send_start_offset + i) % m_send_window.buffer->size();
const size_t max_send = BAN::Math::min<size_t>(total_send - i, m_send_window.mss);
const size_t to_send = BAN::Math::min(max_send, m_send_window.buffer->size() - send_offset);
auto message = BAN::ConstByteSpan(send_base + send_offset, to_send);
m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())
{
dwarnln("{}", ret.error());
break;
}
dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send);
m_send_window.sent_size += to_send;
m_send_window.current_seq += to_send;
i += to_send;
}
m_send_window.last_send_ms = current_ms;
continue;
}
if (const size_t ack_count = m_should_send_ack + m_should_send_zero_window)
{
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
for (size_t i = 0; i < ack_count; i++)
{
m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
dwarnln("{}", ret.error());
}
}
m_thread_blocker.unblock();
m_thread_blocker.block_with_wake_time_ms(current_ms + retransmit_timeout_ms, &m_mutex);
}
m_thread_blocker.unblock();
}
}