banan-os/kernel/kernel/Networking/UNIX/Socket.cpp

394 lines
12 KiB
C++

#include <BAN/HashMap.h>
#include <kernel/FS/VirtualFileSystem.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UNIX/Socket.h>
#include <kernel/Scheduler.h>
#include <fcntl.h>
#include <sys/un.h>
namespace Kernel
{
static BAN::HashMap<BAN::String, BAN::WeakPtr<UnixDomainSocket>> s_bound_sockets;
static SpinLock s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
{
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
s_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
return socket;
}
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
: Socket(info)
, m_socket_type(socket_type)
{
switch (socket_type)
{
case Socket::Type::STREAM:
case Socket::Type::SEQPACKET:
m_info.emplace<ConnectionInfo>();
break;
case Socket::Type::DGRAM:
m_info.emplace<ConnectionlessInfo>();
break;
default:
ASSERT_NOT_REACHED();
}
}
UnixDomainSocket::~UnixDomainSocket()
{
if (is_bound() && !is_bound_to_unused())
{
SpinLockGuard _(s_bound_socket_lock);
auto it = s_bound_sockets.find(m_bound_path);
if (it != s_bound_sockets.end())
s_bound_sockets.remove(it);
}
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (auto connection = connection_info.connection.lock(); connection && connection->m_info.has<ConnectionInfo>())
connection->m_info.get<ConnectionInfo>().target_closed = true;
}
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags)
{
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.listening)
return BAN::Error::from_errno(EINVAL);
while (connection_info.pending_connections.empty())
TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker));
BAN::RefPtr<UnixDomainSocket> pending;
{
SpinLockGuard _(connection_info.pending_lock);
pending = connection_info.pending_connections.front();
connection_info.pending_connections.pop();
connection_info.pending_thread_blocker.unblock();
}
BAN::RefPtr<UnixDomainSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(Socket::Domain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
}
TRY(return_inode->m_bound_path.push_back('X'));
return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection_done = true;
if (address && address_len && !is_bound_to_unused())
{
size_t copy_len = BAN::Math::min<size_t>(*address_len, sizeof(sockaddr) + m_bound_path.size() + 1);
auto& sockaddr_un = *reinterpret_cast<struct sockaddr_un*>(address);
sockaddr_un.sun_family = AF_UNIX;
strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len);
}
return TRY(Process::current().open_inode(VirtualFileSystem::File(return_inode, "<unix socket>"_sv), O_RDWR | flags));
}
BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (!is_bound())
TRY(m_bound_path.push_back('X'));
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
BAN::RefPtr<UnixDomainSocket> target;
{
SpinLockGuard _(s_bound_socket_lock);
auto it = s_bound_sockets.find(file.canonical_path);
if (it == s_bound_sockets.end())
return BAN::Error::from_errno(ECONNREFUSED);
target = it->value.lock();
if (!target)
return BAN::Error::from_errno(ECONNREFUSED);
}
if (m_socket_type != target->m_socket_type)
return BAN::Error::from_errno(EPROTOTYPE);
if (m_info.has<ConnectionlessInfo>())
{
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
connectionless_info.peer_address = BAN::move(file.canonical_path);
return {};
}
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(ECONNREFUSED);
if (connection_info.listening)
return BAN::Error::from_errno(EOPNOTSUPP);
connection_info.connection_done = false;
for (;;)
{
auto& target_info = target->m_info.get<ConnectionInfo>();
{
SpinLockGuard _(target_info.pending_lock);
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
{
MUST(target_info.pending_connections.push(this));
target_info.pending_thread_blocker.unblock();
break;
}
}
TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker));
}
while (!connection_info.connection_done)
Processor::yield();
return {};
}
BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog)
{
backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN);
if (!is_bound())
return BAN::Error::from_errno(EDESTADDRREQ);
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(EINVAL);
TRY(connection_info.pending_connections.reserve(backlog));
connection_info.listening = true;
return {};
}
BAN::ErrorOr<void> UnixDomainSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (is_bound())
return BAN::Error::from_errno(EINVAL);
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto bind_path = BAN::StringView(sockaddr_un.sun_path);
if (bind_path.empty())
return BAN::Error::from_errno(EINVAL);
// FIXME: This feels sketchy
auto parent_file = bind_path.front() == '/'
? VirtualFileSystem::get().root_file()
: TRY(Process::current().working_directory().clone());
if (auto ret = Process::current().create_file_or_dir(AT_FDCWD, bind_path.data(), 0755 | S_IFSOCK); ret.is_error())
{
if (ret.error().get_error_code() == EEXIST)
return BAN::Error::from_errno(EADDRINUSE);
return ret.release_error();
}
auto file = TRY(VirtualFileSystem::get().file_from_relative_path(
parent_file,
Process::current().credentials(),
bind_path,
O_RDWR
));
SpinLockGuard _(s_bound_socket_lock);
ASSERT(!s_bound_sockets.contains(file.canonical_path));
TRY(s_bound_sockets.emplace(file.canonical_path, TRY(get_weak_ptr())));
m_bound_path = BAN::move(file.canonical_path);
return {};
}
bool UnixDomainSocket::is_streaming() const
{
switch (m_socket_type)
{
case Socket::Type::STREAM:
return true;
case Socket::Type::SEQPACKET:
case Socket::Type::DGRAM:
return false;
default:
ASSERT_NOT_REACHED();
}
}
BAN::ErrorOr<void> UnixDomainSocket::add_packet(BAN::ConstByteSpan packet)
{
auto state = m_packet_lock.lock();
while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size)
{
m_packet_lock.unlock(state);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker));
state = m_packet_lock.lock();
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
memcpy(packet_buffer, packet.data(), packet.size());
m_packet_size_total += packet.size();
if (!is_streaming())
m_packet_sizes.push(packet.size());
m_packet_thread_blocker.unblock();
m_packet_lock.unlock(state);
return {};
}
bool UnixDomainSocket::can_read_impl() const
{
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.target_closed)
return true;
if (!connection_info.pending_connections.empty())
return true;
if (!connection_info.connection)
return false;
}
return m_packet_size_total > 0;
}
bool UnixDomainSocket::can_write_impl() const
{
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
return connection_info.connection.valid();
}
return true;
}
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{
if (message.size() > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (address)
return BAN::Error::from_errno(EISCONN);
auto target = connection_info.connection.lock();
if (!target)
return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet(message));
return message.size();
}
else
{
BAN::String canonical_path;
if (!address)
{
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
if (connectionless_info.peer_address.empty())
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(canonical_path.append(connectionless_info.peer_address));
}
else
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_WRONLY
));
canonical_path = BAN::move(file.canonical_path);
}
SpinLockGuard _(s_bound_socket_lock);
auto it = s_bound_sockets.find(canonical_path);
if (it == s_bound_sockets.end())
return BAN::Error::from_errno(EDESTADDRREQ);
auto target = it->value.lock();
if (!target)
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet(message));
return message.size();
}
}
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*)
{
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
bool expected = true;
if (connection_info.target_closed.compare_exchange(expected, false))
return 0;
if (!connection_info.connection)
return BAN::Error::from_errno(ENOTCONN);
}
auto state = m_packet_lock.lock();
while (m_packet_size_total == 0)
{
m_packet_lock.unlock(state);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker));
state = m_packet_lock.lock();
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t nread = 0;
if (is_streaming())
nread = BAN::Math::min(buffer.size(), m_packet_size_total);
else
{
nread = BAN::Math::min(buffer.size(), m_packet_sizes.front());
m_packet_sizes.pop();
}
memcpy(buffer.data(), packet_buffer, nread);
memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread);
m_packet_size_total -= nread;
m_packet_thread_blocker.unblock();
m_packet_lock.unlock(state);
return nread;
}
}