From ff49d8b84fbc0131c0915c39f20d0d4b3e1ba3c2 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Fri, 9 Feb 2024 17:05:07 +0200 Subject: [PATCH] Kernel: Cleanup OSI layer overlapping --- kernel/include/kernel/FS/Inode.h | 16 +-- kernel/include/kernel/Networking/IPv4Layer.h | 9 +- .../include/kernel/Networking/NetworkLayer.h | 5 +- .../include/kernel/Networking/NetworkSocket.h | 14 +- kernel/include/kernel/Networking/UDPSocket.h | 12 +- .../include/kernel/Networking/UNIX/Socket.h | 4 +- kernel/kernel/FS/Inode.cpp | 8 +- kernel/kernel/Networking/IPv4Layer.cpp | 123 ++++++++++++------ kernel/kernel/Networking/NetworkSocket.cpp | 62 +-------- kernel/kernel/Networking/UDPSocket.cpp | 65 ++++++--- kernel/kernel/Networking/UNIX/Socket.cpp | 33 ++--- kernel/kernel/Process.cpp | 6 +- 12 files changed, 185 insertions(+), 172 deletions(-) diff --git a/kernel/include/kernel/FS/Inode.h b/kernel/include/kernel/FS/Inode.h index d4c9b57715..80bd97fea3 100644 --- a/kernel/include/kernel/FS/Inode.h +++ b/kernel/include/kernel/FS/Inode.h @@ -104,8 +104,8 @@ namespace Kernel BAN::ErrorOr bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr listen(int backlog); - BAN::ErrorOr sendto(const sys_sendto_t*); - BAN::ErrorOr recvfrom(sys_recvfrom_t*); + BAN::ErrorOr sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len); + BAN::ErrorOr recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len); // General API BAN::ErrorOr read(off_t, BAN::ByteSpan buffer); @@ -131,12 +131,12 @@ namespace Kernel virtual BAN::ErrorOr link_target_impl() { return BAN::Error::from_errno(ENOTSUP); } // Socket API - virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); } - virtual BAN::ErrorOr recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } // General API virtual BAN::ErrorOr read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); } diff --git a/kernel/include/kernel/Networking/IPv4Layer.h b/kernel/include/kernel/Networking/IPv4Layer.h index 4a33b8ab9e..e12512955e 100644 --- a/kernel/include/kernel/Networking/IPv4Layer.h +++ b/kernel/include/kernel/Networking/IPv4Layer.h @@ -32,7 +32,7 @@ namespace Kernel }; static_assert(sizeof(IPv4Header) == 20); - class IPv4Layer : public NetworkLayer + class IPv4Layer final : public NetworkLayer { BAN_NON_COPYABLE(IPv4Layer); BAN_NON_MOVABLE(IPv4Layer); @@ -45,8 +45,9 @@ namespace Kernel void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan); - virtual void unbind_socket(uint16_t port, BAN::RefPtr) override; - virtual BAN::ErrorOr bind_socket(uint16_t port, BAN::RefPtr) override; + virtual void unbind_socket(BAN::RefPtr, uint16_t port) override; + virtual BAN::ErrorOr bind_socket_to_unused(BAN::RefPtr, const sockaddr* send_address, socklen_t send_address_len) override; + virtual BAN::ErrorOr bind_socket_to_address(BAN::RefPtr, const sockaddr* address, socklen_t address_len) override; virtual BAN::ErrorOr sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override; @@ -65,7 +66,7 @@ namespace Kernel }; private: - SpinLock m_lock; + RecursiveSpinLock m_lock; BAN::UniqPtr m_arp_table; Process* m_process { nullptr }; diff --git a/kernel/include/kernel/Networking/NetworkLayer.h b/kernel/include/kernel/Networking/NetworkLayer.h index 7f970f51c3..79f4cfb7f3 100644 --- a/kernel/include/kernel/Networking/NetworkLayer.h +++ b/kernel/include/kernel/Networking/NetworkLayer.h @@ -22,8 +22,9 @@ namespace Kernel public: virtual ~NetworkLayer() {} - virtual void unbind_socket(uint16_t port, BAN::RefPtr) = 0; - virtual BAN::ErrorOr bind_socket(uint16_t port, BAN::RefPtr) = 0; + virtual void unbind_socket(BAN::RefPtr, uint16_t port) = 0; + virtual BAN::ErrorOr bind_socket_to_unused(BAN::RefPtr, const sockaddr* send_address, socklen_t send_address_len) = 0; + virtual BAN::ErrorOr bind_socket_to_address(BAN::RefPtr, const sockaddr* address, socklen_t address_len) = 0; virtual BAN::ErrorOr sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0; diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index cc9cf492e2..8b5f2f0765 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -6,8 +6,6 @@ #include #include -#include - namespace Kernel { @@ -35,25 +33,21 @@ namespace Kernel virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0; virtual NetworkProtocol protocol() const = 0; - virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; + virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) = 0; + + bool is_bound() const { return m_interface != nullptr; } protected: NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); - virtual BAN::ErrorOr read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0; - virtual void on_close_impl() override; - virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; - virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) override; - virtual BAN::ErrorOr recvfrom_impl(sys_recvfrom_t*) override; - virtual BAN::ErrorOr ioctl_impl(int request, void* arg) override; protected: NetworkLayer& m_network_layer; NetworkInterface* m_interface = nullptr; - uint16_t m_port = PORT_NONE; + uint16_t m_port { PORT_NONE }; }; } diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index 0aff1e389c..17a8ded1d6 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -30,23 +30,25 @@ namespace Kernel virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; protected: - virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_addr, uint16_t sender_port) override; - virtual BAN::ErrorOr read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override; + virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) override; + + virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; + virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override; + virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) override; private: UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); struct PacketInfo { - BAN::IPv4Address sender_addr; - uint16_t sender_port; + sockaddr_storage sender; size_t packet_size; }; private: static constexpr size_t packet_buffer_size = 10 * PAGE_SIZE; BAN::UniqPtr m_packet_buffer; - BAN::CircularQueue m_packets; + BAN::CircularQueue m_packets; size_t m_packet_total_size { 0 }; SpinLock m_packet_lock; Semaphore m_packet_semaphore; diff --git a/kernel/include/kernel/Networking/UNIX/Socket.h b/kernel/include/kernel/Networking/UNIX/Socket.h index 236aea3e95..e1769a5624 100644 --- a/kernel/include/kernel/Networking/UNIX/Socket.h +++ b/kernel/include/kernel/Networking/UNIX/Socket.h @@ -23,8 +23,8 @@ namespace Kernel virtual BAN::ErrorOr connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr listen_impl(int) override; virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) override; - virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) override; - virtual BAN::ErrorOr recvfrom_impl(sys_recvfrom_t*) override; + virtual BAN::ErrorOr sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override; + virtual BAN::ErrorOr recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override; private: UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&); diff --git a/kernel/kernel/FS/Inode.cpp b/kernel/kernel/FS/Inode.cpp index a9dc22ef7f..451f7fd21d 100644 --- a/kernel/kernel/FS/Inode.cpp +++ b/kernel/kernel/FS/Inode.cpp @@ -148,20 +148,20 @@ namespace Kernel return listen_impl(backlog); } - BAN::ErrorOr Inode::sendto(const sys_sendto_t* arguments) + BAN::ErrorOr Inode::sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) { LockGuard _(m_lock); if (!mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return sendto_impl(arguments); + return sendto_impl(message, address, address_len); }; - BAN::ErrorOr Inode::recvfrom(sys_recvfrom_t* arguments) + BAN::ErrorOr Inode::recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) { LockGuard _(m_lock); if (!mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return recvfrom_impl(arguments); + return recvfrom_impl(buffer, address, address_len); }; BAN::ErrorOr Inode::read(off_t offset, BAN::ByteSpan buffer) diff --git a/kernel/kernel/Networking/IPv4Layer.cpp b/kernel/kernel/Networking/IPv4Layer.cpp index cf9bb13bb1..6a12f7a09c 100644 --- a/kernel/kernel/Networking/IPv4Layer.cpp +++ b/kernel/kernel/Networking/IPv4Layer.cpp @@ -66,7 +66,7 @@ namespace Kernel header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {}); } - void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr socket) + void IPv4Layer::unbind_socket(BAN::RefPtr socket, uint16_t port) { LockGuard _(m_lock); if (m_bound_sockets.contains(port)) @@ -78,29 +78,52 @@ namespace Kernel NetworkManager::get().TmpFileSystem::remove_from_cache(socket); } - BAN::ErrorOr IPv4Layer::bind_socket(uint16_t port, BAN::RefPtr socket) + BAN::ErrorOr IPv4Layer::bind_socket_to_unused(BAN::RefPtr socket, const sockaddr* address, socklen_t address_len) + { + if (!address || address_len < (socklen_t)sizeof(sockaddr_in)) + return BAN::Error::from_errno(EINVAL); + if (address->sa_family != AF_INET) + return BAN::Error::from_errno(EAFNOSUPPORT); + auto& sockaddr_in = *reinterpret_cast(address); + + LockGuard _(m_lock); + + uint16_t port = NetworkSocket::PORT_NONE; + for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++) + { + if (!m_bound_sockets.contains(temp)) + { + port = temp; + break; + } + } + if (port == NetworkSocket::PORT_NONE) + { + dwarnln("No ports available"); + return BAN::Error::from_errno(EAGAIN); + } + + struct sockaddr_in target; + target.sin_family = AF_INET; + target.sin_port = BAN::host_to_network_endian(port); + target.sin_addr.s_addr = sockaddr_in.sin_addr.s_addr; + return bind_socket_to_address(socket, (sockaddr*)&target, sizeof(sockaddr_in)); + } + + BAN::ErrorOr IPv4Layer::bind_socket_to_address(BAN::RefPtr socket, const sockaddr* address, socklen_t address_len) { if (NetworkManager::get().interfaces().empty()) return BAN::Error::from_errno(EADDRNOTAVAIL); - LockGuard _(m_lock); + if (!address || address_len < (socklen_t)sizeof(sockaddr_in)) + return BAN::Error::from_errno(EINVAL); + if (address->sa_family != AF_INET) + return BAN::Error::from_errno(EAFNOSUPPORT); - if (port == NetworkSocket::PORT_NONE) - { - for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++) - { - if (!m_bound_sockets.contains(temp)) - { - port = temp; - break; - } - } - if (port == NetworkSocket::PORT_NONE) - { - dwarnln("No ports available"); - return BAN::Error::from_errno(EAGAIN); - } - } + auto& sockaddr_in = *reinterpret_cast(address); + uint16_t port = BAN::host_to_network_endian(sockaddr_in.sin_port); + + LockGuard _(m_lock); if (m_bound_sockets.contains(port)) return BAN::Error::from_errno(EADDRINUSE); @@ -163,6 +186,10 @@ namespace Kernel auto ipv4_data = packet.slice(sizeof(IPv4Header)); auto src_ipv4 = ipv4_header.src_address; + + uint16_t dst_port = NetworkSocket::PORT_NONE; + uint16_t src_port = NetworkSocket::PORT_NONE; + switch (ipv4_header.protocol) { case NetworkProtocol::ICMP: @@ -188,37 +215,53 @@ namespace Kernel dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type); break; } - break; + return {}; } case NetworkProtocol::UDP: { auto& udp_header = ipv4_data.as(); - uint16_t src_port = udp_header.src_port; - uint16_t dst_port = udp_header.dst_port; - - LockGuard _(m_lock); - - if (!m_bound_sockets.contains(dst_port)) - { - dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); - return {}; - } - auto socket = m_bound_sockets[dst_port].lock(); - if (!socket) - { - dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); - return {}; - } - - auto udp_data = ipv4_data.slice(sizeof(UDPHeader)); - socket->add_packet(udp_data, src_ipv4, src_port); + dst_port = udp_header.dst_port; + src_port = udp_header.src_port; break; } default: dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol); - break; + return {}; } + ASSERT(dst_port != NetworkSocket::PORT_NONE); + ASSERT(src_port != NetworkSocket::PORT_NONE); + + BAN::RefPtr bound_socket; + + { + LockGuard _(m_lock); + if (!m_bound_sockets.contains(dst_port)) + { + dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); + return {}; + } + bound_socket = m_bound_sockets[dst_port].lock(); + } + + if (!bound_socket) + { + dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); + return {}; + } + + if (bound_socket->protocol() != ipv4_header.protocol) + { + dprintln_if(DEBUG_IPV4, "got data with wrong protocol ({}) on port {} (bound as {})", ipv4_header.protocol, dst_port, (uint8_t)bound_socket->protocol()); + return {}; + } + + sockaddr_in sender; + sender.sin_family = AF_INET; + sender.sin_port = BAN::NetworkEndian(src_port); + sender.sin_addr.s_addr = src_ipv4.raw; + bound_socket->receive_packet(ipv4_data, *reinterpret_cast(&sender)); + return {}; } diff --git a/kernel/kernel/Networking/NetworkSocket.cpp b/kernel/kernel/Networking/NetworkSocket.cpp index eee1971392..19e76e79e0 100644 --- a/kernel/kernel/Networking/NetworkSocket.cpp +++ b/kernel/kernel/Networking/NetworkSocket.cpp @@ -17,8 +17,12 @@ namespace Kernel void NetworkSocket::on_close_impl() { - if (m_interface) - m_network_layer.unbind_socket(m_port, this); + if (is_bound()) + { + m_network_layer.unbind_socket(this, m_port); + m_interface = nullptr; + m_port = PORT_NONE; + } } void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port) @@ -29,60 +33,6 @@ namespace Kernel m_port = port; } - BAN::ErrorOr NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len) - { - if (m_interface || address_len != sizeof(sockaddr_in)) - return BAN::Error::from_errno(EINVAL); - auto* addr_in = reinterpret_cast(address); - uint16_t dst_port = BAN::host_to_network_endian(addr_in->sin_port); - return m_network_layer.bind_socket(dst_port, this); - } - - BAN::ErrorOr NetworkSocket::sendto_impl(const sys_sendto_t* arguments) - { - if (arguments->flags) - { - dprintln("flags not supported"); - return BAN::Error::from_errno(ENOTSUP); - } - - if (!m_interface) - TRY(m_network_layer.bind_socket(PORT_NONE, this)); - - auto buffer = BAN::ConstByteSpan { reinterpret_cast(arguments->message), arguments->length }; - return TRY(m_network_layer.sendto(*this, buffer, arguments->dest_addr, arguments->dest_len)); - } - - BAN::ErrorOr NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments) - { - sockaddr_in* sender_addr = nullptr; - if (arguments->address) - { - ASSERT(arguments->address_len); - if (*arguments->address_len < (socklen_t)sizeof(sockaddr_in)) - *arguments->address_len = 0; - else - { - sender_addr = reinterpret_cast(arguments->address); - *arguments->address_len = sizeof(sockaddr_in); - } - } - - if (!m_interface) - { - dprintln("No interface bound"); - return BAN::Error::from_errno(EINVAL); - } - - if (m_port == PORT_NONE) - { - dprintln("No port bound"); - return BAN::Error::from_errno(EINVAL); - } - - return TRY(read_packet(BAN::ByteSpan { reinterpret_cast(arguments->buffer), arguments->length }, sender_addr)); - } - BAN::ErrorOr NetworkSocket::ioctl_impl(int request, void* arg) { if (!arg) diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index aa8b29d903..787618d076 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -33,8 +33,11 @@ namespace Kernel header.checksum = 0; } - void UDPSocket::add_packet(BAN::ConstByteSpan packet, BAN::IPv4Address sender_addr, uint16_t sender_port) + void UDPSocket::receive_packet(BAN::ConstByteSpan packet, const sockaddr_storage& sender) { + //auto& header = packet.as(); + auto payload = packet.slice(sizeof(UDPHeader)); + LockGuard _(m_packet_lock); if (m_packets.full()) @@ -43,60 +46,82 @@ namespace Kernel return; } - if (!m_packets.empty() && m_packet_total_size > m_packet_buffer->size()) + if (m_packet_total_size + payload.size() > m_packet_buffer->size()) { dprintln("Packet buffer full, dropping packet"); return; } void* buffer = reinterpret_cast(m_packet_buffer->vaddr() + m_packet_total_size); - memcpy(buffer, packet.data(), packet.size()); + memcpy(buffer, payload.data(), payload.size()); - m_packets.push(PacketInfo { - .sender_addr = sender_addr, - .sender_port = sender_port, - .packet_size = packet.size() + m_packets.emplace(PacketInfo { + .sender = sender, + .packet_size = payload.size() }); - m_packet_total_size += packet.size(); + m_packet_total_size += payload.size(); m_packet_semaphore.unblock(); } - BAN::ErrorOr UDPSocket::read_packet(BAN::ByteSpan buffer, sockaddr_in* sender_addr) + BAN::ErrorOr UDPSocket::bind_impl(const sockaddr* address, socklen_t address_len) { - while (m_packets.empty()) - TRY(Thread::current().block_or_eintr_indefinite(m_packet_semaphore)); + if (is_bound()) + return BAN::Error::from_errno(EINVAL); + return m_network_layer.bind_socket_to_address(this, address, address_len); + } + + BAN::ErrorOr UDPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) + { + if (!is_bound()) + { + dprintln("No interface bound"); + return BAN::Error::from_errno(EINVAL); + } + ASSERT(m_port != PORT_NONE); LockGuard _(m_packet_lock); - if (m_packets.empty()) - return read_packet(buffer, sender_addr); + + while (m_packets.empty()) + { + LockFreeGuard free(m_packet_lock); + TRY(Thread::current().block_or_eintr_indefinite(m_packet_semaphore)); + } auto packet_info = m_packets.front(); m_packets.pop(); size_t nread = BAN::Math::min(packet_info.packet_size, buffer.size()); + uint8_t* packet_buffer = reinterpret_cast(m_packet_buffer->vaddr()); memcpy( buffer.data(), - (const void*)m_packet_buffer->vaddr(), + packet_buffer, nread ); memmove( - (void*)m_packet_buffer->vaddr(), - (void*)(m_packet_buffer->vaddr() + packet_info.packet_size), + packet_buffer, + packet_buffer + packet_info.packet_size, m_packet_total_size - packet_info.packet_size ); m_packet_total_size -= packet_info.packet_size; - if (sender_addr) + if (address && address_len) { - sender_addr->sin_family = AF_INET; - sender_addr->sin_port = BAN::NetworkEndian(packet_info.sender_port); - sender_addr->sin_addr.s_addr = packet_info.sender_addr.raw; + if (*address_len > (socklen_t)sizeof(sockaddr_storage)) + *address_len = sizeof(sockaddr_storage); + memcpy(address, &packet_info.sender, *address_len); } return nread; } + BAN::ErrorOr UDPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) + { + if (!is_bound()) + TRY(m_network_layer.bind_socket_to_unused(this, address, address_len)); + return TRY(m_network_layer.sendto(*this, message, address, address_len)); + } + } diff --git a/kernel/kernel/Networking/UNIX/Socket.cpp b/kernel/kernel/Networking/UNIX/Socket.cpp index 3e953286aa..0646b7f3fe 100644 --- a/kernel/kernel/Networking/UNIX/Socket.cpp +++ b/kernel/kernel/Networking/UNIX/Socket.cpp @@ -248,29 +248,27 @@ namespace Kernel return {}; } - BAN::ErrorOr UnixDomainSocket::sendto_impl(const sys_sendto_t* arguments) + BAN::ErrorOr UnixDomainSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) { - if (arguments->flags) - return BAN::Error::from_errno(ENOTSUP); - if (arguments->length > s_packet_buffer_size) + if (message.size() > s_packet_buffer_size) return BAN::Error::from_errno(ENOBUFS); if (m_info.has()) { auto& connection_info = m_info.get(); - if (arguments->dest_addr) + if (address) return BAN::Error::from_errno(EISCONN); auto target = connection_info.connection.lock(); if (!target) return BAN::Error::from_errno(ENOTCONN); - TRY(target->add_packet({ reinterpret_cast(arguments->message), arguments->length })); - return arguments->length; + TRY(target->add_packet(message)); + return message.size(); } else { BAN::String canonical_path; - if (!arguments->dest_addr) + if (!address) { auto& connectionless_info = m_info.get(); if (connectionless_info.peer_address.empty()) @@ -279,9 +277,9 @@ namespace Kernel } else { - if (arguments->dest_len != sizeof(sockaddr_un)) + if (address_len != sizeof(sockaddr_un)) return BAN::Error::from_errno(EINVAL); - auto& sockaddr_un = *reinterpret_cast(arguments->dest_addr); + auto& sockaddr_un = *reinterpret_cast(address); if (sockaddr_un.sun_family != AF_UNIX) return BAN::Error::from_errno(EAFNOSUPPORT); @@ -301,16 +299,13 @@ namespace Kernel auto target = s_bound_sockets[canonical_path].lock(); if (!target) return BAN::Error::from_errno(EDESTADDRREQ); - TRY(target->add_packet({ reinterpret_cast(arguments->message), arguments->length })); - return arguments->length; + TRY(target->add_packet(message)); + return message.size(); } } - BAN::ErrorOr UnixDomainSocket::recvfrom_impl(sys_recvfrom_t* arguments) + BAN::ErrorOr UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*) { - if (arguments->flags) - return BAN::Error::from_errno(ENOTSUP); - if (m_info.has()) { auto& connection_info = m_info.get(); @@ -328,14 +323,14 @@ namespace Kernel size_t nread = 0; if (is_streaming()) - nread = BAN::Math::min(arguments->length, m_packet_size_total); + nread = BAN::Math::min(buffer.size(), m_packet_size_total); else { - nread = BAN::Math::min(arguments->length, m_packet_sizes.front()); + nread = BAN::Math::min(buffer.size(), m_packet_sizes.front()); m_packet_sizes.pop(); } - memcpy(arguments->buffer, packet_buffer, nread); + memcpy(buffer.data(), packet_buffer, nread); memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread); m_packet_size_total -= nread; diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index 9b514be73c..fb2b669f12 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -983,7 +983,8 @@ namespace Kernel if (!inode->mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return TRY(inode->sendto(arguments)); + BAN::ConstByteSpan message { reinterpret_cast(arguments->message), arguments->length }; + return TRY(inode->sendto(message, arguments->dest_addr, arguments->dest_len)); } BAN::ErrorOr Process::sys_recvfrom(sys_recvfrom_t* arguments) @@ -1006,7 +1007,8 @@ namespace Kernel if (!inode->mode().ifsock()) return BAN::Error::from_errno(ENOTSOCK); - return TRY(inode->recvfrom(arguments)); + BAN::ByteSpan buffer { reinterpret_cast(arguments->buffer), arguments->length }; + return TRY(inode->recvfrom(buffer, arguments->address, arguments->address_len)); } BAN::ErrorOr Process::sys_ioctl(int fildes, int request, void* arg)