Kernel: Optimize networking code

Remove buffering from network layer and rework loopback interface.
loopback now has a separate recieve thread to allow concurrent sends and
prevent deadlocks
This commit is contained in:
Bananymous 2026-02-27 19:08:08 +02:00
parent ff378e4538
commit 9ddf19f605
20 changed files with 563 additions and 476 deletions

View File

@ -31,35 +31,18 @@ namespace Kernel
public: public:
static BAN::ErrorOr<BAN::UniqPtr<ARPTable>> create(); static BAN::ErrorOr<BAN::UniqPtr<ARPTable>> create();
~ARPTable();
BAN::ErrorOr<BAN::MACAddress> get_mac_from_ipv4(NetworkInterface&, BAN::IPv4Address); BAN::ErrorOr<BAN::MACAddress> get_mac_from_ipv4(NetworkInterface&, BAN::IPv4Address);
void add_arp_packet(NetworkInterface&, BAN::ConstByteSpan); BAN::ErrorOr<void> handle_arp_packet(NetworkInterface&, BAN::ConstByteSpan);
private: private:
ARPTable(); ARPTable() = default;
void packet_handle_task();
BAN::ErrorOr<void> handle_arp_packet(NetworkInterface&, const ARPPacket&);
private: private:
struct PendingArpPacket SpinLock m_arp_table_lock;
{
NetworkInterface& interface;
ARPPacket packet;
};
private:
SpinLock m_table_lock;
SpinLock m_pending_lock;
BAN::HashMap<BAN::IPv4Address, BAN::MACAddress> m_arp_table; BAN::HashMap<BAN::IPv4Address, BAN::MACAddress> m_arp_table;
Thread* m_thread { nullptr };
BAN::CircularQueue<PendingArpPacket, 128> m_pending_packets;
ThreadBlocker m_pending_thread_blocker;
friend class BAN::UniqPtr<ARPTable>; friend class BAN::UniqPtr<ARPTable>;
}; };

View File

@ -23,14 +23,14 @@ namespace Kernel
static BAN::ErrorOr<BAN::RefPtr<E1000>> create(PCI::Device&); static BAN::ErrorOr<BAN::RefPtr<E1000>> create(PCI::Device&);
~E1000(); ~E1000();
virtual BAN::MACAddress get_mac_address() const override { return m_mac_address; } BAN::MACAddress get_mac_address() const override { return m_mac_address; }
virtual bool link_up() override { return m_link_up; } bool link_up() override { return m_link_up; }
virtual int link_speed() override; int link_speed() override;
virtual size_t payload_mtu() const override { return E1000_RX_BUFFER_SIZE - sizeof(EthernetHeader); } size_t payload_mtu() const override { return E1000_RX_BUFFER_SIZE - sizeof(EthernetHeader); }
virtual void handle_irq() final override; void handle_irq() final override;
protected: protected:
E1000(PCI::Device& pci_device) E1000(PCI::Device& pci_device)
@ -45,12 +45,12 @@ namespace Kernel
uint32_t read32(uint16_t reg); uint32_t read32(uint16_t reg);
void write32(uint16_t reg, uint32_t value); void write32(uint16_t reg, uint32_t value);
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span<const BAN::ConstByteSpan> payload) override;
virtual bool can_read_impl() const override { return false; } bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return false; } bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; } bool has_error_impl() const override { return false; }
virtual bool has_hungup_impl() const override { return false; } bool has_hungup_impl() const override { return false; }
private: private:
BAN::ErrorOr<void> read_mac_address(); BAN::ErrorOr<void> read_mac_address();
@ -61,7 +61,7 @@ namespace Kernel
void enable_link(); void enable_link();
BAN::ErrorOr<void> enable_interrupt(); BAN::ErrorOr<void> enable_interrupt();
void handle_receive(); void receive_thread();
protected: protected:
PCI::Device& m_pci_device; PCI::Device& m_pci_device;
@ -75,6 +75,10 @@ namespace Kernel
BAN::UniqPtr<DMARegion> m_tx_descriptor_region; BAN::UniqPtr<DMARegion> m_tx_descriptor_region;
SpinLock m_lock; SpinLock m_lock;
bool m_thread_should_die { false };
BAN::Atomic<bool> m_thread_is_dead { true };
ThreadBlocker m_thread_blocker;
BAN::MACAddress m_mac_address {}; BAN::MACAddress m_mac_address {};
bool m_link_up { false }; bool m_link_up { false };

View File

@ -12,8 +12,8 @@ namespace Kernel
static BAN::ErrorOr<BAN::RefPtr<E1000E>> create(PCI::Device&); static BAN::ErrorOr<BAN::RefPtr<E1000E>> create(PCI::Device&);
protected: protected:
virtual void detect_eeprom() override; void detect_eeprom() override;
virtual uint32_t eeprom_read(uint8_t addr) override; uint32_t eeprom_read(uint8_t addr) override;
private: private:
E1000E(PCI::Device& pci_device) E1000E(PCI::Device& pci_device)

View File

@ -38,11 +38,10 @@ namespace Kernel
public: public:
static BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> create(); static BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> create();
~IPv4Layer();
ARPTable& arp_table() { return *m_arp_table; } ARPTable& arp_table() { return *m_arp_table; }
void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan); BAN::ErrorOr<void> handle_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan);
virtual void unbind_socket(uint16_t port) override; virtual void unbind_socket(uint16_t port) override;
virtual BAN::ErrorOr<void> bind_socket_with_target(BAN::RefPtr<NetworkSocket>, const sockaddr* target_address, socklen_t target_address_len) override; virtual BAN::ErrorOr<void> bind_socket_with_target(BAN::RefPtr<NetworkSocket>, const sockaddr* target_address, socklen_t target_address_len) override;
@ -55,35 +54,15 @@ namespace Kernel
virtual size_t header_size() const override { return sizeof(IPv4Header); } virtual size_t header_size() const override { return sizeof(IPv4Header); }
private: private:
IPv4Layer(); IPv4Layer() = default;
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const;
BAN::ErrorOr<in_port_t> find_free_port(); BAN::ErrorOr<in_port_t> find_free_port();
void packet_handle_task();
BAN::ErrorOr<void> handle_ipv4_packet(NetworkInterface&, BAN::ByteSpan);
private: private:
struct PendingIPv4Packet BAN::UniqPtr<ARPTable> m_arp_table;
{
NetworkInterface& interface;
};
private: RecursiveSpinLock m_bound_socket_lock;
RecursiveSpinLock m_bound_socket_lock; BAN::HashMap<int, BAN::WeakPtr<NetworkSocket>> m_bound_sockets;
BAN::UniqPtr<ARPTable> m_arp_table;
Thread* m_thread { nullptr };
static constexpr size_t pending_packet_buffer_size = 128 * PAGE_SIZE;
BAN::UniqPtr<VirtualRange> m_pending_packet_buffer;
BAN::CircularQueue<PendingIPv4Packet, 128> m_pending_packets;
ThreadBlocker m_pending_thread_blocker;
SpinLock m_pending_lock;
size_t m_pending_total_size { 0 };
BAN::HashMap<int, BAN::WeakPtr<NetworkSocket>> m_bound_sockets;
friend class BAN::UniqPtr<IPv4Layer>; friend class BAN::UniqPtr<IPv4Layer>;
}; };

View File

