Kernel: Cleanup network APIs and error messages

This commit is contained in:
Bananymous 2024-02-08 18:33:49 +02:00
parent 5a939cf252
commit acf79570ef
9 changed files with 82 additions and 62 deletions

View File

@ -28,6 +28,8 @@ namespace Kernel
virtual bool link_up() override { return m_link_up; } virtual bool link_up() override { return m_link_up; }
virtual int link_speed() override; virtual int link_speed() override;
virtual size_t payload_mtu() const { return E1000_RX_BUFFER_SIZE; }
virtual void handle_irq() final override; virtual void handle_irq() final override;
protected: protected:
@ -67,7 +69,7 @@ namespace Kernel
BAN::UniqPtr<DMARegion> m_tx_descriptor_region; BAN::UniqPtr<DMARegion> m_tx_descriptor_region;
BAN::MACAddress m_mac_address {}; BAN::MACAddress m_mac_address {};
bool m_link_up { false }; bool m_link_up { false };
friend class BAN::RefPtr<E1000>; friend class BAN::RefPtr<E1000>;
}; };

View File

@ -29,27 +29,6 @@ namespace Kernel
BAN::NetworkEndian<uint16_t> checksum { 0 }; BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address; BAN::IPv4Address src_address;
BAN::IPv4Address dst_address; BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
total_sum -= checksum;
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return ~(uint16_t)total_sum;
}
constexpr bool is_valid_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return total_sum == 0xFFFF;
}
}; };
static_assert(sizeof(IPv4Header) == 20); static_assert(sizeof(IPv4Header) == 20);
@ -69,7 +48,7 @@ namespace Kernel
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override; virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override; virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) override; virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
private: private:
IPv4Layer(); IPv4Layer();

View File

@ -52,6 +52,8 @@ namespace Kernel
virtual bool link_up() = 0; virtual bool link_up() = 0;
virtual int link_speed() = 0; virtual int link_speed() = 0;
virtual size_t payload_mtu() const = 0;
virtual dev_t rdev() const override { return m_rdev; } virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return m_name; } virtual BAN::StringView name() const override { return m_name; }

View File

@ -5,6 +5,15 @@
namespace Kernel namespace Kernel
{ {
struct PseudoHeader
{
BAN::IPv4Address src_ipv4 { 0 };
BAN::IPv4Address dst_ipv4 { 0 };
BAN::NetworkEndian<uint16_t> protocol { 0 };
BAN::NetworkEndian<uint16_t> extra { 0 };
};
static_assert(sizeof(PseudoHeader) == 12);
class NetworkSocket; class NetworkSocket;
enum class SocketType; enum class SocketType;
@ -16,10 +25,22 @@ namespace Kernel
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0; virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0; virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) = 0; virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0;
protected: protected:
NetworkLayer() = default; NetworkLayer() = default;
}; };
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet, const PseudoHeader& pseudo_header)
{
uint32_t checksum = 0;
for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(&pseudo_header)[i]);
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
while (checksum >> 16)
checksum = (checksum >> 16) + (checksum & 0xFFFF);
return ~(uint16_t)checksum;
}
} }

View File

@ -32,7 +32,7 @@ namespace Kernel
NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; } NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; }
virtual size_t protocol_header_size() const = 0; virtual size_t protocol_header_size() const = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) = 0; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0;
virtual NetworkProtocol protocol() const = 0; virtual NetworkProtocol protocol() const = 0;
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0;

View File

@ -24,10 +24,11 @@ namespace Kernel
public: public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&); static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) override;
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override;
protected: protected:
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_addr, uint16_t sender_port) override; virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_addr, uint16_t sender_port) override;
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override; virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override;
@ -47,7 +48,8 @@ namespace Kernel
BAN::UniqPtr<VirtualRange> m_packet_buffer; BAN::UniqPtr<VirtualRange> m_packet_buffer;
BAN::CircularQueue<PacketInfo, 128> m_packets; BAN::CircularQueue<PacketInfo, 128> m_packets;
size_t m_packet_total_size { 0 }; size_t m_packet_total_size { 0 };
Semaphore m_semaphore; SpinLock m_packet_lock;
Semaphore m_packet_semaphore;
friend class BAN::RefPtr<UDPSocket>; friend class BAN::RefPtr<UDPSocket>;
}; };

