diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index a0189ad3a1..51cac192df 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -53,7 +53,7 @@ set(KERNEL_SOURCES kernel/Networking/ARPTable.cpp kernel/Networking/E1000/E1000.cpp kernel/Networking/E1000/E1000E.cpp - kernel/Networking/IPv4.cpp + kernel/Networking/IPv4Layer.cpp kernel/Networking/NetworkInterface.cpp kernel/Networking/NetworkManager.cpp kernel/Networking/NetworkSocket.cpp diff --git a/kernel/include/kernel/Networking/E1000/E1000.h b/kernel/include/kernel/Networking/E1000/E1000.h index 2fba6e122e..1d5d67b19f 100644 --- a/kernel/include/kernel/Networking/E1000/E1000.h +++ b/kernel/include/kernel/Networking/E1000/E1000.h @@ -42,7 +42,7 @@ namespace Kernel uint32_t read32(uint16_t reg); void write32(uint16_t reg, uint32_t value); - virtual BAN::ErrorOr send_raw_bytes(BAN::ConstByteSpan) override; + virtual BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; private: BAN::ErrorOr read_mac_address(); diff --git a/kernel/include/kernel/Networking/ICMP.h b/kernel/include/kernel/Networking/ICMP.h new file mode 100644 index 0000000000..689bde47bf --- /dev/null +++ b/kernel/include/kernel/Networking/ICMP.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +namespace Kernel +{ + + struct ICMPHeader + { + uint8_t type; + uint8_t code; + BAN::NetworkEndian checksum; + BAN::NetworkEndian rest; + }; + static_assert(sizeof(ICMPHeader) == 8); + + enum ICMPType : uint8_t + { + EchoReply = 0x00, + EchoRequest = 0x08, + }; + +} diff --git a/kernel/include/kernel/Networking/IPv4.h b/kernel/include/kernel/Networking/IPv4.h deleted file mode 100644 index 610e64d410..0000000000 --- a/kernel/include/kernel/Networking/IPv4.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace Kernel -{ - - struct IPv4Header - { - uint8_t version_IHL; - uint8_t DSCP_ECN; - BAN::NetworkEndian total_length { 0 }; - BAN::NetworkEndian identification { 0 }; - BAN::NetworkEndian flags_frament { 0 }; - uint8_t time_to_live; - uint8_t protocol; - BAN::NetworkEndian checksum { 0 }; - BAN::IPv4Address src_address; - BAN::IPv4Address dst_address; - - constexpr uint16_t calculate_checksum() const - { - uint32_t total_sum = 0; - for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++) - total_sum += reinterpret_cast*>(this)[i]; - total_sum -= checksum; - while (total_sum >> 16) - total_sum = (total_sum >> 16) + (total_sum & 0xFFFF); - return ~(uint16_t)total_sum; - } - }; - static_assert(sizeof(IPv4Header) == 20); - - void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol); - -} diff --git a/kernel/include/kernel/Networking/IPv4Layer.h b/kernel/include/kernel/Networking/IPv4Layer.h new file mode 100644 index 0000000000..bf9d295270 --- /dev/null +++ b/kernel/include/kernel/Networking/IPv4Layer.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Kernel +{ + + struct IPv4Header + { + uint8_t version_IHL; + uint8_t DSCP_ECN; + BAN::NetworkEndian total_length { 0 }; + BAN::NetworkEndian identification { 0 }; + BAN::NetworkEndian flags_frament { 0 }; + uint8_t time_to_live; + uint8_t protocol; + BAN::NetworkEndian checksum { 0 }; + BAN::IPv4Address src_address; + BAN::IPv4Address dst_address; + + constexpr uint16_t calculate_checksum() const + { + uint32_t total_sum = 0; + for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++) + total_sum += reinterpret_cast*>(this)[i]; + total_sum -= checksum; + while (total_sum >> 16) + total_sum = (total_sum >> 16) + (total_sum & 0xFFFF); + return ~(uint16_t)total_sum; + } + + constexpr bool is_valid_checksum() const + { + uint32_t total_sum = 0; + for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++) + total_sum += reinterpret_cast*>(this)[i]; + while (total_sum >> 16) + total_sum = (total_sum >> 16) + (total_sum & 0xFFFF); + return total_sum == 0xFFFF; + } + }; + static_assert(sizeof(IPv4Header) == 20); + + class IPv4Layer : public NetworkLayer + { + BAN_NON_COPYABLE(IPv4Layer); + BAN_NON_MOVABLE(IPv4Layer); + + public: + static BAN::ErrorOr> create(); + ~IPv4Layer(); + + ARPTable& arp_table() { return *m_arp_table; } + + void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan); + + virtual void unbind_socket(uint16_t port, BAN::RefPtr) override; + virtual BAN::ErrorOr bind_socket(uint16_t port, BAN::RefPtr) override; + + virtual BAN::ErrorOr sendto(NetworkSocket&, const sys_sendto_t*) override; + + private: + IPv4Layer(); + + void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const; + + void packet_handle_task(); + BAN::ErrorOr handle_ipv4_packet(NetworkInterface&, BAN::ByteSpan); + + private: + struct PendingIPv4Packet + { + NetworkInterface& interface; + }; + + private: + SpinLock m_lock; + + BAN::UniqPtr m_arp_table; + Process* m_process { nullptr }; + + static constexpr size_t pending_packet_buffer_size = 128 * PAGE_SIZE; + BAN::UniqPtr m_pending_packet_buffer; + BAN::CircularQueue m_pending_packets; + Semaphore m_pending_semaphore; + size_t m_pending_total_size { 0 }; + + BAN::HashMap> m_bound_sockets; + + friend class BAN::UniqPtr; + }; + +} diff --git a/kernel/include/kernel/Networking/NetworkInterface.h b/kernel/include/kernel/Networking/NetworkInterface.h index 50ff40fdc3..24f18eecb0 100644 --- a/kernel/include/kernel/Networking/NetworkInterface.h +++ b/kernel/include/kernel/Networking/NetworkInterface.h @@ -1,10 +1,10 @@ #pragma once -#include #include +#include +#include #include #include -#include namespace Kernel { @@ -52,13 +52,10 @@ namespace Kernel virtual bool link_up() = 0; virtual int link_speed() = 0; - size_t interface_header_size() const; - void add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination); - virtual dev_t rdev() const override { return m_rdev; } virtual BAN::StringView name() const override { return m_name; } - virtual BAN::ErrorOr send_raw_bytes(BAN::ConstByteSpan) = 0; + virtual BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) = 0; private: const Type m_type; diff --git a/kernel/include/kernel/Networking/NetworkLayer.h b/kernel/include/kernel/Networking/NetworkLayer.h new file mode 100644 index 0000000000..74db08d189 --- /dev/null +++ b/kernel/include/kernel/Networking/NetworkLayer.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace Kernel +{ + + class NetworkSocket; + enum class SocketType; + + class NetworkLayer + { + public: + virtual ~NetworkLayer() {} + + virtual void unbind_socket(uint16_t port, BAN::RefPtr) = 0; + virtual BAN::ErrorOr bind_socket(uint16_t port, BAN::RefPtr) = 0; + + virtual BAN::ErrorOr sendto(NetworkSocket&, const sys_sendto_t*) = 0; + + protected: + NetworkLayer() = default; + }; + +} diff --git a/kernel/include/kernel/Networking/NetworkManager.h b/kernel/include/kernel/Networking/NetworkManager.h index 1477406738..475cbde8d6 100644 --- a/kernel/include/kernel/Networking/NetworkManager.h +++ b/kernel/include/kernel/Networking/NetworkManager.h @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include #include @@ -17,26 +17,15 @@ namespace Kernel BAN_NON_COPYABLE(NetworkManager); BAN_NON_MOVABLE(NetworkManager); - public: - enum class SocketType - { - STREAM, - DGRAM, - SEQPACKET, - }; - public: static BAN::ErrorOr initialize(); static NetworkManager& get(); - ARPTable& arp_table() { return *m_arp_table; } - BAN::ErrorOr add_interface(PCI::Device& pci_device); - void unbind_socket(uint16_t port, BAN::RefPtr); - BAN::ErrorOr bind_socket(uint16_t port, BAN::RefPtr); + BAN::Vector> interfaces() { return m_interfaces; } - BAN::ErrorOr> create_socket(SocketType, mode_t, uid_t, gid_t); + BAN::ErrorOr> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t); void on_receive(NetworkInterface&, BAN::ConstByteSpan); @@ -44,9 +33,8 @@ namespace Kernel NetworkManager(); private: - BAN::UniqPtr m_arp_table; - BAN::Vector> m_interfaces; - BAN::HashMap> m_bound_sockets; + BAN::UniqPtr m_ipv4_layer; + BAN::Vector> m_interfaces; }; } diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index 5171a936d8..adf8c5ad5f 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -11,9 +12,24 @@ namespace Kernel enum NetworkProtocol : uint8_t { + ICMP = 0x01, UDP = 0x11, }; + enum class SocketDomain + { + INET, + INET6, + UNIX, + }; + + enum class SocketType + { + STREAM, + DGRAM, + SEQPACKET, + }; + class NetworkSocket : public TmpInode, public BAN::Weakable { BAN_NON_COPYABLE(NetworkSocket); @@ -26,14 +42,16 @@ namespace Kernel void bind_interface_and_port(NetworkInterface*, uint16_t port); ~NetworkSocket(); + NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; } + virtual size_t protocol_header_size() const = 0; - virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) = 0; + virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) = 0; virtual NetworkProtocol protocol() const = 0; virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; protected: - NetworkSocket(mode_t mode, uid_t uid, gid_t gid); + NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); virtual BAN::ErrorOr read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0; @@ -46,6 +64,7 @@ namespace Kernel virtual BAN::ErrorOr ioctl_impl(int request, void* arg) override; protected: + NetworkLayer& m_network_layer; NetworkInterface* m_interface = nullptr; uint16_t m_port = PORT_NONE; }; diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index 9b9fffb7d6..74955924af 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -22,10 +22,10 @@ namespace Kernel class UDPSocket final : public NetworkSocket { public: - static BAN::ErrorOr> create(mode_t, uid_t, gid_t); + static BAN::ErrorOr> create(NetworkLayer&, ino_t, const TmpInodeInfo&); virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); } - virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) override; + virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) override; virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } protected: @@ -33,7 +33,7 @@ namespace Kernel virtual BAN::ErrorOr read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override; private: - UDPSocket(mode_t, uid_t, gid_t); + UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); struct PacketInfo { diff --git a/kernel/kernel/Networking/ARPTable.cpp b/kernel/kernel/Networking/ARPTable.cpp index ba98a5a3e3..030d44b028 100644 --- a/kernel/kernel/Networking/ARPTable.cpp +++ b/kernel/kernel/Networking/ARPTable.cpp @@ -45,6 +45,9 @@ namespace Kernel if (ipv4_address == s_broadcast_ipv4) return s_broadcast_mac; + if (interface.get_ipv4_address() == BAN::IPv4Address { 0 }) + return BAN::Error::from_errno(EINVAL); + if (interface.get_ipv4_address().mask(interface.get_netmask()) != ipv4_address.mask(interface.get_netmask())) ipv4_address = interface.get_gateway(); @@ -54,16 +57,7 @@ namespace Kernel return m_arp_table[ipv4_address]; } - BAN::Vector full_packet_buffer; - TRY(full_packet_buffer.resize(sizeof(ARPPacket) + sizeof(EthernetHeader))); - auto full_packet = BAN::ByteSpan { full_packet_buffer.span() }; - - auto& ethernet_header = full_packet.as(); - ethernet_header.dst_mac = s_broadcast_mac; - ethernet_header.src_mac = interface.get_mac_address(); - ethernet_header.ether_type = EtherType::ARP; - - auto& arp_request = full_packet.slice(sizeof(EthernetHeader)).as(); + ARPPacket arp_request; arp_request.htype = 0x0001; arp_request.ptype = EtherType::IPv4; arp_request.hlen = 0x06; @@ -74,7 +68,7 @@ namespace Kernel arp_request.tha = {{ 0, 0, 0, 0, 0, 0 }}; arp_request.tpa = ipv4_address; - TRY(interface.send_raw_bytes(full_packet)); + TRY(interface.send_bytes(s_broadcast_mac, EtherType::ARP, BAN::ConstByteSpan::from(arp_request))); uint64_t timeout = SystemTimer::get().ms_since_boot() + 1'000; while (SystemTimer::get().ms_since_boot() < timeout) @@ -104,27 +98,17 @@ namespace Kernel { if (packet.tpa == interface.get_ipv4_address()) { - BAN::Vector full_packet_buffer; - TRY(full_packet_buffer.resize(sizeof(ARPPacket) + sizeof(EthernetHeader))); - auto full_packet = BAN::ByteSpan { full_packet_buffer.span() }; - - auto& ethernet_header = full_packet.as(); - ethernet_header.dst_mac = packet.sha; - ethernet_header.src_mac = interface.get_mac_address(); - ethernet_header.ether_type = EtherType::ARP; - - auto& arp_request = full_packet.slice(sizeof(EthernetHeader)).as(); - arp_request.htype = 0x0001; - arp_request.ptype = EtherType::IPv4; - arp_request.hlen = 0x06; - arp_request.plen = 0x04; - arp_request.oper = ARPOperation::Reply; - arp_request.sha = interface.get_mac_address(); - arp_request.spa = interface.get_ipv4_address(); - arp_request.tha = packet.sha; - arp_request.tpa = packet.spa; - - TRY(interface.send_raw_bytes(full_packet)); + ARPPacket arp_reply; + arp_reply.htype = 0x0001; + arp_reply.ptype = EtherType::IPv4; + arp_reply.hlen = 0x06; + arp_reply.plen = 0x04; + arp_reply.oper = ARPOperation::Reply; + arp_reply.sha = interface.get_mac_address(); + arp_reply.spa = interface.get_ipv4_address(); + arp_reply.tha = packet.sha; + arp_reply.tpa = packet.spa; + TRY(interface.send_bytes(packet.sha, EtherType::ARP, BAN::ConstByteSpan::from(arp_reply))); } break; } diff --git a/kernel/kernel/Networking/E1000/E1000.cpp b/kernel/kernel/Networking/E1000/E1000.cpp index 4f5fb1076c..31f07f6bbb 100644 --- a/kernel/kernel/Networking/E1000/E1000.cpp +++ b/kernel/kernel/Networking/E1000/E1000.cpp @@ -256,19 +256,26 @@ namespace Kernel return {}; } - BAN::ErrorOr E1000::send_raw_bytes(BAN::ConstByteSpan buffer) + + BAN::ErrorOr E1000::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer) { - ASSERT_LTE(buffer.size(), E1000_TX_BUFFER_SIZE); + ASSERT_LTE(buffer.size() + sizeof(EthernetHeader), E1000_TX_BUFFER_SIZE); CriticalScope _; size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT; - auto* tx_buffer = reinterpret_cast(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * tx_current); - memcpy(tx_buffer, buffer.data(), buffer.size()); + auto* tx_buffer = reinterpret_cast(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * tx_current); + + auto& ethernet_header = *reinterpret_cast(tx_buffer); + ethernet_header.dst_mac = destination; + ethernet_header.src_mac = get_mac_address(); + ethernet_header.ether_type = protocol; + + memcpy(tx_buffer + sizeof(EthernetHeader), buffer.data(), buffer.size()); auto& descriptor = reinterpret_cast(m_tx_descriptor_region->vaddr())[tx_current]; - descriptor.length = buffer.size(); + descriptor.length = sizeof(EthernetHeader) + buffer.size(); descriptor.status = 0; descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS; diff --git a/kernel/kernel/Networking/IPv4.cpp b/kernel/kernel/Networking/IPv4.cpp deleted file mode 100644 index 33f9044d7f..0000000000 --- a/kernel/kernel/Networking/IPv4.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include -#include - -namespace Kernel -{ - - void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) - { - auto& header = packet.as(); - header.version_IHL = 0x45; - header.DSCP_ECN = 0x00; - header.total_length = packet.size(); - header.identification = 1; - header.flags_frament = 0x00; - header.time_to_live = 0x40; - header.protocol = protocol; - header.src_address = src_ipv4; - header.dst_address = dst_ipv4; - header.checksum = header.calculate_checksum(); - } - -} diff --git a/kernel/kernel/Networking/IPv4Layer.cpp b/kernel/kernel/Networking/IPv4Layer.cpp new file mode 100644 index 0000000000..1660d66cdf --- /dev/null +++ b/kernel/kernel/Networking/IPv4Layer.cpp @@ -0,0 +1,284 @@ +#include +#include +#include +#include +#include +#include + +#include + +#define DEBUG_IPV4 0 + +namespace Kernel +{ + + BAN::ErrorOr> IPv4Layer::create() + { + auto ipv4_manager = TRY(BAN::UniqPtr::create()); + ipv4_manager->m_process = Process::create_kernel( + [](void* ipv4_manager_ptr) + { + auto& ipv4_manager = *reinterpret_cast(ipv4_manager_ptr); + ipv4_manager.packet_handle_task(); + }, ipv4_manager.ptr() + ); + ASSERT(ipv4_manager->m_process); + ipv4_manager->m_pending_packet_buffer = TRY(VirtualRange::create_to_vaddr_range( + PageTable::kernel(), + KERNEL_OFFSET, + ~(uintptr_t)0, + pending_packet_buffer_size, + PageTable::Flags::ReadWrite | PageTable::Flags::Present, + true + )); + ipv4_manager->m_arp_table = TRY(ARPTable::create()); + return ipv4_manager; + } + + IPv4Layer::IPv4Layer() + { } + + IPv4Layer::~IPv4Layer() + { + if (m_process) + m_process->exit(0, SIGKILL); + m_process = nullptr; + } + + void IPv4Layer::add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const + { + auto& header = packet.as(); + header.version_IHL = 0x45; + header.DSCP_ECN = 0x00; + header.total_length = packet.size(); + header.identification = 1; + header.flags_frament = 0x00; + header.time_to_live = 0x40; + header.protocol = protocol; + header.src_address = src_ipv4; + header.dst_address = dst_ipv4; + header.checksum = header.calculate_checksum(); + } + + void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr socket) + { + LockGuard _(m_lock); + if (m_bound_sockets.contains(port)) + { + ASSERT(m_bound_sockets[port].valid()); + ASSERT(m_bound_sockets[port].lock() == socket); + m_bound_sockets.remove(port); + } + NetworkManager::get().TmpFileSystem::remove_from_cache(socket); + } + + BAN::ErrorOr IPv4Layer::bind_socket(uint16_t port, BAN::RefPtr socket) + { + if (NetworkManager::get().interfaces().empty()) + return BAN::Error::from_errno(EADDRNOTAVAIL); + + LockGuard _(m_lock); + + if (port == NetworkSocket::PORT_NONE) + { + for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++) + { + if (!m_bound_sockets.contains(temp)) + { + port = temp; + break; + } + } + if (port == NetworkSocket::PORT_NONE) + { + dwarnln("No ports available"); + return BAN::Error::from_errno(EAGAIN); + } + } + + if (m_bound_sockets.contains(port)) + return BAN::Error::from_errno(EADDRINUSE); + TRY(m_bound_sockets.insert(port, socket)); + + // FIXME: actually determine proper interface + auto interface = NetworkManager::get().interfaces().front(); + socket->bind_interface_and_port(interface.ptr(), port); + + return {}; + } + + BAN::ErrorOr IPv4Layer::sendto(NetworkSocket& socket, const sys_sendto_t* arguments) + { + if (arguments->dest_addr->sa_family != AF_INET) + return BAN::Error::from_errno(EINVAL); + auto& sockaddr_in = *reinterpret_cast(arguments->dest_addr); + + auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port); + auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr }; + auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4)); + + BAN::Vector packet_buffer; + TRY(packet_buffer.resize(arguments->length + sizeof(IPv4Header) + socket.protocol_header_size())); + auto packet = BAN::ByteSpan { packet_buffer.span() }; + + memcpy( + packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(), + arguments->message, + arguments->length + ); + socket.add_protocol_header( + packet.slice(sizeof(IPv4Header)), + dst_port + ); + add_ipv4_header( + packet, + socket.interface().get_ipv4_address(), + dst_ipv4, + socket.protocol() + ); + + TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet)); + + return arguments->length; + } + + static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet) + { + uint32_t checksum = 0; + for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++) + checksum += BAN::host_to_network_endian(reinterpret_cast(packet.data())[i]); + while (checksum >> 16) + checksum = (checksum >> 16) | (checksum & 0xFFFF); + return ~(uint16_t)checksum; + } + + BAN::ErrorOr IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet) + { + auto& ipv4_header = packet.as(); + auto ipv4_data = packet.slice(sizeof(IPv4Header)); + + ASSERT(ipv4_header.is_valid_checksum()); + + auto src_ipv4 = ipv4_header.src_address; + switch (ipv4_header.protocol) + { + case NetworkProtocol::ICMP: + { + auto& icmp_header = ipv4_data.as(); + switch (icmp_header.type) + { + case ICMPType::EchoRequest: + { + auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4)); + + auto& reply_icmp_header = ipv4_data.as(); + reply_icmp_header.type = ICMPType::EchoReply; + reply_icmp_header.checksum = 0; + reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data); + + add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP); + + TRY(interface.send_bytes(dst_mac, EtherType::IPv4, packet)); + break; + } + default: + dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type); + break; + } + break; + } + case NetworkProtocol::UDP: + { + auto& udp_header = ipv4_data.as(); + uint16_t src_port = udp_header.src_port; + uint16_t dst_port = udp_header.dst_port; + + LockGuard _(m_lock); + + if (!m_bound_sockets.contains(dst_port) || !m_bound_sockets[dst_port].valid()) + { + dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); + return {}; + } + + auto udp_data = ipv4_data.slice(sizeof(UDPHeader)); + m_bound_sockets[dst_port].lock()->add_packet(udp_data, src_ipv4, src_port); + break; + } + default: + dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol); + break; + } + + return {}; + } + + void IPv4Layer::packet_handle_task() + { + for (;;) + { + BAN::Optional pending; + + { + CriticalScope _; + if (!m_pending_packets.empty()) + { + pending = m_pending_packets.front(); + m_pending_packets.pop(); + } + } + + if (!pending.has_value()) + { + m_pending_semaphore.block(); + continue; + } + + uint8_t* buffer_start = reinterpret_cast(m_pending_packet_buffer->vaddr()); + const size_t ipv4_packet_size = reinterpret_cast(buffer_start)->total_length; + + if (auto ret = handle_ipv4_packet(pending->interface, BAN::ByteSpan(buffer_start, ipv4_packet_size)); ret.is_error()) + dwarnln("{}", ret.error()); + + CriticalScope _; + m_pending_total_size -= ipv4_packet_size; + if (m_pending_total_size) + memmove(buffer_start, buffer_start + ipv4_packet_size, m_pending_total_size); + } + } + + void IPv4Layer::add_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer) + { + if (m_pending_packets.full()) + { + dwarnln("IPv4 packet queue full"); + return; + } + + if (m_pending_total_size + buffer.size() > m_pending_packet_buffer->size()) + { + dwarnln("IPv4 packet queue full"); + return; + } + + auto& ipv4_header = buffer.as(); + if (!ipv4_header.is_valid_checksum()) + { + dwarnln("Invalid IPv4 packet"); + return; + } + if (ipv4_header.total_length > buffer.size()) + { + dwarnln("Too short IPv4 packet"); + return; + } + + uint8_t* buffer_start = reinterpret_cast(m_pending_packet_buffer->vaddr()); + memcpy(buffer_start + m_pending_total_size, buffer.data(), ipv4_header.total_length); + m_pending_total_size += ipv4_header.total_length; + + m_pending_packets.push({ .interface = interface }); + m_pending_semaphore.unblock(); + } + +} diff --git a/kernel/kernel/Networking/NetworkInterface.cpp b/kernel/kernel/Networking/NetworkInterface.cpp index 0656f6d945..d4d1be2a13 100644 --- a/kernel/kernel/Networking/NetworkInterface.cpp +++ b/kernel/kernel/Networking/NetworkInterface.cpp @@ -32,19 +32,4 @@ namespace Kernel m_name[3] = minor(m_rdev) + '0'; } - size_t NetworkInterface::interface_header_size() const - { - ASSERT(m_type == Type::Ethernet); - return sizeof(EthernetHeader); - } - - void NetworkInterface::add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination) - { - ASSERT(m_type == Type::Ethernet); - auto& header = packet.as(); - header.dst_mac = destination; - header.src_mac = get_mac_address(); - header.ether_type = 0x0800; - } - } diff --git a/kernel/kernel/Networking/NetworkManager.cpp b/kernel/kernel/Networking/NetworkManager.cpp index 00ab0c023d..884e47464d 100644 --- a/kernel/kernel/Networking/NetworkManager.cpp +++ b/kernel/kernel/Networking/NetworkManager.cpp @@ -3,10 +3,12 @@ #include #include #include -#include +#include #include #include +#define DEBUG_ETHERTYPE 0 + namespace Kernel { @@ -19,8 +21,8 @@ namespace Kernel if (manager_ptr == nullptr) return BAN::Error::from_errno(ENOMEM); auto manager = BAN::UniqPtr::adopt(manager_ptr); - manager->m_arp_table = TRY(ARPTable::create()); TRY(manager->TmpFileSystem::initialize(0777, 0, 0)); + manager->m_ipv4_layer = TRY(IPv4Layer::create()); s_instance = BAN::move(manager); return {}; } @@ -68,45 +70,41 @@ namespace Kernel return {}; } - BAN::ErrorOr> NetworkManager::create_socket(SocketType type, mode_t mode, uid_t uid, gid_t gid) + BAN::ErrorOr> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid) { + switch (domain) + { + case SocketDomain::INET: + { + if (type != SocketType::DGRAM) + return BAN::Error::from_errno(EPROTOTYPE); + break; + } + default: + return BAN::Error::from_errno(EAFNOSUPPORT); + } + ASSERT((mode & Inode::Mode::TYPE_MASK) == 0); + mode |= Inode::Mode::IFSOCK; - if (type != SocketType::DGRAM) - return BAN::Error::from_errno(EPROTOTYPE); + auto inode_info = create_inode_info(mode, uid, gid); + ino_t ino = TRY(allocate_inode(inode_info)); - auto udp_socket = TRY(UDPSocket::create(mode | Inode::Mode::IFSOCK, uid, gid)); - return BAN::RefPtr(udp_socket); - } - - void NetworkManager::unbind_socket(uint16_t port, BAN::RefPtr socket) - { - if (m_bound_sockets.contains(port)) + BAN::RefPtr socket; + switch (domain) { - ASSERT(m_bound_sockets[port].valid()); - ASSERT(m_bound_sockets[port].lock() == socket); - m_bound_sockets.remove(port); - } - NetworkManager::get().remove_from_cache(socket); - } - - BAN::ErrorOr NetworkManager::bind_socket(uint16_t port, BAN::RefPtr socket) - { - if (m_interfaces.empty()) - return BAN::Error::from_errno(EADDRNOTAVAIL); - - if (port != NetworkSocket::PORT_NONE) - { - if (m_bound_sockets.contains(port)) - return BAN::Error::from_errno(EADDRINUSE); - TRY(m_bound_sockets.insert(port, socket)); + case SocketDomain::INET: + { + if (type == SocketType::DGRAM) + socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info)); + break; + } + default: + ASSERT_NOT_REACHED(); } - // FIXME: actually determine proper interface - auto interface = m_interfaces.front(); - socket->bind_interface_and_port(interface.ptr(), port); - - return {}; + ASSERT(socket); + return socket; } void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet) @@ -117,41 +115,16 @@ namespace Kernel { case EtherType::ARP: { - m_arp_table->add_arp_packet(interface, packet.slice(sizeof(EthernetHeader))); + m_ipv4_layer->arp_table().add_arp_packet(interface, packet.slice(sizeof(EthernetHeader))); break; } case EtherType::IPv4: { - auto ipv4 = packet.slice(sizeof(EthernetHeader)); - auto& ipv4_header = ipv4.as(); - auto src_ipv4 = ipv4_header.src_address; - switch (ipv4_header.protocol) - { - case NetworkProtocol::UDP: - { - auto udp = ipv4.slice(sizeof(IPv4Header)); - auto& udp_header = udp.as(); - uint16_t src_port = udp_header.src_port; - uint16_t dst_port = udp_header.dst_port; - - if (!m_bound_sockets.contains(dst_port)) - { - dprintln("no one is listening on port {}", dst_port); - return; - } - - auto raw = udp.slice(8); - m_bound_sockets[dst_port].lock()->add_packet(raw, src_ipv4, src_port); - break; - } - default: - dprintln("Unknown network protocol 0x{2H}", ipv4_header.protocol); - break; - } + m_ipv4_layer->add_ipv4_packet(interface, packet.slice(sizeof(EthernetHeader))); break; } default: - dprintln("Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type); + dprintln_if(DEBUG_ETHERTYPE, "Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type); break; } } diff --git a/kernel/kernel/Networking/NetworkSocket.cpp b/kernel/kernel/Networking/NetworkSocket.cpp index e0943e0f51..2c34f56ed2 100644 --- a/kernel/kernel/Networking/NetworkSocket.cpp +++ b/kernel/kernel/Networking/NetworkSocket.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -7,13 +6,9 @@ namespace Kernel { - NetworkSocket::NetworkSocket(mode_t mode, uid_t uid, gid_t gid) - // FIXME: what the fuck is this - : TmpInode( - NetworkManager::get(), - MUST(NetworkManager::get().allocate_inode(create_inode_info(mode, uid, gid))), - create_inode_info(mode, uid, gid) - ) + NetworkSocket::NetworkSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) + : TmpInode(NetworkManager::get(), ino, inode_info) + , m_network_layer(network_layer) { } NetworkSocket::~NetworkSocket() @@ -23,7 +18,7 @@ namespace Kernel void NetworkSocket::on_close_impl() { if (m_interface) - NetworkManager::get().unbind_socket(m_port, this); + m_network_layer.unbind_socket(m_port, this); } void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port) @@ -36,17 +31,15 @@ namespace Kernel BAN::ErrorOr NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len) { - if (address_len != sizeof(sockaddr_in)) + if (m_interface || address_len != sizeof(sockaddr_in)) return BAN::Error::from_errno(EINVAL); auto* addr_in = reinterpret_cast(address); uint16_t dst_port = BAN::host_to_network_endian(addr_in->sin_port); - return NetworkManager::get().bind_socket(dst_port, this); + return m_network_layer.bind_socket(dst_port, this); } BAN::ErrorOr NetworkSocket::sendto_impl(const sys_sendto_t* arguments) { - if (arguments->dest_len != sizeof(sockaddr_in)) - return BAN::Error::from_errno(EINVAL); if (arguments->flags) { dprintln("flags not supported"); @@ -54,42 +47,9 @@ namespace Kernel } if (!m_interface) - TRY(NetworkManager::get().bind_socket(PORT_NONE, this)); + TRY(m_network_layer.bind_socket(PORT_NONE, this)); - auto* destination = reinterpret_cast(arguments->dest_addr); - auto message = BAN::ConstByteSpan((const uint8_t*)arguments->message, arguments->length); - - uint16_t dst_port = BAN::host_to_network_endian(destination->sin_port); - if (dst_port == PORT_NONE) - return BAN::Error::from_errno(EINVAL); - - auto dst_addr = BAN::IPv4Address(destination->sin_addr.s_addr); - auto dst_mac = TRY(NetworkManager::get().arp_table().get_mac_from_ipv4(*m_interface, dst_addr)); - - const size_t interface_header_offset = 0; - const size_t interface_header_size = m_interface->interface_header_size(); - - const size_t ipv4_header_offset = interface_header_offset + interface_header_size; - const size_t ipv4_header_size = sizeof(IPv4Header); - - const size_t protocol_header_offset = ipv4_header_offset + ipv4_header_size; - const size_t protocol_header_size = this->protocol_header_size(); - - const size_t payload_offset = protocol_header_offset + protocol_header_size; - const size_t payload_size = message.size(); - - BAN::Vector full_packet; - TRY(full_packet.resize(payload_offset + payload_size)); - - BAN::ByteSpan packet_bytespan { full_packet.span() }; - - memcpy(full_packet.data() + payload_offset, message.data(), payload_size); - add_protocol_header(packet_bytespan.slice(protocol_header_offset), m_port, dst_port); - add_ipv4_header(packet_bytespan.slice(ipv4_header_offset), m_interface->get_ipv4_address(), dst_addr, protocol()); - m_interface->add_interface_header(packet_bytespan.slice(interface_header_offset), dst_mac); - TRY(m_interface->send_raw_bytes(packet_bytespan)); - - return arguments->length; + return TRY(m_network_layer.sendto(*this, arguments)); } BAN::ErrorOr NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments) diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index 35d19f8214..09525f2f03 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -5,9 +5,9 @@ namespace Kernel { - BAN::ErrorOr> UDPSocket::create(mode_t mode, uid_t uid, gid_t gid) + BAN::ErrorOr> UDPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) { - auto socket = TRY(BAN::RefPtr::create(mode, uid, gid)); + auto socket = TRY(BAN::RefPtr::create(network_layer, ino, inode_info)); socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range( PageTable::kernel(), KERNEL_OFFSET, @@ -19,14 +19,14 @@ namespace Kernel return socket; } - UDPSocket::UDPSocket(mode_t mode, uid_t uid, gid_t gid) - : NetworkSocket(mode, uid, gid) + UDPSocket::UDPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) + : NetworkSocket(network_layer, ino, inode_info) { } - void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) + void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) { auto& header = packet.as(); - header.src_port = src_port; + header.src_port = m_port; header.dst_port = dst_port; header.length = packet.size(); header.checksum = 0; diff --git a/kernel/kernel/OpenFileDescriptorSet.cpp b/kernel/kernel/OpenFileDescriptorSet.cpp index 7cce348d6d..89cfa94db5 100644 --- a/kernel/kernel/OpenFileDescriptorSet.cpp +++ b/kernel/kernel/OpenFileDescriptorSet.cpp @@ -80,13 +80,25 @@ namespace Kernel BAN::ErrorOr OpenFileDescriptorSet::socket(int domain, int type, int protocol) { - using SocketType = NetworkManager::SocketType; - - if (domain != AF_INET) - return BAN::Error::from_errno(EAFNOSUPPORT); if (protocol != 0) return BAN::Error::from_errno(EPROTONOSUPPORT); + SocketDomain sock_domain; + switch (domain) + { + case AF_INET: + sock_domain = SocketDomain::INET; + break; + case AF_INET6: + sock_domain = SocketDomain::INET6; + break; + case AF_UNIX: + sock_domain = SocketDomain::UNIX; + break; + default: + return BAN::Error::from_errno(EPROTOTYPE); + } + SocketType sock_type; switch (type) { @@ -103,7 +115,7 @@ namespace Kernel return BAN::Error::from_errno(EPROTOTYPE); } - auto socket = TRY(NetworkManager::get().create_socket(sock_type, 0777, m_credentials.euid(), m_credentials.egid())); + auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid())); int fd = TRY(get_free_fd()); m_open_files[fd] = TRY(BAN::RefPtr::create(socket, "no-path"sv, 0, O_RDWR));