@ -9,6 +9,7 @@ namespace Kernel
{ {
public: public:
static constexpr size_t buffer_size = BAN::numeric_limits<uint16_t>::max() + 1; static constexpr size_t buffer_size = BAN::numeric_limits<uint16_t>::max() + 1;
static constexpr size_t buffer_count = 32;
public: public:
static BAN::ErrorOr<BAN::RefPtr<LoopbackInterface>> create(); static BAN::ErrorOr<BAN::RefPtr<LoopbackInterface>> create();
@ -24,8 +25,9 @@ namespace Kernel
LoopbackInterface() LoopbackInterface()
: NetworkInterface(Type::Loopback) : NetworkInterface(Type::Loopback)
{} {}
~LoopbackInterface();
BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span<const BAN::ConstByteSpan> payload) override;
bool can_read_impl() const override { return false; } bool can_read_impl() const override { return false; }
bool can_write_impl() const override { return false; } bool can_write_impl() const override { return false; }
@ -33,8 +35,27 @@ namespace Kernel
bool has_hungup_impl() const override { return false; } bool has_hungup_impl() const override { return false; }
private: private:
SpinLock m_buffer_lock; void receive_thread();
private:
struct Descriptor
{
uint8_t* addr;
uint32_t size;
uint8_t state;
};
private:
Mutex m_buffer_lock;
BAN::UniqPtr<VirtualRange> m_buffer; BAN::UniqPtr<VirtualRange> m_buffer;
uint32_t m_buffer_tail { 0 };
uint32_t m_buffer_head { 0 };
Descriptor m_descriptors[buffer_count] {};
bool m_thread_should_die { false };
BAN::Atomic<bool> m_thread_is_dead { true };
ThreadBlocker m_thread_blocker;
}; };
} }

View File

@ -60,7 +60,11 @@ namespace Kernel
virtual dev_t rdev() const override { return m_rdev; } virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return m_name; } virtual BAN::StringView name() const override { return m_name; }
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) = 0; BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan payload)
{
return send_bytes(destination, protocol, { &payload, 1 });
}
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span<const BAN::ConstByteSpan> payload) = 0;
private: private:
const Type m_type; const Type m_type;

View File

@ -11,7 +11,7 @@ namespace Kernel
BAN::IPv4Address src_ipv4 { 0 }; BAN::IPv4Address src_ipv4 { 0 };
BAN::IPv4Address dst_ipv4 { 0 }; BAN::IPv4Address dst_ipv4 { 0 };
BAN::NetworkEndian<uint16_t> protocol { 0 }; BAN::NetworkEndian<uint16_t> protocol { 0 };
BAN::NetworkEndian<uint16_t> extra { 0 }; BAN::NetworkEndian<uint16_t> length { 0 };
}; };
static_assert(sizeof(PseudoHeader) == 12); static_assert(sizeof(PseudoHeader) == 12);
@ -36,6 +36,7 @@ namespace Kernel
NetworkLayer() = default; NetworkLayer() = default;
}; };
uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet, const PseudoHeader& pseudo_header); uint16_t calculate_internet_checksum(BAN::ConstByteSpan buffer);
uint16_t calculate_internet_checksum(BAN::Span<const BAN::ConstByteSpan> buffers);
} }

View File

@ -32,7 +32,7 @@ namespace Kernel
BAN::ErrorOr<BAN::RefPtr<NetworkInterface>> interface(const sockaddr* target, socklen_t target_len); BAN::ErrorOr<BAN::RefPtr<NetworkInterface>> interface(const sockaddr* target, socklen_t target_len);
virtual size_t protocol_header_size() const = 0; virtual size_t protocol_header_size() const = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0; virtual void get_protocol_header(BAN::ByteSpan header, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader) = 0;
virtual NetworkProtocol protocol() const = 0; virtual NetworkProtocol protocol() const = 0;
virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) = 0; virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) = 0;

View File

@ -29,9 +29,11 @@ namespace Kernel
: NetworkInterface(Type::Ethernet) : NetworkInterface(Type::Ethernet)
, m_pci_device(pci_device) , m_pci_device(pci_device)
{ } { }
~RTL8169();
BAN::ErrorOr<void> initialize(); BAN::ErrorOr<void> initialize();
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span<const BAN::ConstByteSpan>) override;
virtual bool can_read_impl() const override { return false; } virtual bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return false; } virtual bool can_write_impl() const override { return false; }
@ -47,7 +49,7 @@ namespace Kernel
void enable_link(); void enable_link();
BAN::ErrorOr<void> enable_interrupt(); BAN::ErrorOr<void> enable_interrupt();
void handle_receive(); void receive_thread();
protected: protected:
PCI::Device& m_pci_device; PCI::Device& m_pci_device;
@ -63,6 +65,9 @@ namespace Kernel
BAN::UniqPtr<DMARegion> m_tx_descriptor_region; BAN::UniqPtr<DMARegion> m_tx_descriptor_region;
SpinLock m_lock; SpinLock m_lock;
bool m_thread_should_die { false };
BAN::Atomic<bool> m_thread_is_dead { true };
ThreadBlocker m_thread_blocker; ThreadBlocker m_thread_blocker;
uint32_t m_rx_current { 0 }; uint32_t m_rx_current { 0 };

View File

@ -50,30 +50,30 @@ namespace Kernel
static BAN::ErrorOr<BAN::RefPtr<TCPSocket>> create(NetworkLayer&, const Info&); static BAN::ErrorOr<BAN::RefPtr<TCPSocket>> create(NetworkLayer&, const Info&);
~TCPSocket(); ~TCPSocket();
virtual NetworkProtocol protocol() const override { return NetworkProtocol::TCP; } NetworkProtocol protocol() const override { return NetworkProtocol::TCP; }
virtual size_t protocol_header_size() const override { return sizeof(TCPHeader) + m_tcp_options_bytes; } size_t protocol_header_size() const override { return sizeof(TCPHeader) + m_tcp_options_bytes; }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; void get_protocol_header(BAN::ByteSpan header, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader) override;
protected: protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) override; BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*, int) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override; BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override; BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override; BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> recvmsg_impl(msghdr& message, int flags) override; BAN::ErrorOr<size_t> recvmsg_impl(msghdr& message, int flags) override;
virtual BAN::ErrorOr<size_t> sendmsg_impl(const msghdr& message, int flags) override; BAN::ErrorOr<size_t> sendmsg_impl(const msghdr& message, int flags) override;
virtual BAN::ErrorOr<void> getpeername_impl(sockaddr*, socklen_t*) override; BAN::ErrorOr<void> getpeername_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<void> getsockopt_impl(int, int, void*, socklen_t*) override; BAN::ErrorOr<void> getsockopt_impl(int, int, void*, socklen_t*) override;
virtual BAN::ErrorOr<void> setsockopt_impl(int, int, const void*, socklen_t) override; BAN::ErrorOr<void> setsockopt_impl(int, int, const void*, socklen_t) override;
virtual BAN::ErrorOr<long> ioctl_impl(int, void*) override; BAN::ErrorOr<long> ioctl_impl(int, void*) override;
virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override;
virtual bool can_read_impl() const override; bool can_read_impl() const override;
virtual bool can_write_impl() const override; bool can_write_impl() const override;
virtual bool has_error_impl() const override { return false; } bool has_error_impl() const override { return false; }
virtual bool has_hungup_impl() const override; bool has_hungup_impl() const override;
private: private:
enum class State enum class State
@ -181,6 +181,7 @@ namespace Kernel
bool m_no_delay { false }; bool m_no_delay { false };
bool m_should_send_ack { false }; bool m_should_send_ack { false };
bool m_should_send_zero_window { false };
uint64_t m_time_wait_start_ms { 0 }; uint64_t m_time_wait_start_ms { 0 };

View File

@ -25,28 +25,28 @@ namespace Kernel
public: public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, const Socket::Info&); static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, const Socket::Info&);
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); } size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; void get_protocol_header(BAN::ByteSpan header, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader) override;
protected: protected:
virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override; void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override; BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override; BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> recvmsg_impl(msghdr& message, int flags) override; BAN::ErrorOr<size_t> recvmsg_impl(msghdr& message, int flags) override;
virtual BAN::ErrorOr<size_t> sendmsg_impl(const msghdr& message, int flags) override; BAN::ErrorOr<size_t> sendmsg_impl(const msghdr& message, int flags) override;
virtual BAN::ErrorOr<void> getpeername_impl(sockaddr*, socklen_t*) override { return BAN::Error::from_errno(ENOTCONN); } BAN::ErrorOr<void> getpeername_impl(sockaddr*, socklen_t*) override { return BAN::Error::from_errno(ENOTCONN); }
virtual BAN::ErrorOr<void> getsockopt_impl(int, int, void*, socklen_t*) override; BAN::ErrorOr<void> getsockopt_impl(int, int, void*, socklen_t*) override;
virtual BAN::ErrorOr<void> setsockopt_impl(int, int, const void*, socklen_t) override; BAN::ErrorOr<void> setsockopt_impl(int, int, const void*, socklen_t) override;
virtual BAN::ErrorOr<long> ioctl_impl(int, void*) override; BAN::ErrorOr<long> ioctl_impl(int, void*) override;
virtual bool can_read_impl() const override { return !m_packets.empty(); } bool can_read_impl() const override { return !m_packets.empty(); }
virtual bool can_write_impl() const override { return true; } bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; } bool has_error_impl() const override { return false; }
virtual bool has_hungup_impl() const override { return false; } bool has_hungup_impl() const override { return false; }
private: private:
UDPSocket(NetworkLayer&, const Socket::Info&); UDPSocket(NetworkLayer&, const Socket::Info&);

