forked from Bananymous/banan-os
396 lines
12 KiB
C++
396 lines
12 KiB
C++
#include <kernel/Memory/Heap.h>
|
|
#include <kernel/Memory/PageTable.h>
|
|
#include <kernel/Networking/ICMP.h>
|
|
#include <kernel/Networking/IPv4Layer.h>
|
|
#include <kernel/Networking/NetworkManager.h>
|
|
#include <kernel/Networking/TCPSocket.h>
|
|
#include <kernel/Networking/UDPSocket.h>
|
|
#include <kernel/Random.h>
|
|
|
|
#include <netinet/in.h>
|
|
|
|
namespace Kernel
|
|
{
|
|
|
|
enum IPv4Flags : uint16_t
|
|
{
|
|
DF = 1 << 14,
|
|
};
|
|
|
|
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
|
|
{
|
|
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
|
|
ipv4_manager->m_process = Process::create_kernel(
|
|
[](void* ipv4_manager_ptr)
|
|
{
|
|
auto& ipv4_manager = *reinterpret_cast<IPv4Layer*>(ipv4_manager_ptr);
|
|
ipv4_manager.packet_handle_task();
|
|
}, ipv4_manager.ptr()
|
|
);
|
|
ASSERT(ipv4_manager->m_process);
|
|
ipv4_manager->m_pending_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
|
|
PageTable::kernel(),
|
|
KERNEL_OFFSET,
|
|
~(uintptr_t)0,
|
|
pending_packet_buffer_size,
|
|
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
|
true
|
|
));
|
|
ipv4_manager->m_arp_table = TRY(ARPTable::create());
|
|
return ipv4_manager;
|
|
}
|
|
|
|
IPv4Layer::IPv4Layer()
|
|
{ }
|
|
|
|
IPv4Layer::~IPv4Layer()
|
|
{
|
|
if (m_process)
|
|
m_process->exit(0, SIGKILL);
|
|
m_process = nullptr;
|
|
}
|
|
|
|
void IPv4Layer::add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const
|
|
{
|
|
auto& header = packet.as<IPv4Header>();
|
|
header.version_IHL = 0x45;
|
|
header.DSCP_ECN = 0x00;
|
|
header.total_length = packet.size();
|
|
header.identification = 1;
|
|
header.flags_frament = 0x00;
|
|
header.time_to_live = 0x40;
|
|
header.protocol = protocol;
|
|
header.src_address = src_ipv4;
|
|
header.dst_address = dst_ipv4;
|
|
header.checksum = 0;
|
|
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {});
|
|
}
|
|
|
|
void IPv4Layer::unbind_socket(uint16_t port)
|
|
{
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
auto it = m_bound_sockets.find(port);
|
|
ASSERT(it != m_bound_sockets.end());
|
|
m_bound_sockets.remove(it);
|
|
}
|
|
|
|
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_unused(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
|
|
{
|
|
if (!address || address_len < (socklen_t)sizeof(sockaddr_in))
|
|
return BAN::Error::from_errno(EINVAL);
|
|
if (address->sa_family != AF_INET)
|
|
return BAN::Error::from_errno(EAFNOSUPPORT);
|
|
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
|
|
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
|
|
uint16_t port = NetworkSocket::PORT_NONE;
|
|
for (uint32_t i = 0; i < 100 && port == NetworkSocket::PORT_NONE; i++)
|
|
if (uint32_t temp = 0xC000 | (Random::get_u32() & 0x3FFF); !m_bound_sockets.contains(temp))
|
|
port = temp;
|
|
for (uint32_t temp = 0xC000; temp < 0xFFFF && port == NetworkSocket::PORT_NONE; temp++)
|
|
if (!m_bound_sockets.contains(temp))
|
|
port = temp;
|
|
if (port == NetworkSocket::PORT_NONE)
|
|
{
|
|
dwarnln("No ports available");
|
|
return BAN::Error::from_errno(EAGAIN);
|
|
}
|
|
dprintln_if(DEBUG_IPV4, "using port {}", port);
|
|
|
|
struct sockaddr_in target;
|
|
target.sin_family = AF_INET;
|
|
target.sin_port = BAN::host_to_network_endian(port);
|
|
target.sin_addr.s_addr = sockaddr_in.sin_addr.s_addr;
|
|
return bind_socket_to_address(socket, (sockaddr*)&target, sizeof(sockaddr_in));
|
|
}
|
|
|
|
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_address(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
|
|
{
|
|
if (NetworkManager::get().interfaces().empty())
|
|
return BAN::Error::from_errno(EADDRNOTAVAIL);
|
|
|
|
if (!address || address_len < (socklen_t)sizeof(sockaddr_in))
|
|
return BAN::Error::from_errno(EINVAL);
|
|
if (address->sa_family != AF_INET)
|
|
return BAN::Error::from_errno(EAFNOSUPPORT);
|
|
|
|
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
|
|
const uint16_t port = BAN::host_to_network_endian(sockaddr_in.sin_port);
|
|
if (port == NetworkSocket::PORT_NONE)
|
|
return bind_socket_to_unused(socket, address, address_len);
|
|
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
|
|
if (m_bound_sockets.contains(port))
|
|
return BAN::Error::from_errno(EADDRINUSE);
|
|
TRY(m_bound_sockets.insert(port, TRY(socket->get_weak_ptr())));
|
|
|
|
// FIXME: actually determine proper interface
|
|
auto interface = NetworkManager::get().interfaces().front();
|
|
socket->bind_interface_and_port(interface.ptr(), port);
|
|
|
|
return {};
|
|
}
|
|
|
|
BAN::ErrorOr<void> IPv4Layer::get_socket_address(BAN::RefPtr<NetworkSocket> socket, sockaddr* address, socklen_t* address_len)
|
|
{
|
|
if (*address_len < (socklen_t)sizeof(sockaddr_in))
|
|
return BAN::Error::from_errno(ENOBUFS);
|
|
|
|
sockaddr_in* in_addr = reinterpret_cast<sockaddr_in*>(address);
|
|
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
for (auto& [bound_port, bound_socket] : m_bound_sockets)
|
|
{
|
|
if (socket != bound_socket.lock())
|
|
continue;
|
|
// FIXME: sockets should have bound address
|
|
in_addr->sin_family = AF_INET;
|
|
in_addr->sin_port = bound_port;
|
|
in_addr->sin_addr.s_addr = INADDR_ANY;
|
|
return {};
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len)
|
|
{
|
|
if (address->sa_family != AF_INET)
|
|
return BAN::Error::from_errno(EINVAL);
|
|
if (address == nullptr || address_len != sizeof(sockaddr_in))
|
|
return BAN::Error::from_errno(EINVAL);
|
|
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
|
|
|
|
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
|
|
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
|
|
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4));
|
|
|
|
BAN::Vector<uint8_t> packet_buffer;
|
|
TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size()));
|
|
auto packet = BAN::ByteSpan { packet_buffer.span() };
|
|
|
|
auto pseudo_header = PseudoHeader {
|
|
.src_ipv4 = socket.interface().get_ipv4_address(),
|
|
.dst_ipv4 = dst_ipv4,
|
|
.protocol = socket.protocol()
|
|
};
|
|
|
|
memcpy(
|
|
packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(),
|
|
buffer.data(),
|
|
buffer.size()
|
|
);
|
|
socket.add_protocol_header(
|
|
packet.slice(sizeof(IPv4Header)),
|
|
dst_port,
|
|
pseudo_header
|
|
);
|
|
add_ipv4_header(
|
|
packet,
|
|
socket.interface().get_ipv4_address(),
|
|
dst_ipv4,
|
|
socket.protocol()
|
|
);
|
|
|
|
TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet));
|
|
|
|
return buffer.size();
|
|
}
|
|
|
|
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet)
|
|
{
|
|
ASSERT(packet.size() >= sizeof(IPv4Header));
|
|
auto& ipv4_header = packet.as<const IPv4Header>();
|
|
auto ipv4_data = packet.slice(sizeof(IPv4Header));
|
|
|
|
auto src_ipv4 = ipv4_header.src_address;
|
|
|
|
uint16_t dst_port = NetworkSocket::PORT_NONE;
|
|
uint16_t src_port = NetworkSocket::PORT_NONE;
|
|
|
|
switch (ipv4_header.protocol)
|
|
{
|
|
case NetworkProtocol::ICMP:
|
|
{
|
|
if (ipv4_data.size() < sizeof(ICMPHeader))
|
|
{
|
|
dwarnln("IPv4 packet too small for ICMP");
|
|
return {};
|
|
}
|
|
auto& icmp_header = ipv4_data.as<const ICMPHeader>();
|
|
switch (icmp_header.type)
|
|
{
|
|
case ICMPType::EchoRequest:
|
|
{
|
|
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4));
|
|
|
|
auto& reply_icmp_header = ipv4_data.as<ICMPHeader>();
|
|
reply_icmp_header.type = ICMPType::EchoReply;
|
|
reply_icmp_header.checksum = 0;
|
|
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data, {});
|
|
|
|
add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP);
|
|
|
|
TRY(interface.send_bytes(dst_mac, EtherType::IPv4, packet));
|
|
break;
|
|
}
|
|
case ICMPType::DestinationUnreachable:
|
|
{
|
|
auto& ipv4_header = ipv4_data.slice(sizeof(ICMPHeader)).as<const IPv4Header>();
|
|
dprintln("Destination '{}' unreachable, code {2H}", ipv4_header.dst_address, icmp_header.code);
|
|
// FIXME: inform the socket
|
|
break;
|
|
}
|
|
default:
|
|
dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type);
|
|
break;
|
|
}
|
|
return {};
|
|
}
|
|
case NetworkProtocol::UDP:
|
|
{
|
|
if (ipv4_data.size() < sizeof(UDPHeader))
|
|
{
|
|
dwarnln("IPv4 packet too small for UDP");
|
|
return {};
|
|
}
|
|
auto& udp_header = ipv4_data.as<const UDPHeader>();
|
|
dst_port = udp_header.dst_port;
|
|
src_port = udp_header.src_port;
|
|
break;
|
|
}
|
|
case NetworkProtocol::TCP:
|
|
{
|
|
if (ipv4_data.size() < sizeof(TCPHeader))
|
|
{
|
|
dwarnln("IPv4 packet too small for TCP");
|
|
return {};
|
|
}
|
|
auto& tcp_header = ipv4_data.as<const TCPHeader>();
|
|
dst_port = tcp_header.dst_port;
|
|
src_port = tcp_header.src_port;
|
|
break;
|
|
}
|
|
default:
|
|
dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol);
|
|
return {};
|
|
}
|
|
|
|
ASSERT(dst_port != NetworkSocket::PORT_NONE);
|
|
ASSERT(src_port != NetworkSocket::PORT_NONE);
|
|
|
|
BAN::RefPtr<Kernel::NetworkSocket> bound_socket;
|
|
|
|
{
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
auto it = m_bound_sockets.find(dst_port);
|
|
if (it == m_bound_sockets.end())
|
|
{
|
|
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
|
|
return {};
|
|
}
|
|
bound_socket = it->value.lock();
|
|
}
|
|
|
|
if (!bound_socket)
|
|
{
|
|
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
|
|
return {};
|
|
}
|
|
|
|
if (bound_socket->protocol() != ipv4_header.protocol)
|
|
{
|
|
dprintln_if(DEBUG_IPV4, "got data with wrong protocol ({}) on port {} (bound as {})", ipv4_header.protocol, dst_port, (uint8_t)bound_socket->protocol());
|
|
return {};
|
|
}
|
|
|
|
sockaddr_in sender;
|
|
sender.sin_family = AF_INET;
|
|
sender.sin_port = BAN::host_to_network_endian(src_port);
|
|
sender.sin_addr.s_addr = src_ipv4.raw;
|
|
bound_socket->receive_packet(ipv4_data, reinterpret_cast<const sockaddr*>(&sender), sizeof(sender));
|
|
|
|
return {};
|
|
}
|
|
|
|
void IPv4Layer::packet_handle_task()
|
|
{
|
|
for (;;)
|
|
{
|
|
PendingIPv4Packet pending = ({
|
|
auto state = m_pending_lock.lock();
|
|
while (m_pending_packets.empty())
|
|
{
|
|
m_pending_lock.unlock(state);
|
|
m_pending_thread_blocker.block_indefinite();
|
|
state = m_pending_lock.lock();
|
|
}
|
|
auto packet = m_pending_packets.front();
|
|
m_pending_packets.pop();
|
|
m_pending_lock.unlock(state);
|
|
|
|
packet;
|
|
});
|
|
|
|
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
|
|
const size_t ipv4_packet_size = reinterpret_cast<const IPv4Header*>(buffer_start)->total_length;
|
|
|
|
if (auto ret = handle_ipv4_packet(pending.interface, BAN::ByteSpan(buffer_start, ipv4_packet_size)); ret.is_error())
|
|
dwarnln_if(DEBUG_IPV4, "{}", ret.error());
|
|
|
|
SpinLockGuard _(m_pending_lock);
|
|
m_pending_total_size -= ipv4_packet_size;
|
|
if (m_pending_total_size)
|
|
memmove(buffer_start, buffer_start + ipv4_packet_size, m_pending_total_size);
|
|
}
|
|
}
|
|
|
|
void IPv4Layer::add_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer)
|
|
{
|
|
if (buffer.size() < sizeof(IPv4Header))
|
|
{
|
|
dwarnln_if(DEBUG_IPV4, "IPv4 packet too small");
|
|
return;
|
|
}
|
|
|
|
SpinLockGuard _(m_pending_lock);
|
|
|
|
if (m_pending_packets.full())
|
|
{
|
|
dwarnln_if(DEBUG_IPV4, "IPv4 packet queue full");
|
|
return;
|
|
}
|
|
|
|
if (m_pending_total_size + buffer.size() > m_pending_packet_buffer->size())
|
|
{
|
|
dwarnln_if(DEBUG_IPV4, "IPv4 packet queue full");
|
|
return;
|
|
}
|
|
|
|
auto& ipv4_header = buffer.as<const IPv4Header>();
|
|
if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header), {}) != 0)
|
|
{
|
|
dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet");
|
|
return;
|
|
}
|
|
if (ipv4_header.total_length > buffer.size() || ipv4_header.total_length > interface.payload_mtu())
|
|
{
|
|
if (ipv4_header.flags_frament & IPv4Flags::DF)
|
|
dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet");
|
|
else
|
|
dwarnln_if(DEBUG_IPV4, "IPv4 fragmentation not supported");
|
|
return;
|
|
}
|
|
|
|
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
|
|
memcpy(buffer_start + m_pending_total_size, buffer.data(), ipv4_header.total_length);
|
|
m_pending_total_size += ipv4_header.total_length;
|
|
|
|
m_pending_packets.push({ .interface = interface });
|
|
m_pending_thread_blocker.unblock();
|
|
}
|
|
|
|
}
|