398 lines
12 KiB
C++
398 lines
12 KiB
C++
#include <kernel/Memory/Heap.h>
|
|
#include <kernel/Memory/PageTable.h>
|
|
#include <kernel/Lock/SpinLockAsMutex.h>
|
|
#include <kernel/Networking/ICMP.h>
|
|
#include <kernel/Networking/IPv4Layer.h>
|
|
#include <kernel/Networking/NetworkManager.h>
|
|
#include <kernel/Networking/TCPSocket.h>
|
|
#include <kernel/Networking/UDPSocket.h>
|
|
#include <kernel/Random.h>
|
|
|
|
#include <netinet/in.h>
|
|
|
|
namespace Kernel
|
|
{
|
|
|
|
enum IPv4Flags : uint16_t
|
|
{
|
|
DF = 1 << 14,
|
|
};
|
|
|
|
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
|
|
{
|
|
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
|
|
ipv4_manager->m_arp_table = TRY(ARPTable::create());
|
|
return ipv4_manager;
|
|
}
|
|
|
|
static IPv4Header get_ipv4_header(size_t packet_size, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol)
|
|
{
|
|
IPv4Header header {
|
|
.version_IHL = 0x45,
|
|
.DSCP_ECN = 0x00,
|
|
.total_length = packet_size,
|
|
.identification = 1,
|
|
.flags_frament = 0x00,
|
|
.time_to_live = 0x40,
|
|
.protocol = protocol,
|
|
.checksum = 0,
|
|
.src_address = src_ipv4,
|
|
.dst_address = dst_ipv4,
|
|
};
|
|
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header));
|
|
return header;
|
|
}
|
|
|
|
void IPv4Layer::unbind_socket(uint16_t port)
|
|
{
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
auto it = m_bound_sockets.find(port);
|
|
ASSERT(it != m_bound_sockets.end());
|
|
m_bound_sockets.remove(it);
|
|
}
|
|
|
|
BAN::ErrorOr<in_port_t> IPv4Layer::find_free_port()
|
|
{
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
|
|
for (uint32_t i = 0; i < 100; i++)
|
|
if (uint32_t port = 0xC000 | (Random::get_u32() & 0x3FFF); !m_bound_sockets.contains(port))
|
|
return port;
|
|
|
|
for (uint32_t port = 0xC000; port < 0xFFFF; port++)
|
|
if (!m_bound_sockets.contains(port))
|
|
return port;
|
|
|
|
dwarnln("No ports available");
|
|
return BAN::Error::from_errno(EAGAIN);
|
|
}
|
|
|
|
BAN::ErrorOr<void> IPv4Layer::bind_socket_with_target(BAN::RefPtr<NetworkSocket> socket, const sockaddr* target, socklen_t target_len)
|
|
{
|
|
if (!target || target_len < (socklen_t)sizeof(sockaddr_in))
|
|
return BAN::Error::from_errno(EINVAL);
|
|
if (target->sa_family != AF_INET)
|
|
return BAN::Error::from_errno(EAFNOSUPPORT);
|
|
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(target);
|
|
|
|
auto interface =
|
|
TRY([&sockaddr_in]() -> BAN::ErrorOr<BAN::RefPtr<NetworkInterface>> {
|
|
const auto ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
|
|
|
|
// try to find an interface in the same subnet
|
|
const auto& all_interfaces = NetworkManager::get().interfaces();
|
|
for (const auto& interface : all_interfaces)
|
|
{
|
|
const auto netmask = interface->get_netmask();
|
|
if (ipv4.mask(netmask) == interface->get_ipv4_address().mask(netmask))
|
|
return interface;
|
|
}
|
|
|
|
// fallback to non-loopback interface
|
|
// FIXME: make sure target is reachable
|
|
for (const auto& interface : all_interfaces)
|
|
if (interface->type() != NetworkInterface::Type::Loopback)
|
|
return interface;
|
|
|
|
return BAN::Error::from_errno(EHOSTUNREACH);
|
|
}());
|
|
|
|
// FIXME: race condition with port allocation/binding
|
|
struct sockaddr_in bind_address;
|
|
bind_address.sin_family = AF_INET;
|
|
bind_address.sin_port = BAN::host_to_network_endian(TRY(find_free_port()));
|
|
bind_address.sin_addr.s_addr = interface->get_ipv4_address().raw;
|
|
return bind_socket_to_address(socket, (sockaddr*)&bind_address, sizeof(bind_address));
|
|
}
|
|
|
|
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_address(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
|
|
{
|
|
if (!address || address_len < (socklen_t)sizeof(sockaddr_in))
|
|
return BAN::Error::from_errno(EINVAL);
|
|
if (address->sa_family != AF_INET)
|
|
return BAN::Error::from_errno(EAFNOSUPPORT);
|
|
|
|
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
|
|
|
|
TRY([&sockaddr_in]() -> BAN::ErrorOr<void> {
|
|
const auto ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
|
|
|
|
if (ipv4 == 0)
|
|
return {};
|
|
|
|
const auto& all_interfaces = NetworkManager::get().interfaces();
|
|
for (const auto& interface : all_interfaces)
|
|
{
|
|
switch (interface->type())
|
|
{
|
|
case NetworkInterface::Type::Ethernet:
|
|
if (ipv4 == interface->get_ipv4_address())
|
|
return {};
|
|
break;
|
|
case NetworkInterface::Type::Loopback:
|
|
const auto netmask = interface->get_netmask();
|
|
if (ipv4.mask(netmask) == interface->get_ipv4_address().mask(netmask))
|
|
return {};
|
|
break;
|
|
}
|
|
}
|
|
|
|
return BAN::Error::from_errno(EADDRNOTAVAIL);
|
|
}());
|
|
|
|
struct sockaddr_in bind_address;
|
|
memcpy(&bind_address, address, sizeof(sockaddr_in));
|
|
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
|
|
if (bind_address.sin_port == 0)
|
|
bind_address.sin_port = BAN::host_to_network_endian(TRY(find_free_port()));
|
|
const uint16_t port = BAN::network_endian_to_host(bind_address.sin_port);
|
|
|
|
if (m_bound_sockets.contains(port))
|
|
return BAN::Error::from_errno(EADDRINUSE);
|
|
TRY(m_bound_sockets.insert(port, TRY(socket->get_weak_ptr())));
|
|
|
|
socket->bind_address_and_port(reinterpret_cast<struct sockaddr*>(&bind_address), sizeof(bind_address));
|
|
|
|
return {};
|
|
}
|
|
|
|
BAN::ErrorOr<void> IPv4Layer::get_socket_address(BAN::RefPtr<NetworkSocket> socket, sockaddr* address, socklen_t* address_len)
|
|
{
|
|
if (*address_len < (socklen_t)sizeof(sockaddr_in))
|
|
return BAN::Error::from_errno(ENOBUFS);
|
|
|
|
sockaddr_in* in_addr = reinterpret_cast<sockaddr_in*>(address);
|
|
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
for (auto& [bound_port, bound_socket] : m_bound_sockets)
|
|
{
|
|
if (socket != bound_socket.lock())
|
|
continue;
|
|
// FIXME: sockets should have bound address
|
|
in_addr->sin_family = AF_INET;
|
|
in_addr->sin_port = bound_port;
|
|
in_addr->sin_addr.s_addr = INADDR_ANY;
|
|
return {};
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan payload, const sockaddr* address, socklen_t address_len)
|
|
{
|
|
if (address->sa_family != AF_INET)
|
|
return BAN::Error::from_errno(EINVAL);
|
|
if (address == nullptr || address_len != sizeof(sockaddr_in))
|
|
return BAN::Error::from_errno(EINVAL);
|
|
|
|
auto interface = TRY(socket.interface(address, address_len));
|
|
|
|
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
|
|
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
|
|
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
|
|
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(*interface, dst_ipv4));
|
|
|
|
if (interface->type() == NetworkInterface::Type::Loopback)
|
|
{
|
|
BAN::RefPtr<NetworkSocket> receiver;
|
|
|
|
{
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
auto receiver_it = m_bound_sockets.find(dst_port);
|
|
if (receiver_it != m_bound_sockets.end())
|
|
receiver = receiver_it->value.lock();
|
|
}
|
|
|
|
if (!receiver)
|
|
return BAN::Error::from_errno(EADDRNOTAVAIL);
|
|
}
|
|
|
|
const auto ipv4_header = get_ipv4_header(
|
|
sizeof(IPv4Header) + socket.protocol_header_size() + payload.size(),
|
|
interface->get_ipv4_address(),
|
|
dst_ipv4,
|
|
socket.protocol()
|
|
);
|
|
|
|
const auto pseudo_header = PseudoHeader {
|
|
.src_ipv4 = interface->get_ipv4_address(),
|
|
.dst_ipv4 = dst_ipv4,
|
|
.protocol = socket.protocol(),
|
|
.length = socket.protocol_header_size() + payload.size()
|
|
};
|
|
|
|
uint8_t protocol_header_buffer[32];
|
|
ASSERT(socket.protocol_header_size() < sizeof(protocol_header_buffer));
|
|
|
|
auto protocol_header = BAN::ByteSpan::from(protocol_header_buffer).slice(0, socket.protocol_header_size());
|
|
socket.get_protocol_header(protocol_header, payload, dst_port, pseudo_header);
|
|
|
|
BAN::ConstByteSpan buffers[] {
|
|
BAN::ConstByteSpan::from(ipv4_header),
|
|
protocol_header,
|
|
payload,
|
|
};
|
|
|
|
TRY(interface->send_bytes(dst_mac, EtherType::IPv4, { buffers, sizeof(buffers) / sizeof(*buffers) }));
|
|
|
|
return payload.size();
|
|
}
|
|
|
|
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan packet)
|
|
{
|
|
if (packet.size() < sizeof(IPv4Header))
|
|
{
|
|
dwarnln_if(DEBUG_IPV4, "Too small IPv4 packet");
|
|
return {};
|
|
}
|
|
|
|
auto& ipv4_header = packet.as<const IPv4Header>();
|
|
if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header)) != 0)
|
|
{
|
|
dwarnln_if(DEBUG_IPV4, "IPv4 packet checksum failed");
|
|
return {};
|
|
}
|
|
if (ipv4_header.total_length > packet.size() || ipv4_header.total_length > interface.payload_mtu() || ipv4_header.total_length < sizeof(IPv4Header))
|
|
{
|
|
if (ipv4_header.flags_frament & IPv4Flags::DF)
|
|
dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet");
|
|
else
|
|
dwarnln_if(DEBUG_IPV4, "IPv4 fragmentation not supported");
|
|
return {};
|
|
}
|
|
|
|
auto ipv4_data = packet.slice(0, ipv4_header.total_length).slice(sizeof(IPv4Header));
|
|
|
|
auto src_ipv4 = ipv4_header.src_address;
|
|
|
|
uint16_t dst_port = NetworkSocket::PORT_NONE;
|
|
uint16_t src_port = NetworkSocket::PORT_NONE;
|
|
|
|
switch (ipv4_header.protocol)
|
|
{
|
|
case NetworkProtocol::ICMP:
|
|
{
|
|
if (ipv4_data.size() < sizeof(ICMPHeader))
|
|
{
|
|
dwarnln("IPv4 packet too small for ICMP");
|
|
return {};
|
|
}
|
|
auto& icmp_header = ipv4_data.as<const ICMPHeader>();
|
|
switch (icmp_header.type)
|
|
{
|
|
case ICMPType::EchoRequest:
|
|
{
|
|
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4));
|
|
|
|
auto send_ipv4_header = get_ipv4_header(
|
|
ipv4_data.size(),
|
|
interface.get_ipv4_address(),
|
|
src_ipv4,
|
|
NetworkProtocol::ICMP
|
|
);
|
|
|
|
ICMPHeader send_icmp_header {
|
|
.type = ICMPType::EchoReply,
|
|
.code = icmp_header.code,
|
|
.checksum = 0,
|
|
.rest = icmp_header.rest,
|
|
};
|
|
|
|
auto send_payload = ipv4_data.slice(sizeof(ICMPHeader));
|
|
|
|
const BAN::ConstByteSpan send_buffers[] {
|
|
BAN::ConstByteSpan::from(send_ipv4_header),
|
|
BAN::ConstByteSpan::from(send_icmp_header),
|
|
send_payload
|
|
};
|
|
auto send_buffers_span = BAN::Span { send_buffers, sizeof(send_buffers) / sizeof(*send_buffers) };
|
|
|
|
send_icmp_header.checksum = calculate_internet_checksum(send_buffers_span.slice(1));
|
|
|
|
TRY(interface.send_bytes(dst_mac, EtherType::IPv4, send_buffers_span));
|
|
|
|
break;
|
|
}
|
|
case ICMPType::DestinationUnreachable:
|
|
{
|
|
auto& ipv4_header = ipv4_data.slice(sizeof(ICMPHeader)).as<const IPv4Header>();
|
|
dprintln("Destination '{}' unreachable, code {2H}", ipv4_header.dst_address, icmp_header.code);
|
|
// FIXME: inform the socket
|
|
break;
|
|
}
|
|
default:
|
|
dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type);
|
|
break;
|
|
}
|
|
return {};
|
|
}
|
|
case NetworkProtocol::UDP:
|
|
{
|
|
if (ipv4_data.size() < sizeof(UDPHeader))
|
|
{
|
|
dwarnln("IPv4 packet too small for UDP");
|
|
return {};
|
|
}
|
|
auto& udp_header = ipv4_data.as<const UDPHeader>();
|
|
dst_port = udp_header.dst_port;
|
|
src_port = udp_header.src_port;
|
|
break;
|
|
}
|
|
case NetworkProtocol::TCP:
|
|
{
|
|
if (ipv4_data.size() < sizeof(TCPHeader))
|
|
{
|
|
dwarnln("IPv4 packet too small for TCP");
|
|
return {};
|
|
}
|
|
auto& tcp_header = ipv4_data.as<const TCPHeader>();
|
|
dst_port = tcp_header.dst_port;
|
|
src_port = tcp_header.src_port;
|
|
break;
|
|
}
|
|
default:
|
|
dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol);
|
|
return {};
|
|
}
|
|
|
|
ASSERT(dst_port != NetworkSocket::PORT_NONE);
|
|
ASSERT(src_port != NetworkSocket::PORT_NONE);
|
|
|
|
BAN::RefPtr<Kernel::NetworkSocket> bound_socket;
|
|
|
|
{
|
|
SpinLockGuard _(m_bound_socket_lock);
|
|
auto it = m_bound_sockets.find(dst_port);
|
|
if (it == m_bound_sockets.end())
|
|
{
|
|
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
|
|
return {};
|
|
}
|
|
bound_socket = it->value.lock();
|
|
}
|
|
|
|
if (!bound_socket)
|
|
{
|
|
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
|
|
return {};
|
|
}
|
|
|
|
if (bound_socket->protocol() != ipv4_header.protocol)
|
|
{
|
|
dprintln_if(DEBUG_IPV4, "got data with wrong protocol ({}) on port {} (bound as {})", ipv4_header.protocol, dst_port, (uint8_t)bound_socket->protocol());
|
|
return {};
|
|
}
|
|
|
|
sockaddr_in sender;
|
|
sender.sin_family = AF_INET;
|
|
sender.sin_port = BAN::host_to_network_endian(src_port);
|
|
sender.sin_addr.s_addr = src_ipv4.raw;
|
|
bound_socket->receive_packet(ipv4_data, reinterpret_cast<const sockaddr*>(&sender), sizeof(sender));
|
|
|
|
return {};
|
|
}
|
|
|
|
}
|