View File

@ -17,27 +17,7 @@ namespace Kernel
BAN::ErrorOr<BAN::UniqPtr<ARPTable>> ARPTable::create() BAN::ErrorOr<BAN::UniqPtr<ARPTable>> ARPTable::create()
{ {
auto arp_table = TRY(BAN::UniqPtr<ARPTable>::create()); return TRY(BAN::UniqPtr<ARPTable>::create());
arp_table->m_thread = TRY(Thread::create_kernel(
[](void* arp_table_ptr)
{
auto& arp_table = *reinterpret_cast<ARPTable*>(arp_table_ptr);
arp_table.packet_handle_task();
}, arp_table.ptr()
));
TRY(Processor::scheduler().add_thread(arp_table->m_thread));
return arp_table;
}
ARPTable::ARPTable()
{
}
ARPTable::~ARPTable()
{
if (m_thread)
m_thread->add_signal(SIGKILL, {});
m_thread = nullptr;
} }
BAN::ErrorOr<BAN::MACAddress> ARPTable::get_mac_from_ipv4(NetworkInterface& interface, BAN::IPv4Address ipv4_address) BAN::ErrorOr<BAN::MACAddress> ARPTable::get_mac_from_ipv4(NetworkInterface& interface, BAN::IPv4Address ipv4_address)
@ -64,7 +44,7 @@ namespace Kernel
ipv4_address = interface.get_gateway(); ipv4_address = interface.get_gateway();
{ {
SpinLockGuard _(m_table_lock); SpinLockGuard _(m_arp_table_lock);
auto it = m_arp_table.find(ipv4_address); auto it = m_arp_table.find(ipv4_address);
if (it != m_arp_table.end()) if (it != m_arp_table.end())
return it->value; return it->value;
@ -87,7 +67,7 @@ namespace Kernel
while (SystemTimer::get().ms_since_boot() < timeout) while (SystemTimer::get().ms_since_boot() < timeout)
{ {
{ {
SpinLockGuard _(m_table_lock); SpinLockGuard _(m_arp_table_lock);
auto it = m_arp_table.find(ipv4_address); auto it = m_arp_table.find(ipv4_address);
if (it != m_arp_table.end()) if (it != m_arp_table.end())
return it->value; return it->value;
@ -98,8 +78,16 @@ namespace Kernel
return BAN::Error::from_errno(ETIMEDOUT); return BAN::Error::from_errno(ETIMEDOUT);
} }
BAN::ErrorOr<void> ARPTable::handle_arp_packet(NetworkInterface& interface, const ARPPacket& packet) BAN::ErrorOr<void> ARPTable::handle_arp_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer)
{ {
if (buffer.size() < sizeof(ARPPacket))
{
dwarnln_if(DEBUG_ARP, "Too small ARP packet");
return {};
}
const auto& packet = buffer.as<const ARPPacket>();
if (packet.ptype != EtherType::IPv4) if (packet.ptype != EtherType::IPv4)
{ {
dprintln("Non IPv4 arp packet?"); dprintln("Non IPv4 arp packet?");
@ -112,23 +100,24 @@ namespace Kernel
{ {
if (packet.tpa == interface.get_ipv4_address()) if (packet.tpa == interface.get_ipv4_address())
{ {
ARPPacket arp_reply; const ARPPacket arp_reply {
arp_reply.htype = 0x0001; .htype = 0x0001,
arp_reply.ptype = EtherType::IPv4; .ptype = EtherType::IPv4,
arp_reply.hlen = 0x06; .hlen = 0x06,
arp_reply.plen = 0x04; .plen = 0x04,
arp_reply.oper = ARPOperation::Reply; .oper = ARPOperation::Reply,
arp_reply.sha = interface.get_mac_address(); .sha = interface.get_mac_address(),
arp_reply.spa = interface.get_ipv4_address(); .spa = interface.get_ipv4_address(),
arp_reply.tha = packet.sha; .tha = packet.sha,
arp_reply.tpa = packet.spa; .tpa = packet.spa,
};
TRY(interface.send_bytes(packet.sha, EtherType::ARP, BAN::ConstByteSpan::from(arp_reply))); TRY(interface.send_bytes(packet.sha, EtherType::ARP, BAN::ConstByteSpan::from(arp_reply)));
} }
break; break;
} }
case ARPOperation::Reply: case ARPOperation::Reply:
{ {
SpinLockGuard _(m_table_lock); SpinLockGuard _(m_arp_table_lock);
auto it = m_arp_table.find(packet.spa); auto it = m_arp_table.find(packet.spa);
if (it != m_arp_table.end()) if (it != m_arp_table.end())
@ -154,48 +143,4 @@ namespace Kernel
return {}; return {};
} }
void ARPTable::packet_handle_task()
{
for (;;)
{
PendingArpPacket pending = ({
SpinLockGuard guard(m_pending_lock);
while (m_pending_packets.empty())
{
SpinLockGuardAsMutex smutex(guard);
m_pending_thread_blocker.block_indefinite(&smutex);
}
auto packet = m_pending_packets.front();
m_pending_packets.pop();
packet;
});
if (auto ret = handle_arp_packet(pending.interface, pending.packet); ret.is_error())
dwarnln("{}", ret.error());
}
}
void ARPTable::add_arp_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer)
{
if (buffer.size() < sizeof(ARPPacket))
{
dwarnln_if(DEBUG_ARP, "ARP packet too small");
return;
}
auto& arp_packet = buffer.as<const ARPPacket>();
SpinLockGuard _(m_pending_lock);
if (m_pending_packets.full())
{
dwarnln_if(DEBUG_ARP, "ARP packet queue full");
return;
}
m_pending_packets.push({ .interface = interface, .packet = arp_packet });
m_pending_thread_blocker.unblock();
}
} }

View File

@ -1,6 +1,7 @@
#include <kernel/IDT.h> #include <kernel/IDT.h>
#include <kernel/InterruptController.h> #include <kernel/InterruptController.h>
#include <kernel/IO.h> #include <kernel/IO.h>
#include <kernel/Lock/SpinLockAsMutex.h>
#include <kernel/Memory/PageTable.h> #include <kernel/Memory/PageTable.h>
#include <kernel/MMIO.h> #include <kernel/MMIO.h>
#include <kernel/Networking/E1000/E1000.h> #include <kernel/Networking/E1000/E1000.h>
@ -57,6 +58,11 @@ namespace Kernel
E1000::~E1000() E1000::~E1000()
{ {
m_thread_should_die = true;
m_thread_blocker.unblock();
while (!m_thread_is_dead)
Processor::yield();
} }
BAN::ErrorOr<void> E1000::initialize() BAN::ErrorOr<void> E1000::initialize()
@ -84,6 +90,16 @@ namespace Kernel
dprintln(" link speed: {} Mbps", speed); dprintln(" link speed: {} Mbps", speed);
} }
auto* thread = TRY(Thread::create_kernel([](void* e1000_ptr) {
static_cast<E1000*>(e1000_ptr)->receive_thread();
}, this));
if (auto ret = Processor::scheduler().add_thread(thread); ret.is_error())
{
delete thread;
return ret.release_error();
}
m_thread_is_dead = false;
return {}; return {};
} }
@ -259,10 +275,8 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<void> E1000::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer) BAN::ErrorOr<void> E1000::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span<const BAN::ConstByteSpan> payload)
{ {
ASSERT(buffer.size() + sizeof(EthernetHeader) <= E1000_TX_BUFFER_SIZE);
SpinLockGuard _(m_lock); SpinLockGuard _(m_lock);
size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT; size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT;
@ -274,48 +288,75 @@ namespace Kernel
ethernet_header.src_mac = get_mac_address(); ethernet_header.src_mac = get_mac_address();
ethernet_header.ether_type = protocol; ethernet_header.ether_type = protocol;
memcpy(tx_buffer + sizeof(EthernetHeader), buffer.data(), buffer.size()); size_t packet_size = sizeof(EthernetHeader);
for (const auto& buffer : payload)
{
ASSERT(packet_size + buffer.size() < E1000_TX_BUFFER_SIZE);
memcpy(tx_buffer + packet_size, buffer.data(), buffer.size());
packet_size += buffer.size();
}
auto& descriptor = reinterpret_cast<volatile e1000_tx_desc*>(m_tx_descriptor_region->vaddr())[tx_current]; auto& descriptor = reinterpret_cast<volatile e1000_tx_desc*>(m_tx_descriptor_region->vaddr())[tx_current];
descriptor.length = sizeof(EthernetHeader) + buffer.size(); descriptor.length = packet_size;
descriptor.status = 0; descriptor.status = 0;
descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS; descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS;
// FIXME: there isnt really any reason to wait for transmission
write32(REG_TDT, (tx_current + 1) % E1000_TX_DESCRIPTOR_COUNT); write32(REG_TDT, (tx_current + 1) % E1000_TX_DESCRIPTOR_COUNT);
while (descriptor.status == 0) while (descriptor.status == 0)
continue; continue;
dprintln_if(DEBUG_E1000, "sent {} bytes", sizeof(EthernetHeader) + buffer.size()); dprintln_if(DEBUG_E1000, "sent {} bytes", packet_size);
return {}; return {};
} }
void E1000::receive_thread()
{
SpinLockGuard _(m_lock);
while (!m_thread_should_die)
{
for (;;)
{
const uint32_t rx_current = (read32(REG_RDT0) + 1) % E1000_RX_DESCRIPTOR_COUNT;
auto& descriptor = reinterpret_cast<volatile e1000_rx_desc*>(m_rx_descriptor_region->vaddr())[rx_current];
if (!(descriptor.status & 1))
break;
ASSERT(descriptor.length <= E1000_RX_BUFFER_SIZE);
dprintln_if(DEBUG_E1000, "got {} bytes", (uint16_t)descriptor.length);
m_lock.unlock(InterruptState::Enabled);
NetworkManager::get().on_receive(*this, BAN::ConstByteSpan {
reinterpret_cast<const uint8_t*>(m_rx_buffer_region->vaddr() + rx_current * E1000_RX_BUFFER_SIZE),
descriptor.length
});
m_lock.lock();
descriptor.status = 0;
write32(REG_RDT0, rx_current);
}
SpinLockAsMutex smutex(m_lock, InterruptState::Enabled);
m_thread_blocker.block_indefinite(&smutex);
}
m_thread_is_dead = true;
}
void E1000::handle_irq() void E1000::handle_irq()
{ {
const uint32_t icr = read32(REG_ICR); const uint32_t icr = read32(REG_ICR);
if (!(icr & (ICR_RxQ0 | ICR_RXT0)))
return;
write32(REG_ICR, icr); write32(REG_ICR, icr);
SpinLockGuard _(m_lock); if (icr & (ICR_RxQ0 | ICR_RXT0))
{
for (;;) { SpinLockGuard _(m_lock);
uint32_t rx_current = (read32(REG_RDT0) + 1) % E1000_RX_DESCRIPTOR_COUNT; m_thread_blocker.unblock();
auto& descriptor = reinterpret_cast<volatile e1000_rx_desc*>(m_rx_descriptor_region->vaddr())[rx_current];
if (!(descriptor.status & 1))
break;
ASSERT(descriptor.length <= E1000_RX_BUFFER_SIZE);
dprintln_if(DEBUG_E1000, "got {} bytes", (uint16_t)descriptor.length);
NetworkManager::get().on_receive(*this, BAN::ConstByteSpan {
reinterpret_cast<const uint8_t*>(m_rx_buffer_region->vaddr() + rx_current * E1000_RX_BUFFER_SIZE),
descriptor.length
});
descriptor.status = 0;
write32(REG_RDT0, rx_current);
} }
} }

