Kernel: Cleanup OSI layer overlapping

This commit is contained in:
Bananymous 2024-02-09 17:05:07 +02:00
parent 5d78cd3016
commit ff49d8b84f
12 changed files with 185 additions and 172 deletions

View File

@ -104,8 +104,8 @@ namespace Kernel
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> listen(int backlog); BAN::ErrorOr<void> listen(int backlog);
BAN::ErrorOr<size_t> sendto(const sys_sendto_t*); BAN::ErrorOr<size_t> sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<size_t> recvfrom(sys_recvfrom_t*); BAN::ErrorOr<size_t> recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len);
// General API // General API
BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer); BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer);
@ -131,12 +131,12 @@ namespace Kernel
virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); }
// Socket API // Socket API
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
// General API // General API
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); }

View File

@ -32,7 +32,7 @@ namespace Kernel
}; };
static_assert(sizeof(IPv4Header) == 20); static_assert(sizeof(IPv4Header) == 20);
class IPv4Layer : public NetworkLayer class IPv4Layer final : public NetworkLayer
{ {
BAN_NON_COPYABLE(IPv4Layer); BAN_NON_COPYABLE(IPv4Layer);
BAN_NON_MOVABLE(IPv4Layer); BAN_NON_MOVABLE(IPv4Layer);
@ -45,8 +45,9 @@ namespace Kernel
void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan); void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan);
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override; virtual void unbind_socket(BAN::RefPtr<NetworkSocket>, uint16_t port) override;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override; virtual BAN::ErrorOr<void> bind_socket_to_unused(BAN::RefPtr<NetworkSocket>, const sockaddr* send_address, socklen_t send_address_len) override;
virtual BAN::ErrorOr<void> bind_socket_to_address(BAN::RefPtr<NetworkSocket>, const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
@ -65,7 +66,7 @@ namespace Kernel
}; };
private: private:
SpinLock m_lock; RecursiveSpinLock m_lock;
BAN::UniqPtr<ARPTable> m_arp_table; BAN::UniqPtr<ARPTable> m_arp_table;
Process* m_process { nullptr }; Process* m_process { nullptr };

View File

@ -22,8 +22,9 @@ namespace Kernel
public: public:
virtual ~NetworkLayer() {} virtual ~NetworkLayer() {}
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0; virtual void unbind_socket(BAN::RefPtr<NetworkSocket>, uint16_t port) = 0;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0; virtual BAN::ErrorOr<void> bind_socket_to_unused(BAN::RefPtr<NetworkSocket>, const sockaddr* send_address, socklen_t send_address_len) = 0;
virtual BAN::ErrorOr<void> bind_socket_to_address(BAN::RefPtr<NetworkSocket>, const sockaddr* address, socklen_t address_len) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0; virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0;

View File

@ -6,8 +6,6 @@
#include <kernel/Networking/NetworkInterface.h> #include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h> #include <kernel/Networking/NetworkLayer.h>
#include <netinet/in.h>
namespace Kernel namespace Kernel
{ {
@ -35,25 +33,21 @@ namespace Kernel
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0;
virtual NetworkProtocol protocol() const = 0; virtual NetworkProtocol protocol() const = 0;
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) = 0;
bool is_bound() const { return m_interface != nullptr; }
protected: protected:
NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0;
virtual void on_close_impl() override; virtual void on_close_impl() override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override; virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override;
protected: protected:
NetworkLayer& m_network_layer; NetworkLayer& m_network_layer;
NetworkInterface* m_interface = nullptr; NetworkInterface* m_interface = nullptr;
uint16_t m_port = PORT_NONE; uint16_t m_port { PORT_NONE };
}; };
} }

View File

