diff --git a/kernel/include/kernel/FS/Inode.h b/kernel/include/kernel/FS/Inode.h index 5f5ba9f6cd..e21f742de0 100644 --- a/kernel/include/kernel/FS/Inode.h +++ b/kernel/include/kernel/FS/Inode.h @@ -102,6 +102,7 @@ namespace Kernel // Socket API BAN::ErrorOr bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr sendto(const sys_sendto_t*); + BAN::ErrorOr recvfrom(sys_recvfrom_t*); // General API BAN::ErrorOr read(off_t, BAN::ByteSpan buffer); @@ -126,7 +127,8 @@ namespace Kernel // Socket API virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); } // General API virtual BAN::ErrorOr read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); } diff --git a/kernel/include/kernel/Networking/E1000/E1000.h b/kernel/include/kernel/Networking/E1000/E1000.h index dad13f5273..2fba6e122e 100644 --- a/kernel/include/kernel/Networking/E1000/E1000.h +++ b/kernel/include/kernel/Networking/E1000/E1000.h @@ -23,7 +23,7 @@ namespace Kernel static BAN::ErrorOr> create(PCI::Device&); ~E1000(); - virtual uint8_t* get_mac_address() override { return m_mac_address; } + virtual BAN::MACAddress get_mac_address() const override { return m_mac_address; } virtual bool link_up() override { return m_link_up; } virtual int link_speed() override; @@ -66,11 +66,7 @@ namespace Kernel BAN::UniqPtr m_rx_descriptor_region; BAN::UniqPtr m_tx_descriptor_region; - uint8_t m_mac_address[6] {}; - uint16_t m_rx_current {}; - uint16_t m_tx_current {}; - void* m_rx_buffers[E1000_RX_DESCRIPTOR_COUNT] {}; - void* m_tx_buffers[E1000_TX_DESCRIPTOR_COUNT] {}; + BAN::MACAddress m_mac_address {}; bool m_link_up { false }; friend class BAN::RefPtr; diff --git a/kernel/include/kernel/Networking/IPv4.h b/kernel/include/kernel/Networking/IPv4.h index e111cfc821..f893f53bc9 100644 --- a/kernel/include/kernel/Networking/IPv4.h +++ b/kernel/include/kernel/Networking/IPv4.h @@ -1,10 +1,38 @@ #pragma once +#include +#include +#include #include namespace Kernel { - BAN::ErrorOr add_ipv4_header(BAN::Vector&, uint32_t src_ipv4, uint32_t dst_ipv4, uint8_t protocol); + struct IPv4Header + { + uint8_t version_IHL; + uint8_t DSCP_ECN; + BAN::NetworkEndian total_length { 0 }; + BAN::NetworkEndian identification { 0 }; + BAN::NetworkEndian flags_frament { 0 }; + uint8_t time_to_live; + uint8_t protocol; + BAN::NetworkEndian checksum { 0 }; + BAN::IPv4Address src_address; + BAN::IPv4Address dst_address; + + constexpr uint16_t calculate_checksum() const + { + return 0xFFFF + - (((uint16_t)version_IHL << 8) | DSCP_ECN) + - total_length + - identification + - flags_frament + - (((uint16_t)time_to_live << 8) | protocol); + } + }; + static_assert(sizeof(IPv4Header) == 20); + + void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol); } diff --git a/kernel/include/kernel/Networking/NetworkInterface.h b/kernel/include/kernel/Networking/NetworkInterface.h index 3312c9d894..5224fd6f88 100644 --- a/kernel/include/kernel/Networking/NetworkInterface.h +++ b/kernel/include/kernel/Networking/NetworkInterface.h @@ -2,7 +2,9 @@ #include #include +#include #include +#include namespace Kernel { @@ -19,13 +21,14 @@ namespace Kernel NetworkInterface(); virtual ~NetworkInterface() {} - virtual uint8_t* get_mac_address() = 0; - uint32_t get_ipv4_address() const { return m_ipv4_address; } + virtual BAN::MACAddress get_mac_address() const = 0; + BAN::IPv4Address get_ipv4_address() const { return m_ipv4_address; } virtual bool link_up() = 0; virtual int link_speed() = 0; - BAN::ErrorOr add_interface_header(BAN::Vector&, uint8_t destination_mac[6]); + size_t interface_header_size() const; + void add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination); virtual dev_t rdev() const override { return m_rdev; } virtual BAN::StringView name() const override { return m_name; } @@ -38,7 +41,7 @@ namespace Kernel const dev_t m_rdev; char m_name[10]; - uint32_t m_ipv4_address {}; + BAN::IPv4Address m_ipv4_address { 0 }; }; } diff --git a/kernel/include/kernel/Networking/NetworkManager.h b/kernel/include/kernel/Networking/NetworkManager.h index 3777262aa1..047f0089cc 100644 --- a/kernel/include/kernel/Networking/NetworkManager.h +++ b/kernel/include/kernel/Networking/NetworkManager.h @@ -32,6 +32,8 @@ namespace Kernel BAN::ErrorOr> create_socket(SocketType, mode_t, uid_t, gid_t); + void on_receive(BAN::ConstByteSpan); + private: NetworkManager(); diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index f17ae34a17..05fcd3dcbf 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -4,6 +4,8 @@ #include #include +#include + namespace Kernel { @@ -16,16 +18,22 @@ namespace Kernel void bind_interface_and_port(NetworkInterface*, uint16_t port); ~NetworkSocket(); - virtual BAN::ErrorOr add_protocol_header(BAN::Vector&, uint16_t src_port, uint16_t dst_port) = 0; + virtual size_t protocol_header_size() const = 0; + virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) = 0; virtual uint8_t protocol() const = 0; + virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; + protected: NetworkSocket(mode_t mode, uid_t uid, gid_t gid); + virtual BAN::ErrorOr read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0; + virtual void on_close_impl() override; virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) override; + virtual BAN::ErrorOr recvfrom_impl(sys_recvfrom_t*) override; protected: NetworkInterface* m_interface = nullptr; diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index a93dea7a6c..785660ce63 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -1,23 +1,54 @@ #pragma once +#include +#include +#include #include #include +#include namespace Kernel { + struct UDPHeader + { + BAN::NetworkEndian src_port; + BAN::NetworkEndian dst_port; + BAN::NetworkEndian length; + BAN::NetworkEndian checksum; + }; + static_assert(sizeof(UDPHeader) == 8); + class UDPSocket final : public NetworkSocket { public: static BAN::ErrorOr> create(mode_t, uid_t, gid_t); - virtual BAN::ErrorOr add_protocol_header(BAN::Vector&, uint16_t src_port, uint16_t dst_port) override; + virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); } + virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) override; virtual uint8_t protocol() const override { return 0x11; } + protected: + virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_addr, uint16_t sender_port) override; + virtual BAN::ErrorOr read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override; + private: UDPSocket(mode_t, uid_t, gid_t); + struct PacketInfo + { + BAN::IPv4Address sender_addr; + uint16_t sender_port; + size_t packet_size; + }; + private: + static constexpr size_t packet_buffer_size = 10 * PAGE_SIZE; + BAN::UniqPtr m_packet_buffer; + BAN::CircularQueue m_packets; + size_t m_packet_total_size { 0 }; + Semaphore m_semaphore; + friend class BAN::RefPtr; }; diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index 815dbda6d6..e7437a2d53 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -115,6 +115,7 @@ namespace Kernel BAN::ErrorOr sys_socket(int domain, int type, int protocol); BAN::ErrorOr sys_bind(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr sys_sendto(const sys_sendto_t*); + BAN::ErrorOr sys_recvfrom(sys_recvfrom_t*); BAN::ErrorOr sys_pipe(int fildes[2]); BAN::ErrorOr sys_dup(int fildes); diff --git a/kernel/kernel/FS/Inode.cpp b/kernel/kernel/FS/Inode.cpp index c32a47a104..589f3eaabc 100644 --- a/kernel/kernel/FS/Inode.cpp +++ b/kernel/kernel/FS/Inode.cpp @@ -132,6 +132,14 @@ namespace Kernel return sendto_impl(arguments); }; + BAN::ErrorOr Inode::recvfrom(sys_recvfrom_t* arguments) + { + LockGuard _(m_lock); + if (!mode().ifsock()) + return BAN::Error::from_errno(ENOTSOCK); + return recvfrom_impl(arguments); + }; + BAN::ErrorOr Inode::read(off_t offset, BAN::ByteSpan buffer) { LockGuard _(m_lock); diff --git a/kernel/kernel/Networking/E1000/E1000.cpp b/kernel/kernel/Networking/E1000/E1000.cpp index a79af51b4e..4e0732df34 100644 --- a/kernel/kernel/Networking/E1000/E1000.cpp +++ b/kernel/kernel/Networking/E1000/E1000.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #define DEBUG_E1000 1 @@ -69,14 +70,7 @@ namespace Kernel TRY(read_mac_address()); #if DEBUG_E1000 dprintln("E1000 at PCI {}:{}.{}", m_pci_device.bus(), m_pci_device.dev(), m_pci_device.func()); - dprintln(" MAC: {2H}:{2H}:{2H}:{2H}:{2H}:{2H}", - m_mac_address[0], - m_mac_address[1], - m_mac_address[2], - m_mac_address[3], - m_mac_address[4], - m_mac_address[5] - ); + dprintln(" MAC: {}", m_mac_address); #endif TRY(initialize_rx()); @@ -141,16 +135,16 @@ namespace Kernel if (m_has_eerprom) { uint32_t temp = eeprom_read(0); - m_mac_address[0] = temp; - m_mac_address[1] = temp >> 8; + m_mac_address.address[0] = temp; + m_mac_address.address[1] = temp >> 8; temp = eeprom_read(1); - m_mac_address[2] = temp; - m_mac_address[3] = temp >> 8; + m_mac_address.address[2] = temp; + m_mac_address.address[3] = temp >> 8; temp = eeprom_read(2); - m_mac_address[4] = temp; - m_mac_address[5] = temp >> 8; + m_mac_address.address[4] = temp; + m_mac_address.address[5] = temp >> 8; return {}; } @@ -162,7 +156,7 @@ namespace Kernel } for (int i = 0; i < 6; i++) - m_mac_address[i] = (uint8_t)read32(0x5400 + i * 8); + m_mac_address.address[i] = (uint8_t)read32(0x5400 + i * 8); return {}; } @@ -175,7 +169,6 @@ namespace Kernel auto* rx_descriptors = reinterpret_cast(m_rx_descriptor_region->vaddr()); for (size_t i = 0; i < E1000_RX_DESCRIPTOR_COUNT; i++) { - m_rx_buffers[i] = reinterpret_cast(m_rx_buffer_region->vaddr() + E1000_RX_BUFFER_SIZE * i); rx_descriptors[i].addr = m_rx_buffer_region->paddr() + E1000_RX_BUFFER_SIZE * i; rx_descriptors[i].status = 0; } @@ -209,7 +202,6 @@ namespace Kernel auto* tx_descriptors = reinterpret_cast(m_tx_descriptor_region->vaddr()); for (size_t i = 0; i < E1000_TX_DESCRIPTOR_COUNT; i++) { - m_tx_buffers[i] = reinterpret_cast(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * i); tx_descriptors[i].addr = m_tx_buffer_region->paddr() + E1000_TX_BUFFER_SIZE * i; tx_descriptors[i].cmd = 0; } @@ -300,8 +292,10 @@ namespace Kernel break; ASSERT_LTE((uint16_t)descriptor.length, E1000_RX_BUFFER_SIZE); - // FIXME: do something with the packet :) - dprintln("got {} byte packet", (uint16_t)descriptor.length); + NetworkManager::get().on_receive(BAN::ConstByteSpan { + reinterpret_cast(m_rx_buffer_region->vaddr() + rx_current * E1000_RX_BUFFER_SIZE), + descriptor.length + }); descriptor.status = 0; write32(REG_RDT0, rx_current); diff --git a/kernel/kernel/Networking/IPv4.cpp b/kernel/kernel/Networking/IPv4.cpp index ccd8fa4c2f..e9d76a0baa 100644 --- a/kernel/kernel/Networking/IPv4.cpp +++ b/kernel/kernel/Networking/IPv4.cpp @@ -4,50 +4,19 @@ namespace Kernel { - - struct IPv4Header + void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) { - uint8_t version_IHL; - uint8_t DSCP_ECN; - BAN::NetworkEndian total_length; - BAN::NetworkEndian identification; - BAN::NetworkEndian flags_frament; - uint8_t time_to_live; - uint8_t protocol; - BAN::NetworkEndian header_checksum; - BAN::NetworkEndian src_address; - BAN::NetworkEndian dst_address; - - uint16_t checksum() const - { - return 0xFFFF - - (((uint16_t)version_IHL << 8) | DSCP_ECN) - - total_length - - identification - - flags_frament - - (((uint16_t)time_to_live << 8) | protocol); - } - }; - static_assert(sizeof(IPv4Header) == 20); - - BAN::ErrorOr add_ipv4_header(BAN::Vector& packet, uint32_t src_ipv4, uint32_t dst_ipv4, uint8_t protocol) - { - TRY(packet.resize(packet.size() + sizeof(IPv4Header))); - memmove(packet.data() + sizeof(IPv4Header), packet.data(), packet.size() - sizeof(IPv4Header)); - - auto* header = reinterpret_cast(packet.data()); - header->version_IHL = 0x45; - header->DSCP_ECN = 0x10; - header->total_length = packet.size(); - header->identification = 1; - header->flags_frament = 0x00; - header->time_to_live = 0x40; - header->protocol = protocol; - header->header_checksum = header->checksum(); - header->src_address = src_ipv4; - header->dst_address = dst_ipv4; - - return {}; + auto& header = packet.as(); + header.version_IHL = 0x45; + header.DSCP_ECN = 0x10; + header.total_length = packet.size(); + header.identification = 1; + header.flags_frament = 0x00; + header.time_to_live = 0x40; + header.protocol = protocol; + header.checksum = header.calculate_checksum(); + header.src_address = src_ipv4; + header.dst_address = dst_ipv4; } } diff --git a/kernel/kernel/Networking/NetworkInterface.cpp b/kernel/kernel/Networking/NetworkInterface.cpp index 6ca4bc57ab..fa0af961c9 100644 --- a/kernel/kernel/Networking/NetworkInterface.cpp +++ b/kernel/kernel/Networking/NetworkInterface.cpp @@ -10,8 +10,8 @@ namespace Kernel struct EthernetHeader { - uint8_t dst_mac[6]; - uint8_t src_mac[6]; + BAN::MACAddress dst_mac; + BAN::MACAddress src_mac; BAN::NetworkEndian ether_type; }; static_assert(sizeof(EthernetHeader) == 14); @@ -40,19 +40,19 @@ namespace Kernel m_name[3] = minor(m_rdev) + '0'; } - BAN::ErrorOr NetworkInterface::add_interface_header(BAN::Vector& packet, uint8_t destination_mac[6]) + size_t NetworkInterface::interface_header_size() const { ASSERT(m_type == Type::Ethernet); + return sizeof(EthernetHeader); + } - TRY(packet.resize(packet.size() + sizeof(EthernetHeader))); - memmove(packet.data() + sizeof(EthernetHeader), packet.data(), packet.size() - sizeof(EthernetHeader)); - - auto* header = reinterpret_cast(packet.data()); - memcpy(header->dst_mac, destination_mac, 6); - memcpy(header->src_mac, get_mac_address(), 6); - header->ether_type = 0x0800; // ipv4 - - return {}; + void NetworkInterface::add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination) + { + ASSERT(m_type == Type::Ethernet); + auto& header = packet.as(); + header.dst_mac = destination; + header.src_mac = get_mac_address(); + header.ether_type = 0x0800; } } diff --git a/kernel/kernel/Networking/NetworkManager.cpp b/kernel/kernel/Networking/NetworkManager.cpp index 7c350f2aac..3353498279 100644 --- a/kernel/kernel/Networking/NetworkManager.cpp +++ b/kernel/kernel/Networking/NetworkManager.cpp @@ -1,7 +1,9 @@ +#include #include #include #include #include +#include #include #include @@ -106,4 +108,27 @@ namespace Kernel return {}; } + void NetworkManager::on_receive(BAN::ConstByteSpan packet) + { + // FIXME: properly handle packet parsing + + auto ipv4 = packet.slice(14); + auto& ipv4_header = ipv4.as(); + auto src_ipv4 = ipv4_header.src_address; + + auto udp = ipv4.slice(20); + auto& udp_header = udp.as(); + uint16_t src_port = udp_header.src_port; + uint16_t dst_port = udp_header.dst_port; + + if (!m_bound_sockets.contains(dst_port)) + { + dprintln("no one is listening on port {}", dst_port); + return; + } + + auto raw = udp.slice(8); + m_bound_sockets[dst_port].lock()->add_packet(raw, src_ipv4, src_port); + } + } diff --git a/kernel/kernel/Networking/NetworkSocket.cpp b/kernel/kernel/Networking/NetworkSocket.cpp index 79c888d78e..1a6bb6640b 100644 --- a/kernel/kernel/Networking/NetworkSocket.cpp +++ b/kernel/kernel/Networking/NetworkSocket.cpp @@ -65,18 +65,62 @@ namespace Kernel return BAN::Error::from_errno(EINVAL); } - static uint8_t dest_mac[6] { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF }; + static BAN::MACAddress dest_mac {{ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF }}; + + const size_t interface_header_offset = 0; + const size_t interface_header_size = m_interface->interface_header_size(); + + const size_t ipv4_header_offset = interface_header_offset + interface_header_size; + const size_t ipv4_header_size = sizeof(IPv4Header); + + const size_t protocol_header_offset = ipv4_header_offset + ipv4_header_size; + const size_t protocol_header_size = this->protocol_header_size(); + + const size_t payload_offset = protocol_header_offset + protocol_header_size; + const size_t payload_size = message.size(); BAN::Vector full_packet; - TRY(full_packet.resize(message.size())); - memcpy(full_packet.data(), message.data(), message.size()); - TRY(add_protocol_header(full_packet, m_port, destination->sin_port)); - TRY(add_ipv4_header(full_packet, m_interface->get_ipv4_address(), destination->sin_addr.s_addr, protocol())); - TRY(m_interface->add_interface_header(full_packet, dest_mac)); + TRY(full_packet.resize(payload_offset + payload_size)); - TRY(m_interface->send_raw_bytes(BAN::ConstByteSpan { full_packet.span() })); + BAN::ByteSpan packet_bytespan { full_packet.span() }; + + memcpy(full_packet.data() + payload_offset, message.data(), payload_size); + add_protocol_header(packet_bytespan.slice(protocol_header_offset), m_port, destination->sin_port); + add_ipv4_header(packet_bytespan.slice(ipv4_header_offset), m_interface->get_ipv4_address(), destination->sin_addr.s_addr, protocol()); + m_interface->add_interface_header(packet_bytespan.slice(interface_header_offset), dest_mac); + TRY(m_interface->send_raw_bytes(packet_bytespan)); return arguments->length; } + BAN::ErrorOr NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments) + { + sockaddr_in* sender_addr = nullptr; + if (arguments->address) + { + ASSERT(arguments->address_len); + if (*arguments->address_len < (socklen_t)sizeof(sockaddr_in)) + *arguments->address_len = 0; + else + { + sender_addr = reinterpret_cast(arguments->address); + *arguments->address_len = sizeof(sockaddr_in); + } + } + + if (!m_interface) + { + dprintln("No interface bound"); + return BAN::Error::from_errno(EINVAL); + } + + if (m_port == PORT_NONE) + { + dprintln("No port bound"); + return BAN::Error::from_errno(EINVAL); + } + + return TRY(read_packet(BAN::ByteSpan { reinterpret_cast(arguments->buffer), arguments->length }, sender_addr)); + } + } diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index a14a42ddc6..3fea655a70 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -1,39 +1,101 @@ -#include +#include #include +#include namespace Kernel { - struct UDPHeader - { - BAN::NetworkEndian src_port; - BAN::NetworkEndian dst_port; - BAN::NetworkEndian length; - BAN::NetworkEndian checksum; - }; - static_assert(sizeof(UDPHeader) == 8); - BAN::ErrorOr> UDPSocket::create(mode_t mode, uid_t uid, gid_t gid) { - return TRY(BAN::RefPtr::create(mode, uid, gid)); + auto socket = TRY(BAN::RefPtr::create(mode, uid, gid)); + socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range( + PageTable::kernel(), + KERNEL_OFFSET, + ~(uintptr_t)0, + packet_buffer_size, + PageTable::Flags::ReadWrite | PageTable::Flags::Present, + true + )); + return socket; } UDPSocket::UDPSocket(mode_t mode, uid_t uid, gid_t gid) : NetworkSocket(mode, uid, gid) { } - BAN::ErrorOr UDPSocket::add_protocol_header(BAN::Vector& packet, uint16_t src_port, uint16_t dst_port) + void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) { - TRY(packet.resize(packet.size() + sizeof(UDPHeader))); - memmove(packet.data() + sizeof(UDPHeader), packet.data(), packet.size() - sizeof(UDPHeader)); + auto& header = packet.as(); + header.src_port = src_port; + header.dst_port = dst_port; + header.length = packet.size(); + header.checksum = 0; + } - auto* header = reinterpret_cast(packet.data()); - header->src_port = src_port; - header->dst_port = dst_port; - header->length = packet.size(); - header->checksum = 0; + void UDPSocket::add_packet(BAN::ConstByteSpan packet, BAN::IPv4Address sender_addr, uint16_t sender_port) + { + CriticalScope _; - return {}; + if (m_packets.full()) + { + dprintln("Packet buffer full, dropping packet"); + return; + } + + if (!m_packets.empty() && m_packet_total_size > m_packet_buffer->size()) + { + dprintln("Packet buffer full, dropping packet"); + return; + } + + void* buffer = reinterpret_cast(m_packet_buffer->vaddr() + m_packet_total_size); + memcpy(buffer, packet.data(), packet.size()); + + m_packets.push(PacketInfo { + .sender_addr = sender_addr, + .sender_port = sender_port, + .packet_size = packet.size() + }); + m_packet_total_size += packet.size(); + + m_semaphore.unblock(); + } + + BAN::ErrorOr UDPSocket::read_packet(BAN::ByteSpan buffer, sockaddr_in* sender_addr) + { + while (m_packets.empty()) + TRY(Thread::current().block_or_eintr(m_semaphore)); + + CriticalScope _; + if (m_packets.empty()) + return read_packet(buffer, sender_addr); + + auto packet_info = m_packets.front(); + m_packets.pop(); + + size_t nread = BAN::Math::min(packet_info.packet_size, buffer.size()); + + memcpy( + buffer.data(), + (const void*)m_packet_buffer->vaddr(), + nread + ); + memmove( + (void*)m_packet_buffer->vaddr(), + (void*)(m_packet_buffer->vaddr() + packet_info.packet_size), + m_packet_total_size - packet_info.packet_size + ); + + m_packet_total_size -= packet_info.packet_size; + + if (sender_addr) + { + sender_addr->sin_family = AF_INET; + sender_addr->sin_port = packet_info.sender_port; + sender_addr->sin_addr.s_addr = packet_info.sender_addr.as_u32(); + } + + return nread; } } diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index 7144a4f86b..82b328a221 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -915,7 +915,6 @@ namespace Kernel return 0; } - BAN::ErrorOr Process::sys_sendto(const sys_sendto_t* arguments) { LockGuard _(m_lock); @@ -930,6 +929,29 @@ namespace Kernel return TRY(inode->sendto(arguments)); } + BAN::ErrorOr Process::sys_recvfrom(sys_recvfrom_t* arguments) + { + if (arguments->address && !arguments->address_len) + return BAN::Error::from_errno(EINVAL); + if (!arguments->address && arguments->address_len) + return BAN::Error::from_errno(EINVAL); + + LockGuard _(m_lock); + TRY(validate_pointer_access(arguments, sizeof(sys_recvfrom_t))); + TRY(validate_pointer_access(arguments->buffer, arguments->length)); + if (arguments->address) + { + TRY(validate_pointer_access(arguments->address_len, sizeof(*arguments->address_len))); + TRY(validate_pointer_access(arguments->address, *arguments->address_len)); + } + + auto inode = TRY(m_open_file_descriptors.inode_of(arguments->socket)); + if (!inode->mode().ifsock()) + return BAN::Error::from_errno(ENOTSOCK); + + return TRY(inode->recvfrom(arguments)); + } + BAN::ErrorOr Process::sys_pipe(int fildes[2]) { LockGuard _(m_lock); diff --git a/kernel/kernel/Syscall.cpp b/kernel/kernel/Syscall.cpp index 45fab33d40..c2f59cefba 100644 --- a/kernel/kernel/Syscall.cpp +++ b/kernel/kernel/Syscall.cpp @@ -222,6 +222,9 @@ namespace Kernel case SYS_SENDTO: ret = Process::current().sys_sendto((const sys_sendto_t*)arg1); break; + case SYS_RECVFROM: + ret = Process::current().sys_recvfrom((sys_recvfrom_t*)arg1); + break; default: dwarnln("Unknown syscall {}", syscall); break; diff --git a/libc/include/sys/socket.h b/libc/include/sys/socket.h index 43e8b1b893..805fa058c6 100644 --- a/libc/include/sys/socket.h +++ b/libc/include/sys/socket.h @@ -115,6 +115,16 @@ struct sys_sendto_t socklen_t dest_len; }; +struct sys_recvfrom_t +{ + int socket; + void* buffer; + size_t length; + int flags; + struct sockaddr* address; + socklen_t* address_len; +}; + int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len); int bind(int socket, const struct sockaddr* address, socklen_t address_len); int connect(int socket, const struct sockaddr* address, socklen_t address_len); diff --git a/libc/include/sys/syscall.h b/libc/include/sys/syscall.h index 39e7da66c6..4fb51cf58f 100644 --- a/libc/include/sys/syscall.h +++ b/libc/include/sys/syscall.h @@ -66,6 +66,7 @@ __BEGIN_DECLS #define SYS_SOCKET 65 #define SYS_BIND 66 #define SYS_SENDTO 67 +#define SYS_RECVFROM 68 __END_DECLS diff --git a/libc/sys/socket.cpp b/libc/sys/socket.cpp index ce05432fed..241a1c4bba 100644 --- a/libc/sys/socket.cpp +++ b/libc/sys/socket.cpp @@ -7,6 +7,20 @@ int bind(int socket, const struct sockaddr* address, socklen_t address_len) return syscall(SYS_BIND, socket, address, address_len); } +ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len) +{ + sys_recvfrom_t arguments { + .socket = socket, + .buffer = buffer, + .length = length, + .flags = flags, + .address = address, + .address_len = address_len + }; + return syscall(SYS_RECVFROM, &arguments); +} + + ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len) { sys_sendto_t arguments {