View File

@ -21,50 +21,26 @@ namespace Kernel
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create() BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
{ {
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create()); auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
ipv4_manager->m_thread = TRY(Thread::create_kernel(
[](void* ipv4_manager_ptr)
{
auto& ipv4_manager = *reinterpret_cast<IPv4Layer*>(ipv4_manager_ptr);
ipv4_manager.packet_handle_task();
}, ipv4_manager.ptr()
));
TRY(Processor::scheduler().add_thread(ipv4_manager->m_thread));
ipv4_manager->m_pending_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
pending_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false
));
ipv4_manager->m_arp_table = TRY(ARPTable::create()); ipv4_manager->m_arp_table = TRY(ARPTable::create());
return ipv4_manager; return ipv4_manager;
} }
IPv4Layer::IPv4Layer() static IPv4Header get_ipv4_header(size_t packet_size, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol)
{ }
IPv4Layer::~IPv4Layer()
{ {
if (m_thread) IPv4Header header {
m_thread->add_signal(SIGKILL, {}); .version_IHL = 0x45,
m_thread = nullptr; .DSCP_ECN = 0x00,
} .total_length = packet_size,
.identification = 1,
void IPv4Layer::add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const .flags_frament = 0x00,
{ .time_to_live = 0x40,
auto& header = packet.as<IPv4Header>(); .protocol = protocol,
header.version_IHL = 0x45; .checksum = 0,
header.DSCP_ECN = 0x00; .src_address = src_ipv4,
header.total_length = packet.size(); .dst_address = dst_ipv4,
header.identification = 1; };
header.flags_frament = 0x00; header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header));
header.time_to_live = 0x40; return header;
header.protocol = protocol;
header.src_address = src_ipv4;
header.dst_address = dst_ipv4;
header.checksum = 0;
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {});
} }
void IPv4Layer::unbind_socket(uint16_t port) void IPv4Layer::unbind_socket(uint16_t port)
@ -204,7 +180,7 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len) BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan payload, const sockaddr* address, socklen_t address_len)
{ {
if (address->sa_family != AF_INET) if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
@ -233,43 +209,61 @@ namespace Kernel
return BAN::Error::from_errno(EADDRNOTAVAIL); return BAN::Error::from_errno(EADDRNOTAVAIL);
} }
BAN::Vector<uint8_t> packet_buffer; const auto ipv4_header = get_ipv4_header(
TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size())); sizeof(IPv4Header) + socket.protocol_header_size() + payload.size(),
auto packet = BAN::ByteSpan { packet_buffer.span() };
auto pseudo_header = PseudoHeader {
.src_ipv4 = interface->get_ipv4_address(),
.dst_ipv4 = dst_ipv4,
.protocol = socket.protocol()
};
memcpy(
packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(),
buffer.data(),
buffer.size()
);
socket.add_protocol_header(
packet.slice(sizeof(IPv4Header)),
dst_port,
pseudo_header
);
add_ipv4_header(
packet,
interface->get_ipv4_address(), interface->get_ipv4_address(),
dst_ipv4, dst_ipv4,
socket.protocol() socket.protocol()
); );
TRY(interface->send_bytes(dst_mac, EtherType::IPv4, packet)); const auto pseudo_header = PseudoHeader {
.src_ipv4 = interface->get_ipv4_address(),
.dst_ipv4 = dst_ipv4,
.protocol = socket.protocol(),
.length = socket.protocol_header_size() + payload.size()
};
return buffer.size(); uint8_t protocol_header_buffer[32];
ASSERT(socket.protocol_header_size() < sizeof(protocol_header_buffer));
auto protocol_header = BAN::ByteSpan::from(protocol_header_buffer).slice(0, socket.protocol_header_size());
socket.get_protocol_header(protocol_header, payload, dst_port, pseudo_header);
BAN::ConstByteSpan buffers[] {
BAN::ConstByteSpan::from(ipv4_header),
protocol_header,
payload,
};
TRY(interface->send_bytes(dst_mac, EtherType::IPv4, { buffers, sizeof(buffers) / sizeof(*buffers) }));
return payload.size();
} }
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet) BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan packet)
{ {
ASSERT(packet.size() >= sizeof(IPv4Header)); if (packet.size() < sizeof(IPv4Header))
{
dwarnln_if(DEBUG_IPV4, "Too small IPv4 packet");
return {};
}
auto& ipv4_header = packet.as<const IPv4Header>(); auto& ipv4_header = packet.as<const IPv4Header>();
auto ipv4_data = packet.slice(sizeof(IPv4Header)); if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header)) != 0)
{
dwarnln_if(DEBUG_IPV4, "IPv4 packet checksum failed");
return {};
}
if (ipv4_header.total_length > packet.size() || ipv4_header.total_length > interface.payload_mtu() || ipv4_header.total_length < sizeof(IPv4Header))
{
if (ipv4_header.flags_frament & IPv4Flags::DF)
dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet");
else
dwarnln_if(DEBUG_IPV4, "IPv4 fragmentation not supported");
return {};
}
auto ipv4_data = packet.slice(0, ipv4_header.total_length).slice(sizeof(IPv4Header));
auto src_ipv4 = ipv4_header.src_address; auto src_ipv4 = ipv4_header.src_address;
@ -292,14 +286,33 @@ namespace Kernel
{ {
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4)); auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4));
auto& reply_icmp_header = ipv4_data.as<ICMPHeader>(); auto send_ipv4_header = get_ipv4_header(
reply_icmp_header.type = ICMPType::EchoReply; ipv4_data.size(),
reply_icmp_header.checksum = 0; interface.get_ipv4_address(),
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data, {}); src_ipv4,
NetworkProtocol::ICMP
);
add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP); ICMPHeader send_icmp_header {
.type = ICMPType::EchoReply,
.code = icmp_header.code,
.checksum = 0,
.rest = icmp_header.rest,
};
auto send_payload = ipv4_data.slice(sizeof(ICMPHeader));
const BAN::ConstByteSpan send_buffers[] {
BAN::ConstByteSpan::from(send_ipv4_header),
BAN::ConstByteSpan::from(send_icmp_header),
send_payload
};
auto send_buffers_span = BAN::Span { send_buffers, sizeof(send_buffers) / sizeof(*send_buffers) };
send_icmp_header.checksum = calculate_internet_checksum(send_buffers_span.slice(1));
TRY(interface.send_bytes(dst_mac, EtherType::IPv4, send_buffers_span));
TRY(interface.send_bytes(dst_mac, EtherType::IPv4, packet));
break; break;
} }
case ICMPType::DestinationUnreachable: case ICMPType::DestinationUnreachable:
@ -381,80 +394,4 @@ namespace Kernel
return {}; return {};
} }
void IPv4Layer::packet_handle_task()
{
for (;;)
{
PendingIPv4Packet pending = ({
SpinLockGuard guard(m_pending_lock);
while (m_pending_packets.empty())
{
SpinLockGuardAsMutex smutex(guard);
m_pending_thread_blocker.block_indefinite(&smutex);
}
auto packet = m_pending_packets.front();
m_pending_packets.pop();
packet;
});
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
const size_t ipv4_packet_size = reinterpret_cast<const IPv4Header*>(buffer_start)->total_length;
if (auto ret = handle_ipv4_packet(pending.interface, BAN::ByteSpan(buffer_start, ipv4_packet_size)); ret.is_error())
dwarnln_if(DEBUG_IPV4, "{}", ret.error());
SpinLockGuard _(m_pending_lock);
m_pending_total_size -= ipv4_packet_size;
if (m_pending_total_size)
memmove(buffer_start, buffer_start + ipv4_packet_size, m_pending_total_size);
}
}
void IPv4Layer::add_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer)
{
if (buffer.size() < sizeof(IPv4Header))
{
dwarnln_if(DEBUG_IPV4, "IPv4 packet too small");
return;
}
SpinLockGuard _(m_pending_lock);
if (m_pending_packets.full())
{
dwarnln_if(DEBUG_IPV4, "IPv4 packet queue full");
return;
}
if (m_pending_total_size + buffer.size() > m_pending_packet_buffer->size())
{
dwarnln_if(DEBUG_IPV4, "IPv4 packet queue full");
return;
}
auto& ipv4_header = buffer.as<const IPv4Header>();
if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header), {}) != 0)
{
dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet");
return;
}
if (ipv4_header.total_length > buffer.size() || ipv4_header.total_length > interface.payload_mtu())
{
if (ipv4_header.flags_frament & IPv4Flags::DF)
dwarnln_if(DEBUG_IPV4, "Invalid IPv4 packet");
else
dwarnln_if(DEBUG_IPV4, "IPv4 fragmentation not supported");
return;
}
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
memcpy(buffer_start + m_pending_total_size, buffer.data(), ipv4_header.total_length);
m_pending_total_size += ipv4_header.total_length;
m_pending_packets.push({ .interface = interface });
m_pending_thread_blocker.unblock();
}
} }