View File

@ -12,6 +12,11 @@
namespace Kernel namespace Kernel
{ {
enum IPv4Flags : uint16_t
{
DF = 1 << 14,
};
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create() BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
{ {
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create()); auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
@ -57,7 +62,8 @@ namespace Kernel
header.protocol = protocol; header.protocol = protocol;
header.src_address = src_ipv4; header.src_address = src_ipv4;
header.dst_address = dst_ipv4; header.dst_address = dst_ipv4;
header.checksum = header.calculate_checksum(); header.checksum = 0;
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {});
} }
void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket) void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
@ -98,7 +104,7 @@ namespace Kernel
if (m_bound_sockets.contains(port)) if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE); return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket)); TRY(m_bound_sockets.insert(port, TRY(socket->get_weak_ptr())));
// FIXME: actually determine proper interface // FIXME: actually determine proper interface
auto interface = NetworkManager::get().interfaces().front(); auto interface = NetworkManager::get().interfaces().front();
@ -107,28 +113,37 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, const sys_sendto_t* arguments) BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len)
{ {
if (arguments->dest_addr->sa_family != AF_INET) if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(arguments->dest_addr); if (address == nullptr || address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port); auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr }; auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4)); auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4));
BAN::Vector<uint8_t> packet_buffer; BAN::Vector<uint8_t> packet_buffer;
TRY(packet_buffer.resize(arguments->length + sizeof(IPv4Header) + socket.protocol_header_size())); TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size()));
auto packet = BAN::ByteSpan { packet_buffer.span() }; auto packet = BAN::ByteSpan { packet_buffer.span() };
auto pseudo_header = PseudoHeader {
.src_ipv4 = socket.interface().get_ipv4_address(),
.dst_ipv4 = dst_ipv4,
.protocol = socket.protocol()
};
memcpy( memcpy(
packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(), packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(),
arguments->message, buffer.data(),
arguments->length buffer.size()
); );
socket.add_protocol_header( socket.add_protocol_header(
packet.slice(sizeof(IPv4Header)), packet.slice(sizeof(IPv4Header)),
dst_port dst_port,
pseudo_header
); );
add_ipv4_header( add_ipv4_header(
packet, packet,
@ -139,17 +154,7 @@ namespace Kernel
TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet)); TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet));
return arguments->length; return buffer.size();
}
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet)
{
uint32_t checksum = 0;
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
while (checksum >> 16)
checksum = (checksum >> 16) | (checksum & 0xFFFF);
return ~(uint16_t)checksum;
} }
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet) BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet)
@ -157,8 +162,6 @@ namespace Kernel
auto& ipv4_header = packet.as<const IPv4Header>(); auto& ipv4_header = packet.as<const IPv4Header>();
auto ipv4_data = packet.slice(sizeof(IPv4Header)); auto ipv4_data = packet.slice(sizeof(IPv4Header));
ASSERT(ipv4_header.is_valid_checksum());
auto src_ipv4 = ipv4_header.src_address; auto src_ipv4 = ipv4_header.src_address;
switch (ipv4_header.protocol) switch (ipv4_header.protocol)
{ {
@ -174,7 +177,7 @@ namespace Kernel
auto& reply_icmp_header = ipv4_data.as<ICMPHeader>(); auto& reply_icmp_header = ipv4_data.as<ICMPHeader>();
reply_icmp_header.type = ICMPType::EchoReply; reply_icmp_header.type = ICMPType::EchoReply;
reply_icmp_header.checksum = 0; reply_icmp_header.checksum = 0;
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data); reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data, {});
add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP); add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP);
@ -195,14 +198,20 @@ namespace Kernel
LockGuard _(m_lock); LockGuard _(m_lock);
if (!m_bound_sockets.contains(dst_port) || !m_bound_sockets[dst_port].valid()) if (!m_bound_sockets.contains(dst_port))
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
auto socket = m_bound_sockets[dst_port].lock();
if (!socket)
{ {
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {}; return {};
} }
auto udp_data = ipv4_data.slice(sizeof(UDPHeader)); auto udp_data = ipv4_data.slice(sizeof(UDPHeader));
m_bound_sockets[dst_port].lock()->add_packet(udp_data, src_ipv4, src_port); socket->add_packet(udp_data, src_ipv4, src_port);
break; break;
} }
default: default:
@ -262,14 +271,17 @@ namespace Kernel
} }
auto& ipv4_header = buffer.as<const IPv4Header>(); auto& ipv4_header = buffer.as<const IPv4Header>();
if (!ipv4_header.is_valid_checksum()) if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header), {}) != 0)
{ {
dwarnln("Invalid IPv4 packet"); dwarnln("Invalid IPv4 packet");
return; return;
} }
if (ipv4_header.total_length > buffer.size()) if (ipv4_header.total_length > buffer.size() || ipv4_header.total_length > interface.payload_mtu())
{ {
dwarnln("Too short IPv4 packet"); if (ipv4_header.flags_frament & IPv4Flags::DF)
dwarnln("Invalid IPv4 packet");
else
dwarnln("IPv4 fragmentation not supported");
return; return;
} }

