Kernel: Implement unix domain sockets with SOCK_DGRAM

Also unbind sockets on close
This commit is contained in:
Bananymous 2024-02-08 13:17:00 +02:00
parent 065ee9004c
commit 9bc7a72a25
2 changed files with 84 additions and 36 deletions

View File

@ -17,6 +17,8 @@ namespace Kernel
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
protected:
virtual void on_close_impl() override;
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override;
@ -47,7 +49,7 @@ namespace Kernel
struct ConnectionlessInfo
{
BAN::String peer_address;
};
private:

View File

@ -10,8 +10,8 @@
namespace Kernel
{
static BAN::HashMap<BAN::String, BAN::RefPtr<UnixDomainSocket>> s_bound_sockets;
static SpinLock s_bound_socket_lock;
static BAN::HashMap<BAN::String, BAN::WeakPtr<UnixDomainSocket>> s_bound_sockets;
static SpinLock s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
@ -47,6 +47,16 @@ namespace Kernel
}
}
void UnixDomainSocket::on_close_impl()
{
if (is_bound() && !is_bound_to_unused())
{
LockGuard _(s_bound_socket_lock);
if (s_bound_sockets.contains(m_bound_path))
s_bound_sockets.remove(m_bound_path);
}
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len)
{
if (!m_info.has<ConnectionInfo>())
@ -74,7 +84,7 @@ namespace Kernel
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
}
TRY(return_inode->m_bound_path.append(m_bound_path));
TRY(return_inode->m_bound_path.push_back('X'));
return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection_done = true;
@ -113,46 +123,48 @@ namespace Kernel
LockGuard _(s_bound_socket_lock);
if (!s_bound_sockets.contains(file.canonical_path))
return BAN::Error::from_errno(ECONNREFUSED);
target = s_bound_sockets[file.canonical_path];
target = s_bound_sockets[file.canonical_path].lock();
if (!target)
return BAN::Error::from_errno(ECONNREFUSED);
}
if (m_socket_type != target->m_socket_type)
return BAN::Error::from_errno(EPROTOTYPE);
if (m_info.has<ConnectionInfo>())
if (m_info.has<ConnectionlessInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(ECONNREFUSED);
if (connection_info.listening)
return BAN::Error::from_errno(EOPNOTSUPP);
connection_info.connection_done = false;
for (;;)
{
auto& target_info = target->m_info.get<ConnectionInfo>();
{
LockGuard _(target_info.pending_lock);
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
{
MUST(target_info.pending_connections.push(this));
target_info.pending_semaphore.unblock();
break;
}
}
TRY(Thread::current().block_or_eintr(target_info.pending_semaphore));
}
while (!connection_info.connection_done)
Scheduler::get().reschedule();
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
connectionless_info.peer_address = BAN::move(file.canonical_path);
return {};
}
else
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(ECONNREFUSED);
if (connection_info.listening)
return BAN::Error::from_errno(EOPNOTSUPP);
connection_info.connection_done = false;
for (;;)
{
return BAN::Error::from_errno(ENOTSUP);
auto& target_info = target->m_info.get<ConnectionInfo>();
{
LockGuard _(target_info.pending_lock);
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
{
MUST(target_info.pending_connections.push(this));
target_info.pending_semaphore.unblock();
break;
}
}
TRY(Thread::current().block_or_eintr(target_info.pending_semaphore));
}
while (!connection_info.connection_done)
Scheduler::get().reschedule();
return {};
}
BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog)
@ -195,7 +207,7 @@ namespace Kernel
LockGuard _(s_bound_socket_lock);
ASSERT(!s_bound_sockets.contains(file.canonical_path));
TRY(s_bound_sockets.emplace(file.canonical_path, this));
TRY(s_bound_sockets.emplace(file.canonical_path, TRY(get_weak_ptr())));
m_bound_path = BAN::move(file.canonical_path);
return {};
@ -278,7 +290,41 @@ namespace Kernel
}
else
{
return BAN::Error::from_errno(ENOTSUP);
BAN::String canonical_path;
if (!arguments->dest_addr)
{
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
if (connectionless_info.peer_address.empty())
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(canonical_path.append(connectionless_info.peer_address));
}
else
{
if (arguments->dest_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(arguments->dest_addr);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_WRONLY
));
canonical_path = BAN::move(file.canonical_path);
}
LockGuard _(s_bound_socket_lock);
if (!s_bound_sockets.contains(canonical_path))
return BAN::Error::from_errno(EDESTADDRREQ);
auto target = s_bound_sockets[canonical_path].lock();
if (!target)
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet({ reinterpret_cast<const uint8_t*>(arguments->message), arguments->length }));
return arguments->length;
}
}