@ -30,23 +30,25 @@ namespace Kernel
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override;
protected: protected:
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_addr, uint16_t sender_port) override; virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) override;
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) override;
private: private:
UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
struct PacketInfo struct PacketInfo
{ {
BAN::IPv4Address sender_addr; sockaddr_storage sender;
uint16_t sender_port;
size_t packet_size; size_t packet_size;
}; };
private: private:
static constexpr size_t packet_buffer_size = 10 * PAGE_SIZE; static constexpr size_t packet_buffer_size = 10 * PAGE_SIZE;
BAN::UniqPtr<VirtualRange> m_packet_buffer; BAN::UniqPtr<VirtualRange> m_packet_buffer;
BAN::CircularQueue<PacketInfo, 128> m_packets; BAN::CircularQueue<PacketInfo, 32> m_packets;
size_t m_packet_total_size { 0 }; size_t m_packet_total_size { 0 };
SpinLock m_packet_lock; SpinLock m_packet_lock;
Semaphore m_packet_semaphore; Semaphore m_packet_semaphore;

View File

@ -23,8 +23,8 @@ namespace Kernel
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override; virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override; virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override; virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override;
private: private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&); UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);

View File

@ -148,20 +148,20 @@ namespace Kernel
return listen_impl(backlog); return listen_impl(backlog);
} }
BAN::ErrorOr<size_t> Inode::sendto(const sys_sendto_t* arguments) BAN::ErrorOr<size_t> Inode::sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (!mode().ifsock()) if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return sendto_impl(arguments); return sendto_impl(message, address, address_len);
}; };
BAN::ErrorOr<size_t> Inode::recvfrom(sys_recvfrom_t* arguments) BAN::ErrorOr<size_t> Inode::recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (!mode().ifsock()) if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return recvfrom_impl(arguments); return recvfrom_impl(buffer, address, address_len);
}; };
BAN::ErrorOr<size_t> Inode::read(off_t offset, BAN::ByteSpan buffer) BAN::ErrorOr<size_t> Inode::read(off_t offset, BAN::ByteSpan buffer)

View File

@ -66,7 +66,7 @@ namespace Kernel
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {}); header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {});
} }
void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket) void IPv4Layer::unbind_socket(BAN::RefPtr<NetworkSocket> socket, uint16_t port)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (m_bound_sockets.contains(port)) if (m_bound_sockets.contains(port))
@ -78,29 +78,52 @@ namespace Kernel
NetworkManager::get().TmpFileSystem::remove_from_cache(socket); NetworkManager::get().TmpFileSystem::remove_from_cache(socket);
} }
BAN::ErrorOr<void> IPv4Layer::bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket) BAN::ErrorOr<void> IPv4Layer::bind_socket_to_unused(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
{
if (!address || address_len < (socklen_t)sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
LockGuard _(m_lock);
uint16_t port = NetworkSocket::PORT_NONE;
for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++)
{
if (!m_bound_sockets.contains(temp))
{
port = temp;
break;
}
}
if (port == NetworkSocket::PORT_NONE)
{
dwarnln("No ports available");
return BAN::Error::from_errno(EAGAIN);
}
struct sockaddr_in target;
target.sin_family = AF_INET;
target.sin_port = BAN::host_to_network_endian(port);
target.sin_addr.s_addr = sockaddr_in.sin_addr.s_addr;
return bind_socket_to_address(socket, (sockaddr*)&target, sizeof(sockaddr_in));
}
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_address(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
{ {
if (NetworkManager::get().interfaces().empty()) if (NetworkManager::get().interfaces().empty())
return BAN::Error::from_errno(EADDRNOTAVAIL); return BAN::Error::from_errno(EADDRNOTAVAIL);
LockGuard _(m_lock); if (!address || address_len < (socklen_t)sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (port == NetworkSocket::PORT_NONE) auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
{ uint16_t port = BAN::host_to_network_endian(sockaddr_in.sin_port);
for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++)
{ LockGuard _(m_lock);
if (!m_bound_sockets.contains(temp))
{
port = temp;
break;
}
}
if (port == NetworkSocket::PORT_NONE)
{
dwarnln("No ports available");
return BAN::Error::from_errno(EAGAIN);
}
}
if (m_bound_sockets.contains(port)) if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE); return BAN::Error::from_errno(EADDRINUSE);
@ -163,6 +186,10 @@ namespace Kernel
auto ipv4_data = packet.slice(sizeof(IPv4Header)); auto ipv4_data = packet.slice(sizeof(IPv4Header));
auto src_ipv4 = ipv4_header.src_address; auto src_ipv4 = ipv4_header.src_address;
uint16_t dst_port = NetworkSocket::PORT_NONE;
uint16_t src_port = NetworkSocket::PORT_NONE;
switch (ipv4_header.protocol) switch (ipv4_header.protocol)
{ {
case NetworkProtocol::ICMP: case NetworkProtocol::ICMP:
@ -188,37 +215,53 @@ namespace Kernel
dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type); dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type);
break; break;
} }
break; return {};
} }
case NetworkProtocol::UDP: case NetworkProtocol::UDP:
{ {
auto& udp_header = ipv4_data.as<const UDPHeader>(); auto& udp_header = ipv4_data.as<const UDPHeader>();
uint16_t src_port = udp_header.src_port; dst_port = udp_header.dst_port;
uint16_t dst_port = udp_header.dst_port; src_port = udp_header.src_port;
LockGuard _(m_lock);
if (!m_bound_sockets.contains(dst_port))
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
auto socket = m_bound_sockets[dst_port].lock();
if (!socket)
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
auto udp_data = ipv4_data.slice(sizeof(UDPHeader));
socket->add_packet(udp_data, src_ipv4, src_port);
break; break;
} }
default: default:
dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol); dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol);
break; return {};
} }
ASSERT(dst_port != NetworkSocket::PORT_NONE);
ASSERT(src_port != NetworkSocket::PORT_NONE);
BAN::RefPtr<Kernel::NetworkSocket> bound_socket;
{
LockGuard _(m_lock);
if (!m_bound_sockets.contains(dst_port))
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
bound_socket = m_bound_sockets[dst_port].lock();
}
if (!bound_socket)
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
if (bound_socket->protocol() != ipv4_header.protocol)
{
dprintln_if(DEBUG_IPV4, "got data with wrong protocol ({}) on port {} (bound as {})", ipv4_header.protocol, dst_port, (uint8_t)bound_socket->protocol());
return {};
}
sockaddr_in sender;
sender.sin_family = AF_INET;
sender.sin_port = BAN::NetworkEndian<uint16_t>(src_port);
sender.sin_addr.s_addr = src_ipv4.raw;
bound_socket->receive_packet(ipv4_data, *reinterpret_cast<const sockaddr_storage*>(&sender));
return {}; return {};
} }

