diff --git a/kernel/include/kernel/Networking/ARPTable.h b/kernel/include/kernel/Networking/ARPTable.h index 3a420dc6..0aca85c0 100644 --- a/kernel/include/kernel/Networking/ARPTable.h +++ b/kernel/include/kernel/Networking/ARPTable.h @@ -31,35 +31,18 @@ namespace Kernel public: static BAN::ErrorOr> create(); - ~ARPTable(); BAN::ErrorOr get_mac_from_ipv4(NetworkInterface&, BAN::IPv4Address); - void add_arp_packet(NetworkInterface&, BAN::ConstByteSpan); + BAN::ErrorOr handle_arp_packet(NetworkInterface&, BAN::ConstByteSpan); private: - ARPTable(); - - void packet_handle_task(); - BAN::ErrorOr handle_arp_packet(NetworkInterface&, const ARPPacket&); + ARPTable() = default; private: - struct PendingArpPacket - { - NetworkInterface& interface; - ARPPacket packet; - }; - - private: - SpinLock m_table_lock; - SpinLock m_pending_lock; - + SpinLock m_arp_table_lock; BAN::HashMap m_arp_table; - Thread* m_thread { nullptr }; - BAN::CircularQueue m_pending_packets; - ThreadBlocker m_pending_thread_blocker; - friend class BAN::UniqPtr; }; diff --git a/kernel/include/kernel/Networking/E1000/E1000.h b/kernel/include/kernel/Networking/E1000/E1000.h index 967f66b5..0135bb34 100644 --- a/kernel/include/kernel/Networking/E1000/E1000.h +++ b/kernel/include/kernel/Networking/E1000/E1000.h @@ -23,14 +23,14 @@ namespace Kernel static BAN::ErrorOr> create(PCI::Device&); ~E1000(); - virtual BAN::MACAddress get_mac_address() const override { return m_mac_address; } + BAN::MACAddress get_mac_address() const override { return m_mac_address; } - virtual bool link_up() override { return m_link_up; } - virtual int link_speed() override; + bool link_up() override { return m_link_up; } + int link_speed() override; - virtual size_t payload_mtu() const override { return E1000_RX_BUFFER_SIZE - sizeof(EthernetHeader); } + size_t payload_mtu() const override { return E1000_RX_BUFFER_SIZE - sizeof(EthernetHeader); } - virtual void handle_irq() final override; + void handle_irq() final override; protected: E1000(PCI::Device& pci_device) @@ -45,12 +45,12 @@ namespace Kernel uint32_t read32(uint16_t reg); void write32(uint16_t reg, uint32_t value); - virtual BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; + BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span payload) override; - virtual bool can_read_impl() const override { return false; } - virtual bool can_write_impl() const override { return false; } - virtual bool has_error_impl() const override { return false; } - virtual bool has_hungup_impl() const override { return false; } + bool can_read_impl() const override { return false; } + bool can_write_impl() const override { return false; } + bool has_error_impl() const override { return false; } + bool has_hungup_impl() const override { return false; } private: BAN::ErrorOr read_mac_address(); @@ -61,7 +61,7 @@ namespace Kernel void enable_link(); BAN::ErrorOr enable_interrupt(); - void handle_receive(); + void receive_thread(); protected: PCI::Device& m_pci_device; @@ -75,6 +75,10 @@ namespace Kernel BAN::UniqPtr m_tx_descriptor_region; SpinLock m_lock; + bool m_thread_should_die { false }; + BAN::Atomic m_thread_is_dead { true }; + ThreadBlocker m_thread_blocker; + BAN::MACAddress m_mac_address {}; bool m_link_up { false }; diff --git a/kernel/include/kernel/Networking/E1000/E1000E.h b/kernel/include/kernel/Networking/E1000/E1000E.h index 20dea114..203e3137 100644 --- a/kernel/include/kernel/Networking/E1000/E1000E.h +++ b/kernel/include/kernel/Networking/E1000/E1000E.h @@ -12,8 +12,8 @@ namespace Kernel static BAN::ErrorOr> create(PCI::Device&); protected: - virtual void detect_eeprom() override; - virtual uint32_t eeprom_read(uint8_t addr) override; + void detect_eeprom() override; + uint32_t eeprom_read(uint8_t addr) override; private: E1000E(PCI::Device& pci_device) diff --git a/kernel/include/kernel/Networking/IPv4Layer.h b/kernel/include/kernel/Networking/IPv4Layer.h index b0a0bc8e..a2eb1409 100644 --- a/kernel/include/kernel/Networking/IPv4Layer.h +++ b/kernel/include/kernel/Networking/IPv4Layer.h @@ -38,11 +38,10 @@ namespace Kernel public: static BAN::ErrorOr> create(); - ~IPv4Layer(); ARPTable& arp_table() { return *m_arp_table; } - void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan); + BAN::ErrorOr handle_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan); virtual void unbind_socket(uint16_t port) override; virtual BAN::ErrorOr bind_socket_with_target(BAN::RefPtr, const sockaddr* target_address, socklen_t target_address_len) override; @@ -55,35 +54,15 @@ namespace Kernel virtual size_t header_size() const override { return sizeof(IPv4Header); } private: - IPv4Layer(); - - void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const; + IPv4Layer() = default; BAN::ErrorOr find_free_port(); - void packet_handle_task(); - BAN::ErrorOr handle_ipv4_packet(NetworkInterface&, BAN::ByteSpan); - private: - struct PendingIPv4Packet - { - NetworkInterface& interface; - }; + BAN::UniqPtr m_arp_table; - private: - RecursiveSpinLock m_bound_socket_lock; - - BAN::UniqPtr m_arp_table; - Thread* m_thread { nullptr }; - - static constexpr size_t pending_packet_buffer_size = 128 * PAGE_SIZE; - BAN::UniqPtr m_pending_packet_buffer; - BAN::CircularQueue m_pending_packets; - ThreadBlocker m_pending_thread_blocker; - SpinLock m_pending_lock; - size_t m_pending_total_size { 0 }; - - BAN::HashMap> m_bound_sockets; + RecursiveSpinLock m_bound_socket_lock; + BAN::HashMap> m_bound_sockets; friend class BAN::UniqPtr; }; diff --git a/kernel/include/kernel/Networking/Loopback.h b/kernel/include/kernel/Networking/Loopback.h index 289b4796..9a28b549 100644 --- a/kernel/include/kernel/Networking/Loopback.h +++ b/kernel/include/kernel/Networking/Loopback.h @@ -9,6 +9,7 @@ namespace Kernel { public: static constexpr size_t buffer_size = BAN::numeric_limits::max() + 1; + static constexpr size_t buffer_count = 32; public: static BAN::ErrorOr> create(); @@ -24,8 +25,9 @@ namespace Kernel LoopbackInterface() : NetworkInterface(Type::Loopback) {} + ~LoopbackInterface(); - BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; + BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span payload) override; bool can_read_impl() const override { return false; } bool can_write_impl() const override { return false; } @@ -33,8 +35,27 @@ namespace Kernel bool has_hungup_impl() const override { return false; } private: - SpinLock m_buffer_lock; + void receive_thread(); + + private: + struct Descriptor + { + uint8_t* addr; + uint32_t size; + uint8_t state; + }; + + private: + Mutex m_buffer_lock; BAN::UniqPtr m_buffer; + + uint32_t m_buffer_tail { 0 }; + uint32_t m_buffer_head { 0 }; + Descriptor m_descriptors[buffer_count] {}; + + bool m_thread_should_die { false }; + BAN::Atomic m_thread_is_dead { true }; + ThreadBlocker m_thread_blocker; }; } diff --git a/kernel/include/kernel/Networking/NetworkInterface.h b/kernel/include/kernel/Networking/NetworkInterface.h index 892c6434..e24ba0cd 100644 --- a/kernel/include/kernel/Networking/NetworkInterface.h +++ b/kernel/include/kernel/Networking/NetworkInterface.h @@ -60,7 +60,11 @@ namespace Kernel virtual dev_t rdev() const override { return m_rdev; } virtual BAN::StringView name() const override { return m_name; } - virtual BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) = 0; + BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan payload) + { + return send_bytes(destination, protocol, { &payload, 1 }); + } + virtual BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span payload) = 0; private: const Type m_type; diff --git a/kernel/include/kernel/Networking/NetworkLayer.h b/kernel/include/kernel/Networking/NetworkLayer.h index 54183603..5cc83fc9 100644 --- a/kernel/include/kernel/Networking/NetworkLayer.h +++ b/kernel/include/kernel/Networking/NetworkLayer.h @@ -11,7 +11,7 @@ namespace Kernel BAN::IPv4Address src_ipv4 { 0 }; BAN::IPv4Address dst_ipv4 { 0 }; BAN::NetworkEndian protocol { 0 }; - BAN::NetworkEndian extra { 0 }; + BAN::NetworkEndian length { 0 }; }; static_assert(sizeof(PseudoHeader) == 12); @@ -36,6 +36,7 @@ namespace Kernel NetworkLayer() = default; }; - uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet, const PseudoHeader& pseudo_header); + uint16_t calculate_internet_checksum(BAN::ConstByteSpan buffer); + uint16_t calculate_internet_checksum(BAN::Span buffers); } diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index 6c33fb0a..19fd42eb 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -32,7 +32,7 @@ namespace Kernel BAN::ErrorOr> interface(const sockaddr* target, socklen_t target_len); virtual size_t protocol_header_size() const = 0; - virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0; + virtual void get_protocol_header(BAN::ByteSpan header, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader) = 0; virtual NetworkProtocol protocol() const = 0; virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) = 0; diff --git a/kernel/include/kernel/Networking/RTL8169/RTL8169.h b/kernel/include/kernel/Networking/RTL8169/RTL8169.h index 4bc596f6..77146192 100644 --- a/kernel/include/kernel/Networking/RTL8169/RTL8169.h +++ b/kernel/include/kernel/Networking/RTL8169/RTL8169.h @@ -29,9 +29,11 @@ namespace Kernel : NetworkInterface(Type::Ethernet) , m_pci_device(pci_device) { } + ~RTL8169(); + BAN::ErrorOr initialize(); - virtual BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; + virtual BAN::ErrorOr send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span) override; virtual bool can_read_impl() const override { return false; } virtual bool can_write_impl() const override { return false; } @@ -47,7 +49,7 @@ namespace Kernel void enable_link(); BAN::ErrorOr enable_interrupt(); - void handle_receive(); + void receive_thread(); protected: PCI::Device& m_pci_device; @@ -63,6 +65,9 @@ namespace Kernel BAN::UniqPtr m_tx_descriptor_region; SpinLock m_lock; + + bool m_thread_should_die { false }; + BAN::Atomic m_thread_is_dead { true }; ThreadBlocker m_thread_blocker; uint32_t m_rx_current { 0 }; diff --git a/kernel/include/kernel/Networking/TCPSocket.h b/kernel/include/kernel/Networking/TCPSocket.h index 221080d8..11abddba 100644 --- a/kernel/include/kernel/Networking/TCPSocket.h +++ b/kernel/include/kernel/Networking/TCPSocket.h @@ -50,30 +50,30 @@ namespace Kernel static BAN::ErrorOr> create(NetworkLayer&, const Info&); ~TCPSocket(); - virtual NetworkProtocol protocol() const override { return NetworkProtocol::TCP; } + NetworkProtocol protocol() const override { return NetworkProtocol::TCP; } - virtual size_t protocol_header_size() const override { return sizeof(TCPHeader) + m_tcp_options_bytes; } - virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; + size_t protocol_header_size() const override { return sizeof(TCPHeader) + m_tcp_options_bytes; } + void get_protocol_header(BAN::ByteSpan header, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader) override; protected: - virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*, int) override; - virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr listen_impl(int) override; - virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; - virtual BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; - virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override; - virtual BAN::ErrorOr getsockopt_impl(int, int, void*, socklen_t*) override; - virtual BAN::ErrorOr setsockopt_impl(int, int, const void*, socklen_t) override; + BAN::ErrorOr accept_impl(sockaddr*, socklen_t*, int) override; + BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; + BAN::ErrorOr listen_impl(int) override; + BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; + BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; + BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; + BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override; + BAN::ErrorOr getsockopt_impl(int, int, void*, socklen_t*) override; + BAN::ErrorOr setsockopt_impl(int, int, const void*, socklen_t) override; - virtual BAN::ErrorOr ioctl_impl(int, void*) override; + BAN::ErrorOr ioctl_impl(int, void*) override; - virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; + void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; - virtual bool can_read_impl() const override; - virtual bool can_write_impl() const override; - virtual bool has_error_impl() const override { return false; } - virtual bool has_hungup_impl() const override; + bool can_read_impl() const override; + bool can_write_impl() const override; + bool has_error_impl() const override { return false; } + bool has_hungup_impl() const override; private: enum class State @@ -181,6 +181,7 @@ namespace Kernel bool m_no_delay { false }; bool m_should_send_ack { false }; + bool m_should_send_zero_window { false }; uint64_t m_time_wait_start_ms { 0 }; diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index 81ff5bb2..aa5706d9 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -25,28 +25,28 @@ namespace Kernel public: static BAN::ErrorOr> create(NetworkLayer&, const Socket::Info&); - virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } + NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } - virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); } - virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; + size_t protocol_header_size() const override { return sizeof(UDPHeader); } + void get_protocol_header(BAN::ByteSpan header, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader) override; protected: - virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; + void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; - virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; - virtual BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; - virtual BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; - virtual BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override { return BAN::Error::from_errno(ENOTCONN); } - virtual BAN::ErrorOr getsockopt_impl(int, int, void*, socklen_t*) override; - virtual BAN::ErrorOr setsockopt_impl(int, int, const void*, socklen_t) override; + BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; + BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; + BAN::ErrorOr recvmsg_impl(msghdr& message, int flags) override; + BAN::ErrorOr sendmsg_impl(const msghdr& message, int flags) override; + BAN::ErrorOr getpeername_impl(sockaddr*, socklen_t*) override { return BAN::Error::from_errno(ENOTCONN); } + BAN::ErrorOr getsockopt_impl(int, int, void*, socklen_t*) override; + BAN::ErrorOr setsockopt_impl(int, int, const void*, socklen_t) override; - virtual BAN::ErrorOr ioctl_impl(int, void*) override; + BAN::ErrorOr ioctl_impl(int, void*) override; - virtual bool can_read_impl() const override { return !m_packets.empty(); } - virtual bool can_write_impl() const override { return true; } - virtual bool has_error_impl() const override { return false; } - virtual bool has_hungup_impl() const override { return false; } + bool can_read_impl() const override { return !m_packets.empty(); } + bool can_write_impl() const override { return true; } + bool has_error_impl() const override { return false; } + bool has_hungup_impl() const override { return false; } private: UDPSocket(NetworkLayer&, const Socket::Info&); diff --git a/kernel/kernel/Networking/ARPTable.cpp b/kernel/kernel/Networking/ARPTable.cpp index b255ec71..6c2ee529 100644 --- a/kernel/kernel/Networking/ARPTable.cpp +++ b/kernel/kernel/Networking/ARPTable.cpp @@ -17,27 +17,7 @@ namespace Kernel BAN::ErrorOr> ARPTable::create() { - auto arp_table = TRY(BAN::UniqPtr::create()); - arp_table->m_thread = TRY(Thread::create_kernel( - [](void* arp_table_ptr) - { - auto& arp_table = *reinterpret_cast(arp_table_ptr); - arp_table.packet_handle_task(); - }, arp_table.ptr() - )); - TRY(Processor::scheduler().add_thread(arp_table->m_thread)); - return arp_table; - } - - ARPTable::ARPTable() - { - } - - ARPTable::~ARPTable() - { - if (m_thread) - m_thread->add_signal(SIGKILL, {}); - m_thread = nullptr; + return TRY(BAN::UniqPtr::create()); } BAN::ErrorOr ARPTable::get_mac_from_ipv4(NetworkInterface& interface, BAN::IPv4Address ipv4_address) @@ -64,7 +44,7 @@ namespace Kernel ipv4_address = interface.get_gateway(); { - SpinLockGuard _(m_table_lock); + SpinLockGuard _(m_arp_table_lock); auto it = m_arp_table.find(ipv4_address); if (it != m_arp_table.end()) return it->value; @@ -87,7 +67,7 @@ namespace Kernel while (SystemTimer::get().ms_since_boot() < timeout) { { - SpinLockGuard _(m_table_lock); + SpinLockGuard _(m_arp_table_lock); auto it = m_arp_table.find(ipv4_address); if (it != m_arp_table.end()) return it->value; @@ -98,8 +78,16 @@ namespace Kernel return BAN::Error::from_errno(ETIMEDOUT); } - BAN::ErrorOr ARPTable::handle_arp_packet(NetworkInterface& interface, const ARPPacket& packet) + BAN::ErrorOr ARPTable::handle_arp_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer) { + if (buffer.size() < sizeof(ARPPacket)) + { + dwarnln_if(DEBUG_ARP, "Too small ARP packet"); + return {}; + } + + const auto& packet = buffer.as(); + if (packet.ptype != EtherType::IPv4) { dprintln("Non IPv4 arp packet?"); @@ -112,23 +100,24 @@ namespace Kernel { if (packet.tpa == interface.get_ipv4_address()) { - ARPPacket arp_reply; - arp_reply.htype = 0x0001; - arp_reply.ptype = EtherType::IPv4; - arp_reply.hlen = 0x06; - arp_reply.plen = 0x04; - arp_reply.oper = ARPOperation::Reply; - arp_reply.sha = interface.get_mac_address(); - arp_reply.spa = interface.get_ipv4_address(); - arp_reply.tha = packet.sha; - arp_reply.tpa = packet.spa; + const ARPPacket arp_reply { + .htype = 0x0001, + .ptype = EtherType::IPv4, + .hlen = 0x06, + .plen = 0x04, + .oper = ARPOperation::Reply, + .sha = interface.get_mac_address(), + .spa = interface.get_ipv4_address(), + .tha = packet.sha, + .tpa = packet.spa, + }; TRY(interface.send_bytes(packet.sha, EtherType::ARP, BAN::ConstByteSpan::from(arp_reply))); } break; } case ARPOperation::Reply: { - SpinLockGuard _(m_table_lock); + SpinLockGuard _(m_arp_table_lock); auto it = m_arp_table.find(packet.spa); if (it != m_arp_table.end()) @@ -154,48 +143,4 @@ namespace Kernel return {}; } - void ARPTable::packet_handle_task() - { - for (;;) - { - PendingArpPacket pending = ({ - SpinLockGuard guard(m_pending_lock); - while (m_pending_packets.empty()) - { - SpinLockGuardAsMutex smutex(guard); - m_pending_thread_blocker.block_indefinite(&smutex); - } - - auto packet = m_pending_packets.front(); - m_pending_packets.pop(); - - packet; - }); - - if (auto ret = handle_arp_packet(pending.interface, pending.packet); ret.is_error()) - dwarnln("{}", ret.error()); - } - } - - void ARPTable::add_arp_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer) - { - if (buffer.size() < sizeof(ARPPacket)) - { - dwarnln_if(DEBUG_ARP, "ARP packet too small"); - return; - } - auto& arp_packet = buffer.as(); - - SpinLockGuard _(m_pending_lock); - - if (m_pending_packets.full()) - { - dwarnln_if(DEBUG_ARP, "ARP packet queue full"); - return; - } - - m_pending_packets.push({ .interface = interface, .packet = arp_packet }); - m_pending_thread_blocker.unblock(); - } - } diff --git a/kernel/kernel/Networking/E1000/E1000.cpp b/kernel/kernel/Networking/E1000/E1000.cpp index 5d9cc368..607bbe51 100644 --- a/kernel/kernel/Networking/E1000/E1000.cpp +++ b/kernel/kernel/Networking/E1000/E1000.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -57,6 +58,11 @@ namespace Kernel E1000::~E1000() { + m_thread_should_die = true; + m_thread_blocker.unblock(); + + while (!m_thread_is_dead) + Processor::yield(); } BAN::ErrorOr E1000::initialize() @@ -84,6 +90,16 @@ namespace Kernel dprintln(" link speed: {} Mbps", speed); } + auto* thread = TRY(Thread::create_kernel([](void* e1000_ptr) { + static_cast(e1000_ptr)->receive_thread(); + }, this)); + if (auto ret = Processor::scheduler().add_thread(thread); ret.is_error()) + { + delete thread; + return ret.release_error(); + } + m_thread_is_dead = false; + return {}; } @@ -259,10 +275,8 @@ namespace Kernel return {}; } - BAN::ErrorOr E1000::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer) + BAN::ErrorOr E1000::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span payload) { - ASSERT(buffer.size() + sizeof(EthernetHeader) <= E1000_TX_BUFFER_SIZE); - SpinLockGuard _(m_lock); size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT; @@ -274,48 +288,75 @@ namespace Kernel ethernet_header.src_mac = get_mac_address(); ethernet_header.ether_type = protocol; - memcpy(tx_buffer + sizeof(EthernetHeader), buffer.data(), buffer.size()); + size_t packet_size = sizeof(EthernetHeader); + for (const auto& buffer : payload) + { + ASSERT(packet_size + buffer.size() < E1000_TX_BUFFER_SIZE); + memcpy(tx_buffer + packet_size, buffer.data(), buffer.size()); + packet_size += buffer.size(); + } auto& descriptor = reinterpret_cast(m_tx_descriptor_region->vaddr())[tx_current]; - descriptor.length = sizeof(EthernetHeader) + buffer.size(); + descriptor.length = packet_size; descriptor.status = 0; descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS; + // FIXME: there isnt really any reason to wait for transmission write32(REG_TDT, (tx_current + 1) % E1000_TX_DESCRIPTOR_COUNT); while (descriptor.status == 0) continue; - dprintln_if(DEBUG_E1000, "sent {} bytes", sizeof(EthernetHeader) + buffer.size()); + dprintln_if(DEBUG_E1000, "sent {} bytes", packet_size); return {}; } + void E1000::receive_thread() + { + SpinLockGuard _(m_lock); + + while (!m_thread_should_die) + { + for (;;) + { + const uint32_t rx_current = (read32(REG_RDT0) + 1) % E1000_RX_DESCRIPTOR_COUNT; + + auto& descriptor = reinterpret_cast(m_rx_descriptor_region->vaddr())[rx_current]; + if (!(descriptor.status & 1)) + break; + ASSERT(descriptor.length <= E1000_RX_BUFFER_SIZE); + + dprintln_if(DEBUG_E1000, "got {} bytes", (uint16_t)descriptor.length); + + m_lock.unlock(InterruptState::Enabled); + + NetworkManager::get().on_receive(*this, BAN::ConstByteSpan { + reinterpret_cast(m_rx_buffer_region->vaddr() + rx_current * E1000_RX_BUFFER_SIZE), + descriptor.length + }); + + m_lock.lock(); + + descriptor.status = 0; + write32(REG_RDT0, rx_current); + } + + SpinLockAsMutex smutex(m_lock, InterruptState::Enabled); + m_thread_blocker.block_indefinite(&smutex); + } + + m_thread_is_dead = true; + } + void E1000::handle_irq() { const uint32_t icr = read32(REG_ICR); - if (!(icr & (ICR_RxQ0 | ICR_RXT0))) - return; write32(REG_ICR, icr); - SpinLockGuard _(m_lock); - - for (;;) { - uint32_t rx_current = (read32(REG_RDT0) + 1) % E1000_RX_DESCRIPTOR_COUNT; - - auto& descriptor = reinterpret_cast(m_rx_descriptor_region->vaddr())[rx_current]; - if (!(descriptor.status & 1)) - break; - ASSERT(descriptor.length <= E1000_RX_BUFFER_SIZE); - - dprintln_if(DEBUG_E1000, "got {} bytes", (uint16_t)descriptor.length); - - NetworkManager::get().on_receive(*this, BAN::ConstByteSpan { - reinterpret_cast(m_rx_buffer_region->vaddr() + rx_current * E1000_RX_BUFFER_SIZE), - descriptor.length - }); - - descriptor.status = 0; - write32(REG_RDT0, rx_current); + if (icr & (ICR_RxQ0 | ICR_RXT0)) + { + SpinLockGuard _(m_lock); + m_thread_blocker.unblock(); } } diff --git a/kernel/kernel/Networking/IPv4Layer.cpp b/kernel/kernel/Networking/IPv4Layer.cpp index 3061a4fb..6b0ea71c 100644 --- a/kernel/kernel/Networking/IPv4Layer.cpp +++ b/kernel/kernel/Networking/IPv4Layer.cpp @@ -21,50 +21,26 @@ namespace Kernel BAN::ErrorOr> IPv4Layer::create() { auto ipv4_manager = TRY(BAN::UniqPtr::create()); - ipv4_manager->m_thread = TRY(Thread::create_kernel( - [](void* ipv4_manager_ptr) - { - auto& ipv4_manager = *reinterpret_cast(ipv4_manager_ptr); - ipv4_manager.packet_handle_task(); - }, ipv4_manager.ptr() - )); - TRY(Processor::scheduler().add_thread(ipv4_manager->m_thread)); - ipv4_manager->m_pending_packet_buffer = TRY(VirtualRange::create_to_vaddr_range( - PageTable::kernel(), - KERNEL_OFFSET, - ~(uintptr_t)0, - pending_packet_buffer_size, - PageTable::Flags::ReadWrite | PageTable::Flags::Present, - true, false - )); ipv4_manager->m_arp_table = TRY(ARPTable::create()); return ipv4_manager; } - IPv4Layer::IPv4Layer() - { } - - IPv4Layer::~IPv4Layer() + static IPv4Header get_ipv4_header(size_t packet_size, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) { - if (m_thread) - m_thread->add_signal(SIGKILL, {}); - m_thread = nullptr; - } - - void IPv4Layer::add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const - { - auto& header = packet.as(); - header.version_IHL = 0x45; - header.DSCP_ECN = 0x00; - header.total_length = packet.size(); - header.identification = 1; - header.flags_frament = 0x00; - header.time_to_live = 0x40; - header.protocol = protocol; - header.src_address = src_ipv4; - header.dst_address = dst_ipv4; - header.checksum = 0; - header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {}); + IPv4Header header { + .version_IHL = 0x45, + .DSCP_ECN = 0x00, + .total_length = packet_size, + .identification = 1, + .flags_frament = 0x00, + .time_to_live = 0x40, + .protocol = protocol, + .checksum = 0, + .src_address = src_ipv4, + .dst_address = dst_ipv4, + }; + header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header)); + return header; } void IPv4Layer::unbind_socket(uint16_t port) @@ -204,7 +180,7 @@ namespace Kernel return {}; } - BAN::ErrorOr IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len) + BAN::ErrorOr IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan payload, const sockaddr* address, socklen_t address_len) { if (address->sa_family != AF_INET) return BAN::Error::from_errno(EINVAL); @@ -233,43 +209,61 @@ namespace Kernel return BAN::Error::from_errno(EADDRNOTAVAIL); } - BAN::Vector packet_buffer; - TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size())); - auto packet = BAN::ByteSpan { packet_buffer.span() }; - - auto pseudo_header = PseudoHeader { - .src_ipv4 = interface->get_ipv4_address(), - .dst_ipv4 = dst_ipv4, - .protocol = socket.protocol() - }; - - memcpy( - packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(), - buffer.data(), - buffer.size() - ); - socket.add_protocol_header( - packet.slice(sizeof(IPv4Header)), - dst_port, - pseudo_header - ); - add_ipv4_header( - packet, + const auto ipv4_header = get_ipv4_header( + sizeof(IPv4Header) + socket.protocol_header_size() + payload.size(), interface->get_ipv4_address(), dst_ipv4, socket.protocol() ); - TRY(interface->send_bytes(dst_mac, EtherType::IPv4, packet)); + const auto pseudo_header = PseudoHeader { + .src_ipv4 = interface->get_ipv4_address(), + .dst_ipv4 = dst_ipv4, + .protocol = socket.protocol(), + .length = socket.protocol_header_size() + payload.size() + }; - return buffer.size(); + uint8_t protocol_header_buffer[32]; + ASSERT(socket.protocol_header_size() < sizeof(protocol_header_buffer)); + + auto protocol_header = BAN::ByteSpan::from(protocol_header_buffer).slice(0, socket.protocol_header_size()); + socket.get_protocol_header(protocol_header, payload, dst_port, pseudo_header); + + BAN::ConstByteSpan buffers[] { + BAN::ConstByteSpan::from(ipv4_header), + protocol_header, + payload, + }; + + TRY(interface->send_bytes(dst_mac, EtherType::IPv4, { buffers, sizeof(buffers) / sizeof(*buffers) })); + + return payload.size(); } - BAN::ErrorOr IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet) + BAN::ErrorOr IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan packet) { - ASSERT(packet.size() >= sizeof(IPv4Header)); + if (packet.size() < sizeof(IPv4Header)) + { + dwarnln_if(DEBUG_IPV4, "Too small IPv4 packet"); + return {}; + } + auto& ipv4_header = packet.as(); - auto ipv4_data = packet.slice(sizeof(IPv4Header)); + if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header)) != 0) + { + dwarnln_if(DEBUG_IPV4, "IPv4 packet checksum failed"); + return {}; + } + if (ipv4_header.total_length > packet.size() || ipv4_header.total_length > interface.payload_mtu() || ipv4_header.total_length < sizeof(IPv4Header)) + { + if (ipv4_header.flags_frament & IPv4Flags::DF) + dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet"); + else + dwarnln_if(DEBUG_IPV4, "IPv4 fragmentation not supported"); + return {}; + } + + auto ipv4_data = packet.slice(0, ipv4_header.total_length).slice(sizeof(IPv4Header)); auto src_ipv4 = ipv4_header.src_address; @@ -292,14 +286,33 @@ namespace Kernel { auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4)); - auto& reply_icmp_header = ipv4_data.as(); - reply_icmp_header.type = ICMPType::EchoReply; - reply_icmp_header.checksum = 0; - reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data, {}); + auto send_ipv4_header = get_ipv4_header( + ipv4_data.size(), + interface.get_ipv4_address(), + src_ipv4, + NetworkProtocol::ICMP + ); - add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP); + ICMPHeader send_icmp_header { + .type = ICMPType::EchoReply, + .code = icmp_header.code, + .checksum = 0, + .rest = icmp_header.rest, + }; + + auto send_payload = ipv4_data.slice(sizeof(ICMPHeader)); + + const BAN::ConstByteSpan send_buffers[] { + BAN::ConstByteSpan::from(send_ipv4_header), + BAN::ConstByteSpan::from(send_icmp_header), + send_payload + }; + auto send_buffers_span = BAN::Span { send_buffers, sizeof(send_buffers) / sizeof(*send_buffers) }; + + send_icmp_header.checksum = calculate_internet_checksum(send_buffers_span.slice(1)); + + TRY(interface.send_bytes(dst_mac, EtherType::IPv4, send_buffers_span)); - TRY(interface.send_bytes(dst_mac, EtherType::IPv4, packet)); break; } case ICMPType::DestinationUnreachable: @@ -381,80 +394,4 @@ namespace Kernel return {}; } - void IPv4Layer::packet_handle_task() - { - for (;;) - { - PendingIPv4Packet pending = ({ - SpinLockGuard guard(m_pending_lock); - while (m_pending_packets.empty()) - { - SpinLockGuardAsMutex smutex(guard); - m_pending_thread_blocker.block_indefinite(&smutex); - } - - auto packet = m_pending_packets.front(); - m_pending_packets.pop(); - - packet; - }); - - uint8_t* buffer_start = reinterpret_cast(m_pending_packet_buffer->vaddr()); - const size_t ipv4_packet_size = reinterpret_cast(buffer_start)->total_length; - - if (auto ret = handle_ipv4_packet(pending.interface, BAN::ByteSpan(buffer_start, ipv4_packet_size)); ret.is_error()) - dwarnln_if(DEBUG_IPV4, "{}", ret.error()); - - SpinLockGuard _(m_pending_lock); - m_pending_total_size -= ipv4_packet_size; - if (m_pending_total_size) - memmove(buffer_start, buffer_start + ipv4_packet_size, m_pending_total_size); - } - } - - void IPv4Layer::add_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer) - { - if (buffer.size() < sizeof(IPv4Header)) - { - dwarnln_if(DEBUG_IPV4, "IPv4 packet too small"); - return; - } - - SpinLockGuard _(m_pending_lock); - - if (m_pending_packets.full()) - { - dwarnln_if(DEBUG_IPV4, "IPv4 packet queue full"); - return; - } - - if (m_pending_total_size + buffer.size() > m_pending_packet_buffer->size()) - { - dwarnln_if(DEBUG_IPV4, "IPv4 packet queue full"); - return; - } - - auto& ipv4_header = buffer.as(); - if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header), {}) != 0) - { - dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet"); - return; - } - if (ipv4_header.total_length > buffer.size() || ipv4_header.total_length > interface.payload_mtu()) - { - if (ipv4_header.flags_frament & IPv4Flags::DF) - dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet"); - else - dwarnln_if(DEBUG_IPV4, "IPv4 fragmentation not supported"); - return; - } - - uint8_t* buffer_start = reinterpret_cast(m_pending_packet_buffer->vaddr()); - memcpy(buffer_start + m_pending_total_size, buffer.data(), ipv4_header.total_length); - m_pending_total_size += ipv4_header.total_length; - - m_pending_packets.push({ .interface = interface }); - m_pending_thread_blocker.unblock(); - } - } diff --git a/kernel/kernel/Networking/Loopback.cpp b/kernel/kernel/Networking/Loopback.cpp index 9c1b6e58..0fc51f6e 100644 --- a/kernel/kernel/Networking/Loopback.cpp +++ b/kernel/kernel/Networking/Loopback.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -10,40 +11,121 @@ namespace Kernel if (loopback_ptr == nullptr) return BAN::Error::from_errno(ENOMEM); auto loopback = BAN::RefPtr::adopt(loopback_ptr); + loopback->m_buffer = TRY(VirtualRange::create_to_vaddr_range( PageTable::kernel(), KERNEL_OFFSET, BAN::numeric_limits::max(), - buffer_size, + buffer_size * buffer_count, PageTable::Flags::ReadWrite | PageTable::Flags::Present, true, false )); + + auto* thread = TRY(Thread::create_kernel([](void* loopback_ptr) { + static_cast(loopback_ptr)->receive_thread(); + }, loopback_ptr)); + if (auto ret = Processor::scheduler().add_thread(thread); ret.is_error()) + { + delete thread; + return ret.release_error(); + } + loopback->m_thread_is_dead = false; + loopback->set_ipv4_address({ 127, 0, 0, 1 }); loopback->set_netmask({ 255, 0, 0, 0 }); + + for (size_t i = 0; i < buffer_count; i++) + { + loopback->m_descriptors[i] = { + .addr = reinterpret_cast(loopback->m_buffer->vaddr()) + i * buffer_size, + .size = 0, + .state = 0, + }; + } + return loopback; } - BAN::ErrorOr LoopbackInterface::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer) + LoopbackInterface::~LoopbackInterface() { - ASSERT(buffer.size() + sizeof(EthernetHeader) <= buffer_size); + m_thread_should_die = true; + m_thread_blocker.unblock(); - SpinLockGuard _(m_buffer_lock); + while (!m_thread_is_dead) + Processor::yield(); + } - uint8_t* buffer_vaddr = reinterpret_cast(m_buffer->vaddr()); + BAN::ErrorOr LoopbackInterface::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span payload) + { + auto& descriptor = + [&]() -> Descriptor& + { + LockGuard _(m_buffer_lock); + for (;;) + { + auto& descriptor = m_descriptors[m_buffer_head]; + if (descriptor.state == 0) + { + m_buffer_head = (m_buffer_head + 1) % buffer_count; + descriptor.state = 1; + return descriptor; + } + m_thread_blocker.block_indefinite(&m_buffer_lock); + } + }(); - auto& ethernet_header = *reinterpret_cast(buffer_vaddr); + auto& ethernet_header = *reinterpret_cast(descriptor.addr); ethernet_header.dst_mac = destination; ethernet_header.src_mac = get_mac_address(); ethernet_header.ether_type = protocol; - memcpy(buffer_vaddr + sizeof(EthernetHeader), buffer.data(), buffer.size()); + size_t packet_size = sizeof(EthernetHeader); + for (const auto& buffer : payload) + { + ASSERT(packet_size + buffer.size() <= buffer_size); + memcpy(descriptor.addr + packet_size, buffer.data(), buffer.size()); + packet_size += buffer.size(); + } - NetworkManager::get().on_receive(*this, BAN::ConstByteSpan { - buffer_vaddr, - buffer.size() + sizeof(EthernetHeader) - }); + LockGuard _(m_buffer_lock); + descriptor.size = packet_size; + descriptor.state = 2; + m_thread_blocker.unblock(); return {}; } + void LoopbackInterface::receive_thread() + { + LockGuard _(m_buffer_lock); + + while (!m_thread_should_die) + { + for (;;) + { + auto& descriptor = m_descriptors[m_buffer_tail]; + if (descriptor.state != 2) + break; + m_buffer_tail = (m_buffer_tail + 1) % buffer_count; + + m_buffer_lock.unlock(); + + NetworkManager::get().on_receive(*this, { + descriptor.addr, + descriptor.size, + }); + + m_buffer_lock.lock(); + + descriptor.size = 0; + descriptor.state = 0; + m_thread_blocker.unblock(); + } + + m_thread_blocker.block_indefinite(&m_buffer_lock); + } + + m_thread_is_dead = true; + } + } diff --git a/kernel/kernel/Networking/NetworkLayer.cpp b/kernel/kernel/Networking/NetworkLayer.cpp index 05bea4da..75ab9e7a 100644 --- a/kernel/kernel/Networking/NetworkLayer.cpp +++ b/kernel/kernel/Networking/NetworkLayer.cpp @@ -3,15 +3,28 @@ namespace Kernel { - uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet, const PseudoHeader& pseudo_header) + uint16_t calculate_internet_checksum(BAN::ConstByteSpan buffer) + { + return calculate_internet_checksum({ &buffer, 1 }); + } + + uint16_t calculate_internet_checksum(BAN::Span buffers) { uint32_t checksum = 0; - for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(uint16_t); i++) - checksum += BAN::host_to_network_endian(reinterpret_cast(&pseudo_header)[i]); - for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++) - checksum += BAN::host_to_network_endian(reinterpret_cast(packet.data())[i]); - if (packet.size() % 2) - checksum += (uint16_t)packet[packet.size() - 1] << 8; + + for (size_t i = 0; i < buffers.size(); i++) + { + auto buffer = buffers[i]; + + const uint16_t* buffer_u16 = reinterpret_cast(buffer.data()); + for (size_t j = 0; j < buffer.size() / 2; j++) + checksum += BAN::host_to_network_endian(buffer_u16[j]); + if (buffer.size() % 2 == 0) + continue; + ASSERT(i == buffers.size() - 1); + checksum += buffer[buffer.size() - 1] << 8; + } + while (checksum >> 16) checksum = (checksum >> 16) + (checksum & 0xFFFF); return ~(uint16_t)checksum; diff --git a/kernel/kernel/Networking/NetworkManager.cpp b/kernel/kernel/Networking/NetworkManager.cpp index 838d820e..6cc88b99 100644 --- a/kernel/kernel/Networking/NetworkManager.cpp +++ b/kernel/kernel/Networking/NetworkManager.cpp @@ -154,18 +154,18 @@ namespace Kernel return; auto ethernet_header = packet.as(); + auto packet_data = packet.slice(sizeof(EthernetHeader)); + switch (ethernet_header.ether_type) { case EtherType::ARP: - { - m_ipv4_layer->arp_table().add_arp_packet(interface, packet.slice(sizeof(EthernetHeader))); + if (auto ret = m_ipv4_layer->arp_table().handle_arp_packet(interface, packet_data); ret.is_error()) + dwarnln("ARP: {}", ret.error()); break; - } case EtherType::IPv4: - { - m_ipv4_layer->add_ipv4_packet(interface, packet.slice(sizeof(EthernetHeader))); + if (auto ret = m_ipv4_layer->handle_ipv4_packet(interface, packet_data); ret.is_error()) + dwarnln("IPv4; {}", ret.error()); break; - } default: dprintln_if(DEBUG_ETHERTYPE, "Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type); break; diff --git a/kernel/kernel/Networking/RTL8169/RTL8169.cpp b/kernel/kernel/Networking/RTL8169/RTL8169.cpp index 14fedddf..8b6a6342 100644 --- a/kernel/kernel/Networking/RTL8169/RTL8169.cpp +++ b/kernel/kernel/Networking/RTL8169/RTL8169.cpp @@ -7,6 +7,9 @@ namespace Kernel { + // each buffer is 7440 bytes + padding = 8192 + constexpr size_t s_buffer_size = 8192; + bool RTL8169::probe(PCI::Device& pci_device) { if (pci_device.vendor_id() != 0x10ec) @@ -68,9 +71,28 @@ namespace Kernel // lock config registers m_io_bar_region->write8(RTL8169_IO_9346CR, RTL8169_9346CR_MODE_NORMAL); + auto* thread = TRY(Thread::create_kernel([](void* rtl8169_ptr) { + static_cast(rtl8169_ptr)->receive_thread(); + }, this)); + if (auto ret = Processor::scheduler().add_thread(thread); ret.is_error()) + { + delete thread; + return ret.release_error(); + } + m_thread_is_dead = false; + return {}; } + RTL8169::~RTL8169() + { + m_thread_should_die = true; + m_thread_blocker.unblock(); + + while (!m_thread_is_dead) + Processor::yield(); + } + BAN::ErrorOr RTL8169::reset() { m_io_bar_region->write8(RTL8169_IO_CR, RTL8169_CR_RST); @@ -85,15 +107,12 @@ namespace Kernel BAN::ErrorOr RTL8169::initialize_rx() { - // each buffer is 7440 bytes + padding = 8192 - constexpr size_t buffer_size = 2 * PAGE_SIZE; - - m_rx_buffer_region = TRY(DMARegion::create(m_rx_descriptor_count * buffer_size)); + m_rx_buffer_region = TRY(DMARegion::create(m_rx_descriptor_count * s_buffer_size)); m_rx_descriptor_region = TRY(DMARegion::create(m_rx_descriptor_count * sizeof(RTL8169Descriptor))); for (size_t i = 0; i < m_rx_descriptor_count; i++) { - const paddr_t rx_buffer_paddr = m_rx_buffer_region->paddr() + i * buffer_size; + const paddr_t rx_buffer_paddr = m_rx_buffer_region->paddr() + i * s_buffer_size; uint32_t command = 0x1FF8 | RTL8169_DESC_CMD_OWN; if (i == m_rx_descriptor_count - 1) @@ -120,21 +139,17 @@ namespace Kernel // configure max rx packet size m_io_bar_region->write16(RTL8169_IO_RMS, RTL8169_RMS_MAX); - return {}; } BAN::ErrorOr RTL8169::initialize_tx() { - // each buffer is 7440 bytes + padding = 8192 - constexpr size_t buffer_size = 2 * PAGE_SIZE; - - m_tx_buffer_region = TRY(DMARegion::create(m_tx_descriptor_count * buffer_size)); + m_tx_buffer_region = TRY(DMARegion::create(m_tx_descriptor_count * s_buffer_size)); m_tx_descriptor_region = TRY(DMARegion::create(m_tx_descriptor_count * sizeof(RTL8169Descriptor))); for (size_t i = 0; i < m_tx_descriptor_count; i++) { - const paddr_t tx_buffer_paddr = m_tx_buffer_region->paddr() + i * buffer_size; + const paddr_t tx_buffer_paddr = m_tx_buffer_region->paddr() + i * s_buffer_size; uint32_t command = 0; if (i == m_tx_descriptor_count - 1) @@ -194,14 +209,8 @@ namespace Kernel return 0; } - BAN::ErrorOr RTL8169::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer) + BAN::ErrorOr RTL8169::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span payload) { - constexpr size_t buffer_size = 8192; - - const uint16_t packet_size = sizeof(EthernetHeader) + buffer.size(); - if (packet_size > buffer_size) - return BAN::Error::from_errno(EINVAL); - if (!link_up()) return BAN::Error::from_errno(EADDRNOTAVAIL); @@ -219,14 +228,20 @@ namespace Kernel m_lock.unlock(state); - auto* tx_buffer = reinterpret_cast(m_tx_buffer_region->vaddr() + tx_current * buffer_size); + auto* tx_buffer = reinterpret_cast(m_tx_buffer_region->vaddr() + tx_current * s_buffer_size); // write packet auto& ethernet_header = *reinterpret_cast(tx_buffer); ethernet_header.dst_mac = destination; ethernet_header.src_mac = get_mac_address(); ethernet_header.ether_type = protocol; - memcpy(tx_buffer + sizeof(EthernetHeader), buffer.data(), buffer.size()); + + size_t packet_size = sizeof(EthernetHeader); + for (const auto& buffer : payload) + { + memcpy(tx_buffer + packet_size, buffer.data(), buffer.size()); + packet_size += buffer.size(); + } // give packet ownership to NIC uint32_t command = packet_size | RTL8169_DESC_CMD_OWN | RTL8169_DESC_CMD_LS | RTL8169_DESC_CMD_FS; @@ -240,6 +255,50 @@ namespace Kernel return {}; } + void RTL8169::receive_thread() + { + SpinLockGuard _(m_lock); + + while (!m_thread_should_die) + { + for (;;) + { + auto& descriptor = reinterpret_cast(m_rx_descriptor_region->vaddr())[m_rx_current]; + if (descriptor.command & RTL8169_DESC_CMD_OWN) + break; + + // packet buffer can only hold single packet, so we should not receive any multi-descriptor packets + ASSERT((descriptor.command & RTL8169_DESC_CMD_LS) && (descriptor.command & RTL8169_DESC_CMD_FS)); + + const uint16_t packet_length = descriptor.command & 0x3FFF; + if (packet_length > s_buffer_size) + dwarnln("Got {} bytes to {} byte buffer", packet_length, s_buffer_size); + else if (descriptor.command & (1u << 21)) + ; // descriptor has an error + else + { + m_lock.unlock(InterruptState::Enabled); + + NetworkManager::get().on_receive(*this, BAN::ConstByteSpan { + reinterpret_cast(m_rx_buffer_region->vaddr() + m_rx_current * s_buffer_size), + packet_length + }); + + m_lock.lock(); + } + + m_rx_current = (m_rx_current + 1) % m_rx_descriptor_count; + + descriptor.command = descriptor.command | RTL8169_DESC_CMD_OWN; + } + + SpinLockAsMutex smutex(m_lock, InterruptState::Enabled); + m_thread_blocker.block_indefinite(&smutex); + } + + m_thread_is_dead = true; + } + void RTL8169::handle_irq() { const uint16_t interrupt_status = m_io_bar_region->read16(RTL8169_IO_ISR); @@ -251,7 +310,7 @@ namespace Kernel dprintln("link status -> {}", m_link_up.load()); } - if (interrupt_status & RTL8169_IR_TOK) + if (interrupt_status & (RTL8169_IR_TOK | RTL8169_IR_ROK)) { SpinLockGuard _(m_lock); m_thread_blocker.unblock(); @@ -266,38 +325,6 @@ namespace Kernel if (interrupt_status & RTL8169_IR_FVOW) dwarnln("Rx FIFO overflow"); // dont log TDU is sent after each sent packet - - if (!(interrupt_status & RTL8169_IR_ROK)) - return; - - constexpr size_t buffer_size = 8192; - - for (;;) - { - auto& descriptor = reinterpret_cast(m_rx_descriptor_region->vaddr())[m_rx_current]; - if (descriptor.command & RTL8169_DESC_CMD_OWN) - break; - - // packet buffer can only hold single packet, so we should not receive any multi-descriptor packets - ASSERT((descriptor.command & RTL8169_DESC_CMD_LS) && (descriptor.command & RTL8169_DESC_CMD_FS)); - - const uint16_t packet_length = descriptor.command & 0x3FFF; - if (packet_length > buffer_size) - dwarnln("Got {} bytes to {} byte buffer", packet_length, buffer_size); - else if (descriptor.command & (1u << 21)) - ; // descriptor has an error - else - { - NetworkManager::get().on_receive(*this, BAN::ConstByteSpan { - reinterpret_cast(m_rx_buffer_region->vaddr() + m_rx_current * buffer_size), - packet_length - }); - } - - m_rx_current = (m_rx_current + 1) % m_rx_descriptor_count; - - descriptor.command = descriptor.command | RTL8169_DESC_CMD_OWN; - } } } diff --git a/kernel/kernel/Networking/TCPSocket.cpp b/kernel/kernel/Networking/TCPSocket.cpp index cd83846a..b699518b 100644 --- a/kernel/kernel/Networking/TCPSocket.cpp +++ b/kernel/kernel/Networking/TCPSocket.cpp @@ -524,24 +524,33 @@ namespace Kernel return result; } - void TCPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader pseudo_header) + void TCPSocket::get_protocol_header(BAN::ByteSpan header_buffer, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader pseudo_header) { ASSERT(m_next_flags); ASSERT(m_mutex.locker() == Thread::current().tid()); - - auto& header = packet.as(); - memset(&header, 0, sizeof(TCPHeader)); - memset(header.options, TCPOption::End, m_tcp_options_bytes); + ASSERT(header_buffer.size() == protocol_header_size()); m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size; + if (m_should_send_zero_window) + m_last_sent_window_size = 0; + + m_should_send_ack = false; + m_should_send_zero_window = false; + + auto& header = header_buffer.as(); + header = { + .src_port = bound_port(), + .dst_port = dst_port, + .seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte, + .ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte, + .data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t), + .flags = m_next_flags, + .window_size = BAN::Math::min(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift), + .checksum = 0, + .urgent_pointer = 0, + }; + memset(header.options, 0, m_tcp_options_bytes); - header.src_port = bound_port(); - header.dst_port = dst_port; - header.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte; - header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte; - header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t); - header.window_size = BAN::Math::min(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift); - header.flags = m_next_flags; if (header.flags & FIN) m_send_window.has_ghost_byte = true; m_next_flags = 0; @@ -566,10 +575,12 @@ namespace Kernel m_send_window.current_seq = m_send_window.start_seq; } - pseudo_header.extra = packet.size(); - header.checksum = calculate_internet_checksum(packet, pseudo_header); - - m_should_send_ack = false; + const BAN::ConstByteSpan buffers[] { + BAN::ConstByteSpan::from(pseudo_header), + header_buffer, + payload, + }; + header.checksum = calculate_internet_checksum({ buffers, sizeof(buffers) / sizeof(*buffers) }); dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags); dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number); @@ -603,14 +614,17 @@ namespace Kernel auto interface = interface_or_error.release_value(); auto& addr_in = *reinterpret_cast(sender); - checksum = calculate_internet_checksum(buffer, - PseudoHeader { - .src_ipv4 = BAN::IPv4Address(addr_in.sin_addr.s_addr), - .dst_ipv4 = interface->get_ipv4_address(), - .protocol = NetworkProtocol::TCP, - .extra = buffer.size() - } - ); + const PseudoHeader pseudo_header { + .src_ipv4 = BAN::IPv4Address(addr_in.sin_addr.s_addr), + .dst_ipv4 = interface->get_ipv4_address(), + .protocol = NetworkProtocol::TCP, + .length = buffer.size(), + }; + const BAN::ConstByteSpan buffers[] { + BAN::ConstByteSpan::from(pseudo_header), + buffer + }; + checksum = calculate_internet_checksum({ buffers, sizeof(buffers) / sizeof(*buffers) }); } else { @@ -757,9 +771,9 @@ namespace Kernel break; } - // TODO: even without SACKs, if other end sends seq [0, 1000] and our current seq is 100, we should accept - // packet with seq [100, 1000] - if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte) + const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte; + + if (header.seq_number > expected_seq) dprintln_if(DEBUG_TCP, "Missing packets"); else if (check_payload) { @@ -770,7 +784,19 @@ namespace Kernel m_send_window.current_ack = header.ack_number; auto payload = buffer.slice(header.data_offset * sizeof(uint32_t)); - if (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size()) + + if (header.seq_number < expected_seq) + { + const uint32_t already_received_bytes = expected_seq - header.seq_number; + if (already_received_bytes <= payload.size()) + payload = payload.slice(already_received_bytes); + else + payload = {}; + } + + const bool can_receive_new_data = (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size()); + + if (can_receive_new_data) { auto* recv_base = reinterpret_cast(m_recv_window.buffer->vaddr()); @@ -787,12 +813,12 @@ namespace Kernel epoll_notify(EPOLLIN); dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv); - - m_should_send_ack = true; } // make sure zero window is reported - if (m_next_flags == 0 && m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size()) + if (m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size()) + m_should_send_zero_window = true; + else if (can_receive_new_data) m_should_send_ack = true; } @@ -915,7 +941,9 @@ namespace Kernel const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms); - if (m_send_window.sent_size < m_send_window.scaled_size() && (should_retransmit || m_send_window.data_size > m_send_window.sent_size)) + const bool can_send_new_data = (m_send_window.data_size > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size()); + + if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data)) { m_send_window.had_zero_window = false; @@ -927,7 +955,7 @@ namespace Kernel const size_t total_send = BAN::Math::min( m_send_window.data_size - send_start_offset, - m_send_window.scaled_size() - m_send_window.sent_size + m_send_window.scaled_size() - send_start_offset ); m_send_window.current_seq = m_send_window.start_seq + send_start_offset; @@ -961,15 +989,18 @@ namespace Kernel continue; } - if (m_should_send_ack) + if (const size_t ack_count = m_should_send_ack + m_should_send_zero_window) { ASSERT(m_connection_info.has_value()); auto* target_address = reinterpret_cast(&m_connection_info->address); auto target_address_len = m_connection_info->address_len; - m_next_flags = ACK; - if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) - dwarnln("{}", ret.error()); + for (size_t i = 0; i < ack_count; i++) + { + m_next_flags = ACK; + if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) + dwarnln("{}", ret.error()); + } } m_thread_blocker.unblock(); diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index 3454e070..c32f4815 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -35,13 +35,26 @@ namespace Kernel m_address_len = 0; } - void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) + void UDPSocket::get_protocol_header(BAN::ByteSpan header_buffer, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader pseudo_header) { - auto& header = packet.as(); - header.src_port = bound_port(); - header.dst_port = dst_port; - header.length = packet.size(); - header.checksum = 0; + ASSERT(header_buffer.size() == protocol_header_size()); + + auto& header = header_buffer.as(); + header = { + .src_port = bound_port(), + .dst_port = dst_port, + .length = protocol_header_size() + payload.size(), + .checksum = 0, + }; + + const BAN::ConstByteSpan buffers[] { + BAN::ConstByteSpan::from(pseudo_header), + header_buffer, + payload, + }; + header.checksum = calculate_internet_checksum({ buffers, sizeof(buffers) / sizeof(*buffers) }); + if (header.checksum == 0) + header.checksum = 0xFFFF; } void UDPSocket::receive_packet(BAN::ConstByteSpan packet, const sockaddr* sender, socklen_t sender_len)