View File

@ -1,3 +1,4 @@
#include <kernel/Lock/LockGuard.h>
#include <kernel/Networking/Loopback.h> #include <kernel/Networking/Loopback.h>
#include <kernel/Networking/NetworkManager.h> #include <kernel/Networking/NetworkManager.h>
@ -10,40 +11,121 @@ namespace Kernel
if (loopback_ptr == nullptr) if (loopback_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM); return BAN::Error::from_errno(ENOMEM);
auto loopback = BAN::RefPtr<LoopbackInterface>::adopt(loopback_ptr); auto loopback = BAN::RefPtr<LoopbackInterface>::adopt(loopback_ptr);
loopback->m_buffer = TRY(VirtualRange::create_to_vaddr_range( loopback->m_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(), PageTable::kernel(),
KERNEL_OFFSET, KERNEL_OFFSET,
BAN::numeric_limits<vaddr_t>::max(), BAN::numeric_limits<vaddr_t>::max(),
buffer_size, buffer_size * buffer_count,
PageTable::Flags::ReadWrite | PageTable::Flags::Present, PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true, false true, false
)); ));
auto* thread = TRY(Thread::create_kernel([](void* loopback_ptr) {
static_cast<LoopbackInterface*>(loopback_ptr)->receive_thread();
}, loopback_ptr));
if (auto ret = Processor::scheduler().add_thread(thread); ret.is_error())
{
delete thread;
return ret.release_error();
}
loopback->m_thread_is_dead = false;
loopback->set_ipv4_address({ 127, 0, 0, 1 }); loopback->set_ipv4_address({ 127, 0, 0, 1 });
loopback->set_netmask({ 255, 0, 0, 0 }); loopback->set_netmask({ 255, 0, 0, 0 });
for (size_t i = 0; i < buffer_count; i++)
{
loopback->m_descriptors[i] = {
.addr = reinterpret_cast<uint8_t*>(loopback->m_buffer->vaddr()) + i * buffer_size,
.size = 0,
.state = 0,
};
}
return loopback; return loopback;
} }
BAN::ErrorOr<void> LoopbackInterface::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer) LoopbackInterface::~LoopbackInterface()
{ {
ASSERT(buffer.size() + sizeof(EthernetHeader) <= buffer_size); m_thread_should_die = true;
m_thread_blocker.unblock();
SpinLockGuard _(m_buffer_lock); while (!m_thread_is_dead)
Processor::yield();
}
uint8_t* buffer_vaddr = reinterpret_cast<uint8_t*>(m_buffer->vaddr()); BAN::ErrorOr<void> LoopbackInterface::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span<const BAN::ConstByteSpan> payload)
{
auto& descriptor =
[&]() -> Descriptor&
{
LockGuard _(m_buffer_lock);
for (;;)
{
auto& descriptor = m_descriptors[m_buffer_head];
if (descriptor.state == 0)
{
m_buffer_head = (m_buffer_head + 1) % buffer_count;
descriptor.state = 1;
return descriptor;
}
m_thread_blocker.block_indefinite(&m_buffer_lock);
}
}();
auto& ethernet_header = *reinterpret_cast<EthernetHeader*>(buffer_vaddr); auto& ethernet_header = *reinterpret_cast<EthernetHeader*>(descriptor.addr);
ethernet_header.dst_mac = destination; ethernet_header.dst_mac = destination;
ethernet_header.src_mac = get_mac_address(); ethernet_header.src_mac = get_mac_address();
ethernet_header.ether_type = protocol; ethernet_header.ether_type = protocol;
memcpy(buffer_vaddr + sizeof(EthernetHeader), buffer.data(), buffer.size()); size_t packet_size = sizeof(EthernetHeader);
for (const auto& buffer : payload)
{
ASSERT(packet_size + buffer.size() <= buffer_size);
memcpy(descriptor.addr + packet_size, buffer.data(), buffer.size());
packet_size += buffer.size();
}
NetworkManager::get().on_receive(*this, BAN::ConstByteSpan { LockGuard _(m_buffer_lock);
buffer_vaddr, descriptor.size = packet_size;
buffer.size() + sizeof(EthernetHeader) descriptor.state = 2;
}); m_thread_blocker.unblock();
return {}; return {};
} }
void LoopbackInterface::receive_thread()
{
LockGuard _(m_buffer_lock);
while (!m_thread_should_die)
{
for (;;)
{
auto& descriptor = m_descriptors[m_buffer_tail];
if (descriptor.state != 2)
break;
m_buffer_tail = (m_buffer_tail + 1) % buffer_count;
m_buffer_lock.unlock();
NetworkManager::get().on_receive(*this, {
descriptor.addr,
descriptor.size,
});
m_buffer_lock.lock();
descriptor.size = 0;
descriptor.state = 0;
m_thread_blocker.unblock();
}
m_thread_blocker.block_indefinite(&m_buffer_lock);
}
m_thread_is_dead = true;
}
} }