View File

@ -17,8 +17,12 @@ namespace Kernel
void NetworkSocket::on_close_impl() void NetworkSocket::on_close_impl()
{ {
if (m_interface) if (is_bound())
m_network_layer.unbind_socket(m_port, this); {
m_network_layer.unbind_socket(this, m_port);
m_interface = nullptr;
m_port = PORT_NONE;
}
} }
void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port) void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port)
@ -29,60 +33,6 @@ namespace Kernel
m_port = port; m_port = port;
} }
BAN::ErrorOr<void> NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (m_interface || address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
auto* addr_in = reinterpret_cast<const sockaddr_in*>(address);
uint16_t dst_port = BAN::host_to_network_endian(addr_in->sin_port);
return m_network_layer.bind_socket(dst_port, this);
}
BAN::ErrorOr<size_t> NetworkSocket::sendto_impl(const sys_sendto_t* arguments)
{
if (arguments->flags)
{
dprintln("flags not supported");
return BAN::Error::from_errno(ENOTSUP);
}
if (!m_interface)
TRY(m_network_layer.bind_socket(PORT_NONE, this));
auto buffer = BAN::ConstByteSpan { reinterpret_cast<const uint8_t*>(arguments->message), arguments->length };
return TRY(m_network_layer.sendto(*this, buffer, arguments->dest_addr, arguments->dest_len));
}
BAN::ErrorOr<size_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{
sockaddr_in* sender_addr = nullptr;
if (arguments->address)
{
ASSERT(arguments->address_len);
if (*arguments->address_len < (socklen_t)sizeof(sockaddr_in))
*arguments->address_len = 0;
else
{
sender_addr = reinterpret_cast<sockaddr_in*>(arguments->address);
*arguments->address_len = sizeof(sockaddr_in);
}
}
if (!m_interface)
{
dprintln("No interface bound");
return BAN::Error::from_errno(EINVAL);
}
if (m_port == PORT_NONE)
{
dprintln("No port bound");
return BAN::Error::from_errno(EINVAL);
}
return TRY(read_packet(BAN::ByteSpan { reinterpret_cast<uint8_t*>(arguments->buffer), arguments->length }, sender_addr));
}
BAN::ErrorOr<long> NetworkSocket::ioctl_impl(int request, void* arg) BAN::ErrorOr<long> NetworkSocket::ioctl_impl(int request, void* arg)
{ {
if (!arg) if (!arg)

View File

@ -33,8 +33,11 @@ namespace Kernel
header.checksum = 0; header.checksum = 0;
} }
void UDPSocket::add_packet(BAN::ConstByteSpan packet, BAN::IPv4Address sender_addr, uint16_t sender_port) void UDPSocket::receive_packet(BAN::ConstByteSpan packet, const sockaddr_storage& sender)
{ {
//auto& header = packet.as<const UDPHeader>();
auto payload = packet.slice(sizeof(UDPHeader));
LockGuard _(m_packet_lock); LockGuard _(m_packet_lock);
if (m_packets.full()) if (m_packets.full())
@ -43,60 +46,82 @@ namespace Kernel
return; return;
} }
if (!m_packets.empty() && m_packet_total_size > m_packet_buffer->size()) if (m_packet_total_size + payload.size() > m_packet_buffer->size())
{ {
dprintln("Packet buffer full, dropping packet"); dprintln("Packet buffer full, dropping packet");
return; return;
} }
void* buffer = reinterpret_cast<void*>(m_packet_buffer->vaddr() + m_packet_total_size); void* buffer = reinterpret_cast<void*>(m_packet_buffer->vaddr() + m_packet_total_size);
memcpy(buffer, packet.data(), packet.size()); memcpy(buffer, payload.data(), payload.size());
m_packets.push(PacketInfo { m_packets.emplace(PacketInfo {
.sender_addr = sender_addr, .sender = sender,
.sender_port = sender_port, .packet_size = payload.size()
.packet_size = packet.size()
}); });
m_packet_total_size += packet.size(); m_packet_total_size += payload.size();
m_packet_semaphore.unblock(); m_packet_semaphore.unblock();
} }
BAN::ErrorOr<size_t> UDPSocket::read_packet(BAN::ByteSpan buffer, sockaddr_in* sender_addr) BAN::ErrorOr<void> UDPSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{ {
while (m_packets.empty()) if (is_bound())
TRY(Thread::current().block_or_eintr_indefinite(m_packet_semaphore)); return BAN::Error::from_errno(EINVAL);
return m_network_layer.bind_socket_to_address(this, address, address_len);
}
BAN::ErrorOr<size_t> UDPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len)
{
if (!is_bound())
{
dprintln("No interface bound");
return BAN::Error::from_errno(EINVAL);
}
ASSERT(m_port != PORT_NONE);
LockGuard _(m_packet_lock); LockGuard _(m_packet_lock);
if (m_packets.empty())
return read_packet(buffer, sender_addr); while (m_packets.empty())
{
LockFreeGuard free(m_packet_lock);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_semaphore));
}
auto packet_info = m_packets.front(); auto packet_info = m_packets.front();
m_packets.pop(); m_packets.pop();
size_t nread = BAN::Math::min<size_t>(packet_info.packet_size, buffer.size()); size_t nread = BAN::Math::min<size_t>(packet_info.packet_size, buffer.size());
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
memcpy( memcpy(
buffer.data(), buffer.data(),
(const void*)m_packet_buffer->vaddr(), packet_buffer,
nread nread
); );
memmove( memmove(
(void*)m_packet_buffer->vaddr(), packet_buffer,
(void*)(m_packet_buffer->vaddr() + packet_info.packet_size), packet_buffer + packet_info.packet_size,
m_packet_total_size - packet_info.packet_size m_packet_total_size - packet_info.packet_size
); );
m_packet_total_size -= packet_info.packet_size; m_packet_total_size -= packet_info.packet_size;
if (sender_addr) if (address && address_len)
{ {
sender_addr->sin_family = AF_INET; if (*address_len > (socklen_t)sizeof(sockaddr_storage))
sender_addr->sin_port = BAN::NetworkEndian(packet_info.sender_port); *address_len = sizeof(sockaddr_storage);
sender_addr->sin_addr.s_addr = packet_info.sender_addr.raw; memcpy(address, &packet_info.sender, *address_len);
} }
return nread; return nread;
} }
BAN::ErrorOr<size_t> UDPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{
if (!is_bound())
TRY(m_network_layer.bind_socket_to_unused(this, address, address_len));
return TRY(m_network_layer.sendto(*this, message, address, address_len));
}
} }

