Kernel: Cleanup network APIs and error messages

This commit is contained in:
Bananymous 2024-02-08 18:33:49 +02:00
parent 5a939cf252
commit acf79570ef
9 changed files with 82 additions and 62 deletions

View File

@ -28,6 +28,8 @@ namespace Kernel
virtual bool link_up() override { return m_link_up; }
virtual int link_speed() override;
virtual size_t payload_mtu() const { return E1000_RX_BUFFER_SIZE; }
virtual void handle_irq() final override;
protected:
@ -67,7 +69,7 @@ namespace Kernel
BAN::UniqPtr<DMARegion> m_tx_descriptor_region;
BAN::MACAddress m_mac_address {};
bool m_link_up { false };
bool m_link_up { false };
friend class BAN::RefPtr<E1000>;
};

View File

@ -29,27 +29,6 @@ namespace Kernel
BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address;
BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
total_sum -= checksum;
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return ~(uint16_t)total_sum;
}
constexpr bool is_valid_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return total_sum == 0xFFFF;
}
};
static_assert(sizeof(IPv4Header) == 20);
@ -69,7 +48,7 @@ namespace Kernel
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
private:
IPv4Layer();

View File

@ -52,6 +52,8 @@ namespace Kernel
virtual bool link_up() = 0;
virtual int link_speed() = 0;
virtual size_t payload_mtu() const = 0;
virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return m_name; }

View File

@ -5,6 +5,15 @@
namespace Kernel
{
struct PseudoHeader
{
BAN::IPv4Address src_ipv4 { 0 };
BAN::IPv4Address dst_ipv4 { 0 };
BAN::NetworkEndian<uint16_t> protocol { 0 };
BAN::NetworkEndian<uint16_t> extra { 0 };
};
static_assert(sizeof(PseudoHeader) == 12);
class NetworkSocket;
enum class SocketType;
@ -16,10 +25,22 @@ namespace Kernel
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0;
protected:
NetworkLayer() = default;
};
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet, const PseudoHeader& pseudo_header)
{
uint32_t checksum = 0;
for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(&pseudo_header)[i]);
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
while (checksum >> 16)
checksum = (checksum >> 16) + (checksum & 0xFFFF);
return ~(uint16_t)checksum;
}
}

View File

@ -32,7 +32,7 @@ namespace Kernel
NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; }
virtual size_t protocol_header_size() const = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0;
virtual NetworkProtocol protocol() const = 0;
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0;

View File

@ -24,10 +24,11 @@ namespace Kernel
public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) override;
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override;
protected:
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_addr, uint16_t sender_port) override;
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override;
@ -47,7 +48,8 @@ namespace Kernel
BAN::UniqPtr<VirtualRange> m_packet_buffer;
BAN::CircularQueue<PacketInfo, 128> m_packets;
size_t m_packet_total_size { 0 };
Semaphore m_semaphore;
SpinLock m_packet_lock;
Semaphore m_packet_semaphore;
friend class BAN::RefPtr<UDPSocket>;
};

View File