View File

@ -3,15 +3,28 @@
namespace Kernel namespace Kernel
{ {
uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet, const PseudoHeader& pseudo_header) uint16_t calculate_internet_checksum(BAN::ConstByteSpan buffer)
{
return calculate_internet_checksum({ &buffer, 1 });
}
uint16_t calculate_internet_checksum(BAN::Span<const BAN::ConstByteSpan> buffers)
{ {
uint32_t checksum = 0; uint32_t checksum = 0;
for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(&pseudo_header)[i]); for (size_t i = 0; i < buffers.size(); i++)
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++) {
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]); auto buffer = buffers[i];
if (packet.size() % 2)
checksum += (uint16_t)packet[packet.size() - 1] << 8; const uint16_t* buffer_u16 = reinterpret_cast<const uint16_t*>(buffer.data());
for (size_t j = 0; j < buffer.size() / 2; j++)
checksum += BAN::host_to_network_endian(buffer_u16[j]);
if (buffer.size() % 2 == 0)
continue;
ASSERT(i == buffers.size() - 1);
checksum += buffer[buffer.size() - 1] << 8;
}
while (checksum >> 16) while (checksum >> 16)
checksum = (checksum >> 16) + (checksum & 0xFFFF); checksum = (checksum >> 16) + (checksum & 0xFFFF);
return ~(uint16_t)checksum; return ~(uint16_t)checksum;

View File

