Kernel: Implement connect for UDP socket

This commit is contained in:
Bananymous 2025-11-12 03:34:16 +02:00
parent 59cfc339b0
commit c700d9f714
2 changed files with 29 additions and 1 deletions

View File

@ -33,6 +33,7 @@ namespace Kernel
protected:
virtual void receive_packet(BAN::ConstByteSpan, const sockaddr* sender, socklen_t sender_len) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> recvmsg_impl(msghdr& message, int flags) override;
virtual BAN::ErrorOr<size_t> sendmsg_impl(const msghdr& message, int flags) override;
@ -63,6 +64,9 @@ namespace Kernel
SpinLock m_packet_lock;
ThreadBlocker m_packet_thread_blocker;
sockaddr_storage m_peer_address {};
socklen_t m_peer_address_len { 0 };
friend class BAN::RefPtr<UDPSocket>;
};

View File

@ -79,6 +79,15 @@ namespace Kernel
m_packet_thread_blocker.unblock();
}
BAN::ErrorOr<void> UDPSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len > static_cast<socklen_t>(sizeof(m_peer_address)))
address_len = sizeof(m_peer_address);
memcpy(&m_peer_address, address, address_len);
m_peer_address_len = address_len;
return {};
}
BAN::ErrorOr<void> UDPSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (is_bound())
@ -187,7 +196,22 @@ namespace Kernel
offset += message.msg_iov[i].iov_len;
}
return TRY(m_network_layer.sendto(*this, buffer.span(), static_cast<sockaddr*>(message.msg_name), message.msg_namelen));
sockaddr* address;
socklen_t address_len;
if (!message.msg_name || message.msg_namelen == 0)
{
if (m_peer_address_len == 0)
return BAN::Error::from_errno(EDESTADDRREQ);
address = reinterpret_cast<sockaddr*>(&m_peer_address);
address_len = m_peer_address_len;
}
else
{
address = static_cast<sockaddr*>(message.msg_name);
address_len = message.msg_namelen;
}
return TRY(m_network_layer.sendto(*this, buffer.span(), address, address_len));
}
BAN::ErrorOr<long> UDPSocket::ioctl_impl(int request, void* argument)