@ -12,6 +12,11 @@
namespace Kernel
{
enum IPv4Flags : uint16_t
{
DF = 1 << 14,
};
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
{
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
@ -57,7 +62,8 @@ namespace Kernel
header.protocol = protocol;
header.src_address = src_ipv4;
header.dst_address = dst_ipv4;
header.checksum = header.calculate_checksum();
header.checksum = 0;
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {});
}
void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
@ -98,7 +104,7 @@ namespace Kernel
if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket));
TRY(m_bound_sockets.insert(port, TRY(socket->get_weak_ptr())));
// FIXME: actually determine proper interface
auto interface = NetworkManager::get().interfaces().front();
@ -107,28 +113,37 @@ namespace Kernel
return {};
}
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, const sys_sendto_t* arguments)
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len)
{
if (arguments->dest_addr->sa_family != AF_INET)
if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(arguments->dest_addr);
if (address == nullptr || address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4));
BAN::Vector<uint8_t> packet_buffer;
TRY(packet_buffer.resize(arguments->length + sizeof(IPv4Header) + socket.protocol_header_size()));
TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size()));
auto packet = BAN::ByteSpan { packet_buffer.span() };
auto pseudo_header = PseudoHeader {
.src_ipv4 = socket.interface().get_ipv4_address(),
.dst_ipv4 = dst_ipv4,
.protocol = socket.protocol()
};
memcpy(
packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(),
arguments->message,
arguments->length
buffer.data(),
buffer.size()
);
socket.add_protocol_header(
packet.slice(sizeof(IPv4Header)),
dst_port
dst_port,
pseudo_header
);
add_ipv4_header(
packet,
@ -139,17 +154,7 @@ namespace Kernel
TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet));
return arguments->length;
}
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet)
{
uint32_t checksum = 0;
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
while (checksum >> 16)
checksum = (checksum >> 16) | (checksum & 0xFFFF);
return ~(uint16_t)checksum;
return buffer.size();
}
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet)
@ -157,8 +162,6 @@ namespace Kernel
auto& ipv4_header = packet.as<const IPv4Header>();
auto ipv4_data = packet.slice(sizeof(IPv4Header));
ASSERT(ipv4_header.is_valid_checksum());
auto src_ipv4 = ipv4_header.src_address;
switch (ipv4_header.protocol)
{
@ -174,7 +177,7 @@ namespace Kernel
auto& reply_icmp_header = ipv4_data.as<ICMPHeader>();
reply_icmp_header.type = ICMPType::EchoReply;
reply_icmp_header.checksum = 0;
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data);
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data, {});
add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP);
@ -195,14 +198,20 @@ namespace Kernel
LockGuard _(m_lock);
if (!m_bound_sockets.contains(dst_port) || !m_bound_sockets[dst_port].valid())
if (!m_bound_sockets.contains(dst_port))
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
auto socket = m_bound_sockets[dst_port].lock();
if (!socket)
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
auto udp_data = ipv4_data.slice(sizeof(UDPHeader));
m_bound_sockets[dst_port].lock()->add_packet(udp_data, src_ipv4, src_port);
socket->add_packet(udp_data, src_ipv4, src_port);
break;
}
default:
@ -262,14 +271,17 @@ namespace Kernel
}
auto& ipv4_header = buffer.as<const IPv4Header>();
if (!ipv4_header.is_valid_checksum())
if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header), {}) != 0)
{
dwarnln("Invalid IPv4 packet");
return;
}
if (ipv4_header.total_length > buffer.size())
if (ipv4_header.total_length > buffer.size() || ipv4_header.total_length > interface.payload_mtu())
{
dwarnln("Too short IPv4 packet");
if (ipv4_header.flags_frament & IPv4Flags::DF)
dwarnln("Invalid IPv4 packet");
else
dwarnln("IPv4 fragmentation not supported");
return;
}

View File

@ -49,7 +49,8 @@ namespace Kernel
if (!m_interface)
TRY(m_network_layer.bind_socket(PORT_NONE, this));
return TRY(m_network_layer.sendto(*this, arguments));
auto buffer = BAN::ConstByteSpan { reinterpret_cast<const uint8_t*>(arguments->message), arguments->length };
return TRY(m_network_layer.sendto(*this, buffer, arguments->dest_addr, arguments->dest_len));
}
BAN::ErrorOr<size_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments)

View File

@ -1,3 +1,4 @@
#include <kernel/LockGuard.h>
#include <kernel/Memory/Heap.h>
#include <kernel/Networking/UDPSocket.h>
#include <kernel/Thread.h>
@ -23,7 +24,7 @@ namespace Kernel
: NetworkSocket(network_layer, ino, inode_info)
{ }
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port)
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader)
{
auto& header = packet.as<UDPHeader>();
header.src_port = m_port;
@ -34,7 +35,7 @@ namespace Kernel
void UDPSocket::add_packet(BAN::ConstByteSpan packet, BAN::IPv4Address sender_addr, uint16_t sender_port)
{
CriticalScope _;
LockGuard _(m_packet_lock);
if (m_packets.full())
{
@ -58,15 +59,15 @@ namespace Kernel
});
m_packet_total_size += packet.size();
m_semaphore.unblock();
m_packet_semaphore.unblock();
}
BAN::ErrorOr<size_t> UDPSocket::read_packet(BAN::ByteSpan buffer, sockaddr_in* sender_addr)
{
while (m_packets.empty())
TRY(Thread::current().block_or_eintr(m_semaphore));
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
CriticalScope _;
LockGuard _(m_packet_lock);
if (m_packets.empty())
return read_packet(buffer, sender_addr);