@ -154,18 +154,18 @@ namespace Kernel
return; return;
auto ethernet_header = packet.as<const EthernetHeader>(); auto ethernet_header = packet.as<const EthernetHeader>();
auto packet_data = packet.slice(sizeof(EthernetHeader));
switch (ethernet_header.ether_type) switch (ethernet_header.ether_type)
{ {
case EtherType::ARP: case EtherType::ARP:
{ if (auto ret = m_ipv4_layer->arp_table().handle_arp_packet(interface, packet_data); ret.is_error())
m_ipv4_layer->arp_table().add_arp_packet(interface, packet.slice(sizeof(EthernetHeader))); dwarnln("ARP: {}", ret.error());
break; break;
}
case EtherType::IPv4: case EtherType::IPv4:
{ if (auto ret = m_ipv4_layer->handle_ipv4_packet(interface, packet_data); ret.is_error())
m_ipv4_layer->add_ipv4_packet(interface, packet.slice(sizeof(EthernetHeader))); dwarnln("IPv4; {}", ret.error());
break; break;
}
default: default:
dprintln_if(DEBUG_ETHERTYPE, "Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type); dprintln_if(DEBUG_ETHERTYPE, "Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type);
break; break;

View File

@ -7,6 +7,9 @@
namespace Kernel namespace Kernel
{ {
// each buffer is 7440 bytes + padding = 8192
constexpr size_t s_buffer_size = 8192;
bool RTL8169::probe(PCI::Device& pci_device) bool RTL8169::probe(PCI::Device& pci_device)
{ {
if (pci_device.vendor_id() != 0x10ec) if (pci_device.vendor_id() != 0x10ec)
@ -68,9 +71,28 @@ namespace Kernel
// lock config registers // lock config registers
m_io_bar_region->write8(RTL8169_IO_9346CR, RTL8169_9346CR_MODE_NORMAL); m_io_bar_region->write8(RTL8169_IO_9346CR, RTL8169_9346CR_MODE_NORMAL);
auto* thread = TRY(Thread::create_kernel([](void* rtl8169_ptr) {
static_cast<RTL8169*>(rtl8169_ptr)->receive_thread();
}, this));
if (auto ret = Processor::scheduler().add_thread(thread); ret.is_error())
{
delete thread;
return ret.release_error();
}
m_thread_is_dead = false;
return {}; return {};
} }
RTL8169::~RTL8169()
{
m_thread_should_die = true;
m_thread_blocker.unblock();
while (!m_thread_is_dead)
Processor::yield();
}
BAN::ErrorOr<void> RTL8169::reset() BAN::ErrorOr<void> RTL8169::reset()
{ {
m_io_bar_region->write8(RTL8169_IO_CR, RTL8169_CR_RST); m_io_bar_region->write8(RTL8169_IO_CR, RTL8169_CR_RST);
@ -85,15 +107,12 @@ namespace Kernel
BAN::ErrorOr<void> RTL8169::initialize_rx() BAN::ErrorOr<void> RTL8169::initialize_rx()
{ {
// each buffer is 7440 bytes + padding = 8192 m_rx_buffer_region = TRY(DMARegion::create(m_rx_descriptor_count * s_buffer_size));
constexpr size_t buffer_size = 2 * PAGE_SIZE;
m_rx_buffer_region = TRY(DMARegion::create(m_rx_descriptor_count * buffer_size));
m_rx_descriptor_region = TRY(DMARegion::create(m_rx_descriptor_count * sizeof(RTL8169Descriptor))); m_rx_descriptor_region = TRY(DMARegion::create(m_rx_descriptor_count * sizeof(RTL8169Descriptor)));
for (size_t i = 0; i < m_rx_descriptor_count; i++) for (size_t i = 0; i < m_rx_descriptor_count; i++)
{ {
const paddr_t rx_buffer_paddr = m_rx_buffer_region->paddr() + i * buffer_size; const paddr_t rx_buffer_paddr = m_rx_buffer_region->paddr() + i * s_buffer_size;
uint32_t command = 0x1FF8 | RTL8169_DESC_CMD_OWN; uint32_t command = 0x1FF8 | RTL8169_DESC_CMD_OWN;
if (i == m_rx_descriptor_count - 1) if (i == m_rx_descriptor_count - 1)
@ -120,21 +139,17 @@ namespace Kernel
// configure max rx packet size // configure max rx packet size
m_io_bar_region->write16(RTL8169_IO_RMS, RTL8169_RMS_MAX); m_io_bar_region->write16(RTL8169_IO_RMS, RTL8169_RMS_MAX);
return {}; return {};
} }
BAN::ErrorOr<void> RTL8169::initialize_tx() BAN::ErrorOr<void> RTL8169::initialize_tx()
{ {
// each buffer is 7440 bytes + padding = 8192 m_tx_buffer_region = TRY(DMARegion::create(m_tx_descriptor_count * s_buffer_size));
constexpr size_t buffer_size = 2 * PAGE_SIZE;
m_tx_buffer_region = TRY(DMARegion::create(m_tx_descriptor_count * buffer_size));
m_tx_descriptor_region = TRY(DMARegion::create(m_tx_descriptor_count * sizeof(RTL8169Descriptor))); m_tx_descriptor_region = TRY(DMARegion::create(m_tx_descriptor_count * sizeof(RTL8169Descriptor)));
for (size_t i = 0; i < m_tx_descriptor_count; i++) for (size_t i = 0; i < m_tx_descriptor_count; i++)
{ {
const paddr_t tx_buffer_paddr = m_tx_buffer_region->paddr() + i * buffer_size; const paddr_t tx_buffer_paddr = m_tx_buffer_region->paddr() + i * s_buffer_size;
uint32_t command = 0; uint32_t command = 0;
if (i == m_tx_descriptor_count - 1) if (i == m_tx_descriptor_count - 1)
@ -194,14 +209,8 @@ namespace Kernel
return 0; return 0;
} }
BAN::ErrorOr<void> RTL8169::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer) BAN::ErrorOr<void> RTL8169::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::Span<const BAN::ConstByteSpan> payload)
{ {
constexpr size_t buffer_size = 8192;
const uint16_t packet_size = sizeof(EthernetHeader) + buffer.size();
if (packet_size > buffer_size)
return BAN::Error::from_errno(EINVAL);
if (!link_up()) if (!link_up())
return BAN::Error::from_errno(EADDRNOTAVAIL); return BAN::Error::from_errno(EADDRNOTAVAIL);
@ -219,14 +228,20 @@ namespace Kernel
m_lock.unlock(state); m_lock.unlock(state);
auto* tx_buffer = reinterpret_cast<uint8_t*>(m_tx_buffer_region->vaddr() + tx_current * buffer_size); auto* tx_buffer = reinterpret_cast<uint8_t*>(m_tx_buffer_region->vaddr() + tx_current * s_buffer_size);
// write packet // write packet
auto& ethernet_header = *reinterpret_cast<EthernetHeader*>(tx_buffer); auto& ethernet_header = *reinterpret_cast<EthernetHeader*>(tx_buffer);
ethernet_header.dst_mac = destination; ethernet_header.dst_mac = destination;
ethernet_header.src_mac = get_mac_address(); ethernet_header.src_mac = get_mac_address();
ethernet_header.ether_type = protocol; ethernet_header.ether_type = protocol;
memcpy(tx_buffer + sizeof(EthernetHeader), buffer.data(), buffer.size());
size_t packet_size = sizeof(EthernetHeader);
for (const auto& buffer : payload)
{
memcpy(tx_buffer + packet_size, buffer.data(), buffer.size());
packet_size += buffer.size();
}
// give packet ownership to NIC // give packet ownership to NIC
uint32_t command = packet_size | RTL8169_DESC_CMD_OWN | RTL8169_DESC_CMD_LS | RTL8169_DESC_CMD_FS; uint32_t command = packet_size | RTL8169_DESC_CMD_OWN | RTL8169_DESC_CMD_LS | RTL8169_DESC_CMD_FS;
@ -240,6 +255,50 @@ namespace Kernel
return {}; return {};
} }
void RTL8169::receive_thread()
{
SpinLockGuard _(m_lock);
while (!m_thread_should_die)
{
for (;;)
{
auto& descriptor = reinterpret_cast<volatile RTL8169Descriptor*>(m_rx_descriptor_region->vaddr())[m_rx_current];
if (descriptor.command & RTL8169_DESC_CMD_OWN)
break;
// packet buffer can only hold single packet, so we should not receive any multi-descriptor packets
ASSERT((descriptor.command & RTL8169_DESC_CMD_LS) && (descriptor.command & RTL8169_DESC_CMD_FS));
const uint16_t packet_length = descriptor.command & 0x3FFF;
if (packet_length > s_buffer_size)
dwarnln("Got {} bytes to {} byte buffer", packet_length, s_buffer_size);
else if (descriptor.command & (1u << 21))
; // descriptor has an error
else
{
m_lock.unlock(InterruptState::Enabled);
NetworkManager::get().on_receive(*this, BAN::ConstByteSpan {
reinterpret_cast<const uint8_t*>(m_rx_buffer_region->vaddr() + m_rx_current * s_buffer_size),
packet_length
});
m_lock.lock();
}
m_rx_current = (m_rx_current + 1) % m_rx_descriptor_count;
descriptor.command = descriptor.command | RTL8169_DESC_CMD_OWN;
}
SpinLockAsMutex smutex(m_lock, InterruptState::Enabled);
m_thread_blocker.block_indefinite(&smutex);
}
m_thread_is_dead = true;
}
void RTL8169::handle_irq() void RTL8169::handle_irq()
{ {
const uint16_t interrupt_status = m_io_bar_region->read16(RTL8169_IO_ISR); const uint16_t interrupt_status = m_io_bar_region->read16(RTL8169_IO_ISR);
@ -251,7 +310,7 @@ namespace Kernel
dprintln("link status -> {}", m_link_up.load()); dprintln("link status -> {}", m_link_up.load());
} }
if (interrupt_status & RTL8169_IR_TOK) if (interrupt_status & (RTL8169_IR_TOK | RTL8169_IR_ROK))
{ {
SpinLockGuard _(m_lock); SpinLockGuard _(m_lock);
m_thread_blocker.unblock(); m_thread_blocker.unblock();
@ -266,38 +325,6 @@ namespace Kernel
if (interrupt_status & RTL8169_IR_FVOW) if (interrupt_status & RTL8169_IR_FVOW)
dwarnln("Rx FIFO overflow"); dwarnln("Rx FIFO overflow");
// dont log TDU is sent after each sent packet // dont log TDU is sent after each sent packet
if (!(interrupt_status & RTL8169_IR_ROK))
return;
constexpr size_t buffer_size = 8192;
for (;;)
{
auto& descriptor = reinterpret_cast<volatile RTL8169Descriptor*>(m_rx_descriptor_region->vaddr())[m_rx_current];
if (descriptor.command & RTL8169_DESC_CMD_OWN)
break;
// packet buffer can only hold single packet, so we should not receive any multi-descriptor packets
ASSERT((descriptor.command & RTL8169_DESC_CMD_LS) && (descriptor.command & RTL8169_DESC_CMD_FS));
const uint16_t packet_length = descriptor.command & 0x3FFF;
if (packet_length > buffer_size)
dwarnln("Got {} bytes to {} byte buffer", packet_length, buffer_size);
else if (descriptor.command & (1u << 21))
; // descriptor has an error
else
{
NetworkManager::get().on_receive(*this, BAN::ConstByteSpan {
reinterpret_cast<const uint8_t*>(m_rx_buffer_region->vaddr() + m_rx_current * buffer_size),
packet_length
});
}
m_rx_current = (m_rx_current + 1) % m_rx_descriptor_count;
descriptor.command = descriptor.command | RTL8169_DESC_CMD_OWN;
}
} }
} }

View File

@ -524,24 +524,33 @@ namespace Kernel
return result; return result;
} }
void TCPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader pseudo_header) void TCPSocket::get_protocol_header(BAN::ByteSpan header_buffer, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader pseudo_header)
{ {
ASSERT(m_next_flags); ASSERT(m_next_flags);
ASSERT(m_mutex.locker() == Thread::current().tid()); ASSERT(m_mutex.locker() == Thread::current().tid());
ASSERT(header_buffer.size() == protocol_header_size());
auto& header = packet.as<TCPHeader>();
memset(&header, 0, sizeof(TCPHeader));
memset(header.options, TCPOption::End, m_tcp_options_bytes);
m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size; m_last_sent_window_size = m_recv_window.buffer->size() - m_recv_window.data_size;
if (m_should_send_zero_window)
m_last_sent_window_size = 0;
m_should_send_ack = false;
m_should_send_zero_window = false;
auto& header = header_buffer.as<TCPHeader>();
header = {
.src_port = bound_port(),
.dst_port = dst_port,
.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte,
.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte,
.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t),
.flags = m_next_flags,
.window_size = BAN::Math::min<size_t>(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift),
.checksum = 0,
.urgent_pointer = 0,
};
memset(header.options, 0, m_tcp_options_bytes);
header.src_port = bound_port();
header.dst_port = dst_port;
header.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte;
header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t);
header.window_size = BAN::Math::min<size_t>(0xFFFF, m_last_sent_window_size >> m_recv_window.scale_shift);
header.flags = m_next_flags;
if (header.flags & FIN) if (header.flags & FIN)
m_send_window.has_ghost_byte = true; m_send_window.has_ghost_byte = true;
m_next_flags = 0; m_next_flags = 0;
@ -566,10 +575,12 @@ namespace Kernel
m_send_window.current_seq = m_send_window.start_seq; m_send_window.current_seq = m_send_window.start_seq;
} }
pseudo_header.extra = packet.size(); const BAN::ConstByteSpan buffers[] {
header.checksum = calculate_internet_checksum(packet, pseudo_header); BAN::ConstByteSpan::from(pseudo_header),
header_buffer,
m_should_send_ack = false; payload,
};
header.checksum = calculate_internet_checksum({ buffers, sizeof(buffers) / sizeof(*buffers) });
dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags); dprintln_if(DEBUG_TCP, "sending {} {8b}", (uint8_t)m_state, header.flags);
dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number); dprintln_if(DEBUG_TCP, " ack {}", (uint32_t)header.ack_number);
@ -603,14 +614,17 @@ namespace Kernel
auto interface = interface_or_error.release_value(); auto interface = interface_or_error.release_value();
auto& addr_in = *reinterpret_cast<const sockaddr_in*>(sender); auto& addr_in = *reinterpret_cast<const sockaddr_in*>(sender);
checksum = calculate_internet_checksum(buffer, const PseudoHeader pseudo_header {
PseudoHeader { .src_ipv4 = BAN::IPv4Address(addr_in.sin_addr.s_addr),
.src_ipv4 = BAN::IPv4Address(addr_in.sin_addr.s_addr), .dst_ipv4 = interface->get_ipv4_address(),
.dst_ipv4 = interface->get_ipv4_address(), .protocol = NetworkProtocol::TCP,
.protocol = NetworkProtocol::TCP, .length = buffer.size(),
.extra = buffer.size() };
} const BAN::ConstByteSpan buffers[] {
); BAN::ConstByteSpan::from(pseudo_header),
buffer
};
checksum = calculate_internet_checksum({ buffers, sizeof(buffers) / sizeof(*buffers) });
} }
else else
{ {
@ -757,9 +771,9 @@ namespace Kernel
break; break;
} }
// TODO: even without SACKs, if other end sends seq [0, 1000] and our current seq is 100, we should accept const uint32_t expected_seq = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
// packet with seq [100, 1000]
if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte) if (header.seq_number > expected_seq)
dprintln_if(DEBUG_TCP, "Missing packets"); dprintln_if(DEBUG_TCP, "Missing packets");
else if (check_payload) else if (check_payload)
{ {
@ -770,7 +784,19 @@ namespace Kernel
m_send_window.current_ack = header.ack_number; m_send_window.current_ack = header.ack_number;
auto payload = buffer.slice(header.data_offset * sizeof(uint32_t)); auto payload = buffer.slice(header.data_offset * sizeof(uint32_t));
if (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size())
if (header.seq_number < expected_seq)
{
const uint32_t already_received_bytes = expected_seq - header.seq_number;
if (already_received_bytes <= payload.size())
payload = payload.slice(already_received_bytes);
else
payload = {};
}
const bool can_receive_new_data = (payload.size() > 0 && m_recv_window.data_size < m_recv_window.buffer->size());
if (can_receive_new_data)
{ {
auto* recv_base = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr()); auto* recv_base = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
@ -787,12 +813,12 @@ namespace Kernel
epoll_notify(EPOLLIN); epoll_notify(EPOLLIN);
dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv); dprintln_if(DEBUG_TCP, "Received {} bytes", nrecv);
m_should_send_ack = true;
} }
// make sure zero window is reported // make sure zero window is reported
if (m_next_flags == 0 && m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size()) if (m_last_sent_window_size > 0 && m_recv_window.data_size == m_recv_window.buffer->size())
m_should_send_zero_window = true;
else if (can_receive_new_data)
m_should_send_ack = true; m_should_send_ack = true;
} }
@ -915,7 +941,9 @@ namespace Kernel
const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms); const bool should_retransmit = m_send_window.had_zero_window || (m_send_window.sent_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms);
if (m_send_window.sent_size < m_send_window.scaled_size() && (should_retransmit || m_send_window.data_size > m_send_window.sent_size)) const bool can_send_new_data = (m_send_window.data_size > m_send_window.sent_size && m_send_window.sent_size < m_send_window.scaled_size());
if (m_send_window.scaled_size() > 0 && (should_retransmit || can_send_new_data))
{ {
m_send_window.had_zero_window = false; m_send_window.had_zero_window = false;
@ -927,7 +955,7 @@ namespace Kernel
const size_t total_send = BAN::Math::min<size_t>( const size_t total_send = BAN::Math::min<size_t>(
m_send_window.data_size - send_start_offset, m_send_window.data_size - send_start_offset,
m_send_window.scaled_size() - m_send_window.sent_size m_send_window.scaled_size() - send_start_offset
); );
m_send_window.current_seq = m_send_window.start_seq + send_start_offset; m_send_window.current_seq = m_send_window.start_seq + send_start_offset;
@ -961,15 +989,18 @@ namespace Kernel
continue; continue;
} }
if (m_should_send_ack) if (const size_t ack_count = m_should_send_ack + m_should_send_zero_window)
{ {
ASSERT(m_connection_info.has_value()); ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address); auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len; auto target_address_len = m_connection_info->address_len;
m_next_flags = ACK; for (size_t i = 0; i < ack_count; i++)
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error()) {
dwarnln("{}", ret.error()); m_next_flags = ACK;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
dwarnln("{}", ret.error());
}
} }
m_thread_blocker.unblock(); m_thread_blocker.unblock();

