diff --git a/kernel/include/kernel/Networking/E1000/E1000.h b/kernel/include/kernel/Networking/E1000/E1000.h index 1d5d67b1..4637c987 100644 --- a/kernel/include/kernel/Networking/E1000/E1000.h +++ b/kernel/include/kernel/Networking/E1000/E1000.h @@ -28,6 +28,8 @@ namespace Kernel virtual bool link_up() override { return m_link_up; } virtual int link_speed() override; + virtual size_t payload_mtu() const { return E1000_RX_BUFFER_SIZE; } + virtual void handle_irq() final override; protected: @@ -67,7 +69,7 @@ namespace Kernel BAN::UniqPtr m_tx_descriptor_region; BAN::MACAddress m_mac_address {}; - bool m_link_up { false }; + bool m_link_up { false }; friend class BAN::RefPtr; }; diff --git a/kernel/include/kernel/Networking/IPv4Layer.h b/kernel/include/kernel/Networking/IPv4Layer.h index bf9d2952..4a33b8ab 100644 --- a/kernel/include/kernel/Networking/IPv4Layer.h +++ b/kernel/include/kernel/Networking/IPv4Layer.h @@ -29,27 +29,6 @@ namespace Kernel BAN::NetworkEndian checksum { 0 }; BAN::IPv4Address src_address; BAN::IPv4Address dst_address; - - constexpr uint16_t calculate_checksum() const - { - uint32_t total_sum = 0; - for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++) - total_sum += reinterpret_cast*>(this)[i]; - total_sum -= checksum; - while (total_sum >> 16) - total_sum = (total_sum >> 16) + (total_sum & 0xFFFF); - return ~(uint16_t)total_sum; - } - - constexpr bool is_valid_checksum() const - { - uint32_t total_sum = 0; - for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++) - total_sum += reinterpret_cast*>(this)[i]; - while (total_sum >> 16) - total_sum = (total_sum >> 16) + (total_sum & 0xFFFF); - return total_sum == 0xFFFF; - } }; static_assert(sizeof(IPv4Header) == 20); @@ -69,7 +48,7 @@ namespace Kernel virtual void unbind_socket(uint16_t port, BAN::RefPtr) override; virtual BAN::ErrorOr bind_socket(uint16_t port, BAN::RefPtr) override; - virtual BAN::ErrorOr sendto(NetworkSocket&, const sys_sendto_t*) override; + virtual BAN::ErrorOr sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override; private: IPv4Layer(); diff --git a/kernel/include/kernel/Networking/NetworkInterface.h b/kernel/include/kernel/Networking/NetworkInterface.h index 24f18eec..6980e621 100644 --- a/kernel/include/kernel/Networking/NetworkInterface.h +++ b/kernel/include/kernel/Networking/NetworkInterface.h @@ -52,6 +52,8 @@ namespace Kernel virtual bool link_up() = 0; virtual int link_speed() = 0; + virtual size_t payload_mtu() const = 0; + virtual dev_t rdev() const override { return m_rdev; } virtual BAN::StringView name() const override { return m_name; } diff --git a/kernel/include/kernel/Networking/NetworkLayer.h b/kernel/include/kernel/Networking/NetworkLayer.h index 74db08d1..7d11a21b 100644 --- a/kernel/include/kernel/Networking/NetworkLayer.h +++ b/kernel/include/kernel/Networking/NetworkLayer.h @@ -5,6 +5,15 @@ namespace Kernel { + struct PseudoHeader + { + BAN::IPv4Address src_ipv4 { 0 }; + BAN::IPv4Address dst_ipv4 { 0 }; + BAN::NetworkEndian protocol { 0 }; + BAN::NetworkEndian extra { 0 }; + }; + static_assert(sizeof(PseudoHeader) == 12); + class NetworkSocket; enum class SocketType; @@ -16,10 +25,22 @@ namespace Kernel virtual void unbind_socket(uint16_t port, BAN::RefPtr) = 0; virtual BAN::ErrorOr bind_socket(uint16_t port, BAN::RefPtr) = 0; - virtual BAN::ErrorOr sendto(NetworkSocket&, const sys_sendto_t*) = 0; + virtual BAN::ErrorOr sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0; protected: NetworkLayer() = default; }; + static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet, const PseudoHeader& pseudo_header) + { + uint32_t checksum = 0; + for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(uint16_t); i++) + checksum += BAN::host_to_network_endian(reinterpret_cast(&pseudo_header)[i]); + for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++) + checksum += BAN::host_to_network_endian(reinterpret_cast(packet.data())[i]); + while (checksum >> 16) + checksum = (checksum >> 16) + (checksum & 0xFFFF); + return ~(uint16_t)checksum; + } + } diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index da25acd7..cc9cf492 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -32,7 +32,7 @@ namespace Kernel NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; } virtual size_t protocol_header_size() const = 0; - virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) = 0; + virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0; virtual NetworkProtocol protocol() const = 0; virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index 74955924..0aff1e38 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -24,10 +24,11 @@ namespace Kernel public: static BAN::ErrorOr> create(NetworkLayer&, ino_t, const TmpInodeInfo&); - virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); } - virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) override; virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } + virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); } + virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; + protected: virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_addr, uint16_t sender_port) override; virtual BAN::ErrorOr read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override; @@ -47,7 +48,8 @@ namespace Kernel BAN::UniqPtr m_packet_buffer; BAN::CircularQueue m_packets; size_t m_packet_total_size { 0 }; - Semaphore m_semaphore; + SpinLock m_packet_lock; + Semaphore m_packet_semaphore; friend class BAN::RefPtr; }; diff --git a/kernel/kernel/Networking/IPv4Layer.cpp b/kernel/kernel/Networking/IPv4Layer.cpp index 1660d66c..7878cfa8 100644 --- a/kernel/kernel/Networking/IPv4Layer.cpp +++ b/kernel/kernel/Networking/IPv4Layer.cpp @@ -12,6 +12,11 @@ namespace Kernel { + enum IPv4Flags : uint16_t + { + DF = 1 << 14, + }; + BAN::ErrorOr> IPv4Layer::create() { auto ipv4_manager = TRY(BAN::UniqPtr::create()); @@ -57,7 +62,8 @@ namespace Kernel header.protocol = protocol; header.src_address = src_ipv4; header.dst_address = dst_ipv4; - header.checksum = header.calculate_checksum(); + header.checksum = 0; + header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {}); } void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr socket) @@ -98,7 +104,7 @@ namespace Kernel if (m_bound_sockets.contains(port)) return BAN::Error::from_errno(EADDRINUSE); - TRY(m_bound_sockets.insert(port, socket)); + TRY(m_bound_sockets.insert(port, TRY(socket->get_weak_ptr()))); // FIXME: actually determine proper interface auto interface = NetworkManager::get().interfaces().front(); @@ -107,28 +113,37 @@ namespace Kernel return {}; } - BAN::ErrorOr IPv4Layer::sendto(NetworkSocket& socket, const sys_sendto_t* arguments) + BAN::ErrorOr IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len) { - if (arguments->dest_addr->sa_family != AF_INET) + if (address->sa_family != AF_INET) return BAN::Error::from_errno(EINVAL); - auto& sockaddr_in = *reinterpret_cast(arguments->dest_addr); + if (address == nullptr || address_len != sizeof(sockaddr_in)) + return BAN::Error::from_errno(EINVAL); + auto& sockaddr_in = *reinterpret_cast(address); auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port); auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr }; auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4)); BAN::Vector packet_buffer; - TRY(packet_buffer.resize(arguments->length + sizeof(IPv4Header) + socket.protocol_header_size())); + TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size())); auto packet = BAN::ByteSpan { packet_buffer.span() }; + auto pseudo_header = PseudoHeader { + .src_ipv4 = socket.interface().get_ipv4_address(), + .dst_ipv4 = dst_ipv4, + .protocol = socket.protocol() + }; + memcpy( packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(), - arguments->message, - arguments->length + buffer.data(), + buffer.size() ); socket.add_protocol_header( packet.slice(sizeof(IPv4Header)), - dst_port + dst_port, + pseudo_header ); add_ipv4_header( packet, @@ -139,17 +154,7 @@ namespace Kernel TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet)); - return arguments->length; - } - - static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet) - { - uint32_t checksum = 0; - for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++) - checksum += BAN::host_to_network_endian(reinterpret_cast(packet.data())[i]); - while (checksum >> 16) - checksum = (checksum >> 16) | (checksum & 0xFFFF); - return ~(uint16_t)checksum; + return buffer.size(); } BAN::ErrorOr IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet) @@ -157,8 +162,6 @@ namespace Kernel auto& ipv4_header = packet.as(); auto ipv4_data = packet.slice(sizeof(IPv4Header)); - ASSERT(ipv4_header.is_valid_checksum()); - auto src_ipv4 = ipv4_header.src_address; switch (ipv4_header.protocol) { @@ -174,7 +177,7 @@ namespace Kernel auto& reply_icmp_header = ipv4_data.as(); reply_icmp_header.type = ICMPType::EchoReply; reply_icmp_header.checksum = 0; - reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data); + reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data, {}); add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP); @@ -195,14 +198,20 @@ namespace Kernel LockGuard _(m_lock); - if (!m_bound_sockets.contains(dst_port) || !m_bound_sockets[dst_port].valid()) + if (!m_bound_sockets.contains(dst_port)) + { + dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); + return {}; + } + auto socket = m_bound_sockets[dst_port].lock(); + if (!socket) { dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); return {}; } auto udp_data = ipv4_data.slice(sizeof(UDPHeader)); - m_bound_sockets[dst_port].lock()->add_packet(udp_data, src_ipv4, src_port); + socket->add_packet(udp_data, src_ipv4, src_port); break; } default: @@ -262,14 +271,17 @@ namespace Kernel } auto& ipv4_header = buffer.as(); - if (!ipv4_header.is_valid_checksum()) + if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header), {}) != 0) { dwarnln("Invalid IPv4 packet"); return; } - if (ipv4_header.total_length > buffer.size()) + if (ipv4_header.total_length > buffer.size() || ipv4_header.total_length > interface.payload_mtu()) { - dwarnln("Too short IPv4 packet"); + if (ipv4_header.flags_frament & IPv4Flags::DF) + dwarnln("Invalid IPv4 packet"); + else + dwarnln("IPv4 fragmentation not supported"); return; } diff --git a/kernel/kernel/Networking/NetworkSocket.cpp b/kernel/kernel/Networking/NetworkSocket.cpp index 53eb745e..eee19713 100644 --- a/kernel/kernel/Networking/NetworkSocket.cpp +++ b/kernel/kernel/Networking/NetworkSocket.cpp @@ -49,7 +49,8 @@ namespace Kernel if (!m_interface) TRY(m_network_layer.bind_socket(PORT_NONE, this)); - return TRY(m_network_layer.sendto(*this, arguments)); + auto buffer = BAN::ConstByteSpan { reinterpret_cast(arguments->message), arguments->length }; + return TRY(m_network_layer.sendto(*this, buffer, arguments->dest_addr, arguments->dest_len)); } BAN::ErrorOr NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments) diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index 09525f2f..b4a291da 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -23,7 +24,7 @@ namespace Kernel : NetworkSocket(network_layer, ino, inode_info) { } - void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) + void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) { auto& header = packet.as(); header.src_port = m_port; @@ -34,7 +35,7 @@ namespace Kernel void UDPSocket::add_packet(BAN::ConstByteSpan packet, BAN::IPv4Address sender_addr, uint16_t sender_port) { - CriticalScope _; + LockGuard _(m_packet_lock); if (m_packets.full()) { @@ -58,15 +59,15 @@ namespace Kernel }); m_packet_total_size += packet.size(); - m_semaphore.unblock(); + m_packet_semaphore.unblock(); } BAN::ErrorOr UDPSocket::read_packet(BAN::ByteSpan buffer, sockaddr_in* sender_addr) { while (m_packets.empty()) - TRY(Thread::current().block_or_eintr(m_semaphore)); + TRY(Thread::current().block_or_eintr(m_packet_semaphore)); - CriticalScope _; + LockGuard _(m_packet_lock); if (m_packets.empty()) return read_packet(buffer, sender_addr);