View File

@ -49,7 +49,8 @@ namespace Kernel
if (!m_interface) if (!m_interface)
TRY(m_network_layer.bind_socket(PORT_NONE, this)); TRY(m_network_layer.bind_socket(PORT_NONE, this));
return TRY(m_network_layer.sendto(*this, arguments)); auto buffer = BAN::ConstByteSpan { reinterpret_cast<const uint8_t*>(arguments->message), arguments->length };
return TRY(m_network_layer.sendto(*this, buffer, arguments->dest_addr, arguments->dest_len));
} }
BAN::ErrorOr<size_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments) BAN::ErrorOr<size_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments)

View File

@ -1,3 +1,4 @@
#include <kernel/LockGuard.h>
#include <kernel/Memory/Heap.h> #include <kernel/Memory/Heap.h>
#include <kernel/Networking/UDPSocket.h> #include <kernel/Networking/UDPSocket.h>
#include <kernel/Thread.h> #include <kernel/Thread.h>
@ -23,7 +24,7 @@ namespace Kernel
: NetworkSocket(network_layer, ino, inode_info) : NetworkSocket(network_layer, ino, inode_info)
{ } { }
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader)
{ {
auto& header = packet.as<UDPHeader>(); auto& header = packet.as<UDPHeader>();
header.src_port = m_port; header.src_port = m_port;
@ -34,7 +35,7 @@ namespace Kernel
void UDPSocket::add_packet(BAN::ConstByteSpan packet, BAN::IPv4Address sender_addr, uint16_t sender_port) void UDPSocket::add_packet(BAN::ConstByteSpan packet, BAN::IPv4Address sender_addr, uint16_t sender_port)
{ {
CriticalScope _; LockGuard _(m_packet_lock);
if (m_packets.full()) if (m_packets.full())
{ {
@ -58,15 +59,15 @@ namespace Kernel
}); });
m_packet_total_size += packet.size(); m_packet_total_size += packet.size();
m_semaphore.unblock(); m_packet_semaphore.unblock();
} }
BAN::ErrorOr<size_t> UDPSocket::read_packet(BAN::ByteSpan buffer, sockaddr_in* sender_addr) BAN::ErrorOr<size_t> UDPSocket::read_packet(BAN::ByteSpan buffer, sockaddr_in* sender_addr)
{ {
while (m_packets.empty()) while (m_packets.empty())
TRY(Thread::current().block_or_eintr(m_semaphore)); TRY(Thread::current().block_or_eintr(m_packet_semaphore));
CriticalScope _; LockGuard _(m_packet_lock);
if (m_packets.empty()) if (m_packets.empty())
return read_packet(buffer, sender_addr); return read_packet(buffer, sender_addr);