View File

@ -35,13 +35,26 @@ namespace Kernel
m_address_len = 0; m_address_len = 0;
} }
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) void UDPSocket::get_protocol_header(BAN::ByteSpan header_buffer, BAN::ConstByteSpan payload, uint16_t dst_port, PseudoHeader pseudo_header)
{ {
auto& header = packet.as<UDPHeader>(); ASSERT(header_buffer.size() == protocol_header_size());
header.src_port = bound_port();
header.dst_port = dst_port; auto& header = header_buffer.as<UDPHeader>();
header.length = packet.size(); header = {
header.checksum = 0; .src_port = bound_port(),
.dst_port = dst_port,
.length = protocol_header_size() + payload.size(),
.checksum = 0,
};
const BAN::ConstByteSpan buffers[] {
BAN::ConstByteSpan::from(pseudo_header),
header_buffer,
payload,
};
header.checksum = calculate_internet_checksum({ buffers, sizeof(buffers) / sizeof(*buffers) });
if (header.checksum == 0)
header.checksum = 0xFFFF;
} }
void UDPSocket::receive_packet(BAN::ConstByteSpan packet, const sockaddr* sender, socklen_t sender_len) void UDPSocket::receive_packet(BAN::ConstByteSpan packet, const sockaddr* sender, socklen_t sender_len)