View File

@ -248,29 +248,27 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(const sys_sendto_t* arguments) BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{ {
if (arguments->flags) if (message.size() > s_packet_buffer_size)
return BAN::Error::from_errno(ENOTSUP);
if (arguments->length > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS); return BAN::Error::from_errno(ENOBUFS);
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
{ {
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
if (arguments->dest_addr) if (address)
return BAN::Error::from_errno(EISCONN); return BAN::Error::from_errno(EISCONN);
auto target = connection_info.connection.lock(); auto target = connection_info.connection.lock();
if (!target) if (!target)
return BAN::Error::from_errno(ENOTCONN); return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet({ reinterpret_cast<const uint8_t*>(arguments->message), arguments->length })); TRY(target->add_packet(message));
return arguments->length; return message.size();
} }
else else
{ {
BAN::String canonical_path; BAN::String canonical_path;
if (!arguments->dest_addr) if (!address)
{ {
auto& connectionless_info = m_info.get<ConnectionlessInfo>(); auto& connectionless_info = m_info.get<ConnectionlessInfo>();
if (connectionless_info.peer_address.empty()) if (connectionless_info.peer_address.empty())
@ -279,9 +277,9 @@ namespace Kernel
} }
else else
{ {
if (arguments->dest_len != sizeof(sockaddr_un)) if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(arguments->dest_addr); auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX) if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT); return BAN::Error::from_errno(EAFNOSUPPORT);
@ -301,16 +299,13 @@ namespace Kernel
auto target = s_bound_sockets[canonical_path].lock(); auto target = s_bound_sockets[canonical_path].lock();
if (!target) if (!target)
return BAN::Error::from_errno(EDESTADDRREQ); return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet({ reinterpret_cast<const uint8_t*>(arguments->message), arguments->length })); TRY(target->add_packet(message));
return arguments->length; return message.size();
} }
} }
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(sys_recvfrom_t* arguments) BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*)
{ {
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
{ {
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
@ -328,14 +323,14 @@ namespace Kernel
size_t nread = 0; size_t nread = 0;
if (is_streaming()) if (is_streaming())
nread = BAN::Math::min(arguments->length, m_packet_size_total); nread = BAN::Math::min(buffer.size(), m_packet_size_total);
else else
{ {
nread = BAN::Math::min(arguments->length, m_packet_sizes.front()); nread = BAN::Math::min(buffer.size(), m_packet_sizes.front());
m_packet_sizes.pop(); m_packet_sizes.pop();
} }
memcpy(arguments->buffer, packet_buffer, nread); memcpy(buffer.data(), packet_buffer, nread);
memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread); memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread);
m_packet_size_total -= nread; m_packet_size_total -= nread;

View File

@ -983,7 +983,8 @@ namespace Kernel
if (!inode->mode().ifsock()) if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return TRY(inode->sendto(arguments)); BAN::ConstByteSpan message { reinterpret_cast<const uint8_t*>(arguments->message), arguments->length };
return TRY(inode->sendto(message, arguments->dest_addr, arguments->dest_len));
} }
BAN::ErrorOr<long> Process::sys_recvfrom(sys_recvfrom_t* arguments) BAN::ErrorOr<long> Process::sys_recvfrom(sys_recvfrom_t* arguments)
@ -1006,7 +1007,8 @@ namespace Kernel
if (!inode->mode().ifsock()) if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return TRY(inode->recvfrom(arguments)); BAN::ByteSpan buffer { reinterpret_cast<uint8_t*>(arguments->buffer), arguments->length };
return TRY(inode->recvfrom(buffer, arguments->address, arguments->address_len));
} }
BAN::ErrorOr<long> Process::sys_ioctl(int fildes, int request, void* arg) BAN::ErrorOr<long> Process::sys_ioctl(int fildes, int request, void* arg)