682 lines
20 KiB
C++
682 lines
20 KiB
C++
#include <BAN/HashMap.h>
|
|
|
|
#include <kernel/FS/VirtualFileSystem.h>
|
|
#include <kernel/Lock/LockGuard.h>
|
|
#include <kernel/Networking/NetworkManager.h>
|
|
#include <kernel/Networking/UNIX/Socket.h>
|
|
#include <kernel/Process.h>
|
|
#include <kernel/Scheduler.h>
|
|
|
|
#include <fcntl.h>
|
|
#include <sys/epoll.h>
|
|
#include <sys/un.h>
|
|
|
|
namespace Kernel
|
|
{
|
|
|
|
struct UnixSocketHash
|
|
{
|
|
BAN::hash_t operator()(const BAN::RefPtr<Inode>& socket)
|
|
{
|
|
return BAN::hash<const Inode*>{}(socket.ptr());
|
|
}
|
|
};
|
|
|
|
static BAN::HashMap<BAN::RefPtr<Inode>, BAN::WeakPtr<UnixDomainSocket>, UnixSocketHash> s_bound_sockets;
|
|
static Mutex s_bound_socket_lock;
|
|
|
|
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
|
|
|
|
static BAN::ErrorOr<BAN::StringView> validate_sockaddr_un(const sockaddr* address, socklen_t address_len)
|
|
{
|
|
if (address_len < static_cast<socklen_t>(sizeof(sa_family_t)))
|
|
return BAN::Error::from_errno(EINVAL);
|
|
if (address_len > static_cast<socklen_t>(sizeof(sockaddr_un)))
|
|
address_len = sizeof(sockaddr_un);
|
|
|
|
const auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
|
|
if (sockaddr_un.sun_family != AF_UNIX)
|
|
return BAN::Error::from_errno(EINVAL);
|
|
|
|
size_t length = 0;
|
|
while (length < address_len - sizeof(sa_family_t) && sockaddr_un.sun_path[length])
|
|
length++;
|
|
|
|
return BAN::StringView { sockaddr_un.sun_path, length };
|
|
}
|
|
|
|
// FIXME: why is this using spinlocks instead of mutexes??
|
|
|
|
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
|
|
{
|
|
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
|
|
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
|
|
PageTable::kernel(),
|
|
KERNEL_OFFSET,
|
|
~(uintptr_t)0,
|
|
s_packet_buffer_size,
|
|
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
|
|
true, false
|
|
));
|
|
return socket;
|
|
}
|
|
|
|
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
|
|
: Socket(info)
|
|
, m_socket_type(socket_type)
|
|
{
|
|
switch (socket_type)
|
|
{
|
|
case Socket::Type::STREAM:
|
|
case Socket::Type::SEQPACKET:
|
|
m_info.emplace<ConnectionInfo>();
|
|
break;
|
|
case Socket::Type::DGRAM:
|
|
m_info.emplace<ConnectionlessInfo>();
|
|
break;
|
|
default:
|
|
ASSERT_NOT_REACHED();
|
|
}
|
|
}
|
|
|
|
UnixDomainSocket::~UnixDomainSocket()
|
|
{
|
|
if (is_bound() && !is_bound_to_unused())
|
|
{
|
|
LockGuard _(s_bound_socket_lock);
|
|
s_bound_sockets.remove(m_bound_file.inode);
|
|
}
|
|
if (m_info.has<ConnectionInfo>())
|
|
{
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
if (auto connection = connection_info.connection.lock(); connection && connection->m_info.has<ConnectionInfo>())
|
|
{
|
|
connection->m_info.get<ConnectionInfo>().target_closed = true;
|
|
connection->epoll_notify(EPOLLHUP);
|
|
connection->m_packet_thread_blocker.unblock();
|
|
}
|
|
}
|
|
}
|
|
|
|
BAN::ErrorOr<void> UnixDomainSocket::make_socket_pair(UnixDomainSocket& other)
|
|
{
|
|
if (!m_info.has<ConnectionInfo>() || !other.m_info.has<ConnectionInfo>())
|
|
return BAN::Error::from_errno(EINVAL);
|
|
|
|
TRY(this->get_weak_ptr());
|
|
TRY(other.get_weak_ptr());
|
|
|
|
this->m_info.get<ConnectionInfo>().connection = MUST(other.get_weak_ptr());
|
|
other.m_info.get<ConnectionInfo>().connection = MUST(this->get_weak_ptr());
|
|
|
|
return {};
|
|
}
|
|
|
|
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len, int flags)
|
|
{
|
|
if (!m_info.has<ConnectionInfo>())
|
|
return BAN::Error::from_errno(EOPNOTSUPP);
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
if (!connection_info.listening)
|
|
return BAN::Error::from_errno(EINVAL);
|
|
|
|
|
|
BAN::RefPtr<UnixDomainSocket> pending;
|
|
|
|
{
|
|
LockGuard _(connection_info.pending_lock);
|
|
while (connection_info.pending_connections.empty())
|
|
TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_thread_blocker, &connection_info.pending_lock));
|
|
|
|
pending = connection_info.pending_connections.front();
|
|
connection_info.pending_connections.pop();
|
|
connection_info.pending_thread_blocker.unblock();
|
|
}
|
|
|
|
BAN::RefPtr<UnixDomainSocket> return_inode;
|
|
|
|
{
|
|
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(Socket::Domain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
|
|
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
|
|
}
|
|
|
|
TRY(return_inode->m_bound_file.canonical_path.push_back('X'));
|
|
return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr());
|
|
pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr());
|
|
pending->m_info.get<ConnectionInfo>().connection_done = true;
|
|
|
|
if (address && address_len && !is_bound_to_unused())
|
|
{
|
|
sockaddr_un sa_un {
|
|
.sun_family = AF_UNIX,
|
|
.sun_path {},
|
|
};
|
|
strcpy(sa_un.sun_path, pending->m_bound_file.canonical_path.data());
|
|
|
|
const size_t to_copy = BAN::Math::min<size_t>(*address_len, sizeof(sockaddr_un));
|
|
memcpy(address, &sa_un, to_copy);
|
|
*address_len = to_copy;
|
|
}
|
|
|
|
return TRY(Process::current().open_inode(VirtualFileSystem::File(return_inode, "<unix socket>"_sv), O_RDWR | flags));
|
|
}
|
|
|
|
BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len)
|
|
{
|
|
const auto sun_path = TRY(validate_sockaddr_un(address, address_len));
|
|
if (!is_bound())
|
|
TRY(m_bound_file.canonical_path.push_back('X'));
|
|
|
|
auto absolute_path = TRY(Process::current().absolute_path_of(sun_path));
|
|
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
|
|
Process::current().root_file().inode,
|
|
Process::current().credentials(),
|
|
absolute_path,
|
|
O_RDWR
|
|
));
|
|
|
|
BAN::RefPtr<UnixDomainSocket> target;
|
|
|
|
{
|
|
LockGuard _(s_bound_socket_lock);
|
|
auto it = s_bound_sockets.find(file.inode);
|
|
if (it == s_bound_sockets.end())
|
|
return BAN::Error::from_errno(ECONNREFUSED);
|
|
target = it->value.lock();
|
|
if (!target)
|
|
return BAN::Error::from_errno(ECONNREFUSED);
|
|
}
|
|
|
|
if (m_socket_type != target->m_socket_type)
|
|
return BAN::Error::from_errno(EPROTOTYPE);
|
|
|
|
if (m_info.has<ConnectionlessInfo>())
|
|
{
|
|
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
|
|
connectionless_info.peer_address = BAN::move(file.canonical_path);
|
|
return {};
|
|
}
|
|
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
if (connection_info.connection)
|
|
return BAN::Error::from_errno(ECONNREFUSED);
|
|
if (connection_info.listening)
|
|
return BAN::Error::from_errno(EOPNOTSUPP);
|
|
|
|
connection_info.connection_done = false;
|
|
|
|
for (;;)
|
|
{
|
|
auto& target_info = target->m_info.get<ConnectionInfo>();
|
|
|
|
LockGuard _(target_info.pending_lock);
|
|
|
|
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
|
|
{
|
|
MUST(target_info.pending_connections.push(this));
|
|
target_info.pending_thread_blocker.unblock();
|
|
break;
|
|
}
|
|
|
|
TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_thread_blocker, &target_info.pending_lock));
|
|
}
|
|
|
|
target->epoll_notify(EPOLLIN);
|
|
|
|
while (!connection_info.connection_done)
|
|
Processor::yield();
|
|
|
|
return {};
|
|
}
|
|
|
|
BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog)
|
|
{
|
|
backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN);
|
|
if (!is_bound())
|
|
return BAN::Error::from_errno(EDESTADDRREQ);
|
|
if (!m_info.has<ConnectionInfo>())
|
|
return BAN::Error::from_errno(EOPNOTSUPP);
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
if (connection_info.connection)
|
|
return BAN::Error::from_errno(EINVAL);
|
|
TRY(connection_info.pending_connections.reserve(backlog));
|
|
connection_info.listening = true;
|
|
return {};
|
|
}
|
|
|
|
BAN::ErrorOr<void> UnixDomainSocket::bind_impl(const sockaddr* address, socklen_t address_len)
|
|
{
|
|
if (is_bound())
|
|
return BAN::Error::from_errno(EINVAL);
|
|
|
|
const auto sun_path = TRY(validate_sockaddr_un(address, address_len));
|
|
if (sun_path.empty())
|
|
return BAN::Error::from_errno(EINVAL);
|
|
|
|
// FIXME: This feels sketchy
|
|
auto parent_file = sun_path.front() == '/'
|
|
? TRY(Process::current().root_file().clone())
|
|
: TRY(Process::current().working_directory().clone());
|
|
if (auto ret = Process::current().create_file_or_dir(AT_FDCWD, sun_path.data(), 0755 | S_IFSOCK); ret.is_error())
|
|
{
|
|
if (ret.error().get_error_code() == EEXIST)
|
|
return BAN::Error::from_errno(EADDRINUSE);
|
|
return ret.release_error();
|
|
}
|
|
auto file = TRY(VirtualFileSystem::get().file_from_relative_path(
|
|
Process::current().root_file().inode,
|
|
parent_file,
|
|
Process::current().credentials(),
|
|
sun_path,
|
|
O_RDWR
|
|
));
|
|
|
|
LockGuard _(s_bound_socket_lock);
|
|
if (s_bound_sockets.contains(file.inode))
|
|
return BAN::Error::from_errno(EADDRINUSE);
|
|
TRY(s_bound_sockets.emplace(file.inode, TRY(get_weak_ptr())));
|
|
m_bound_file = BAN::move(file);
|
|
|
|
return {};
|
|
}
|
|
|
|
bool UnixDomainSocket::is_streaming() const
|
|
{
|
|
switch (m_socket_type)
|
|
{
|
|
case Socket::Type::STREAM:
|
|
return true;
|
|
case Socket::Type::SEQPACKET:
|
|
case Socket::Type::DGRAM:
|
|
return false;
|
|
default:
|
|
ASSERT_NOT_REACHED();
|
|
}
|
|
}
|
|
|
|
BAN::ErrorOr<void> UnixDomainSocket::add_packet(const msghdr& packet, PacketInfo&& packet_info)
|
|
{
|
|
LockGuard _(m_packet_lock);
|
|
|
|
while (m_packet_infos.full() || m_packet_size_total + packet_info.size > s_packet_buffer_size)
|
|
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
|
|
|
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
|
|
|
|
size_t offset = 0;
|
|
for (int i = 0; i < packet.msg_iovlen; i++)
|
|
{
|
|
memcpy(packet_buffer + offset, packet.msg_iov[i].iov_base, packet.msg_iov[i].iov_len);
|
|
offset += packet.msg_iov[i].iov_len;
|
|
}
|
|
|
|
ASSERT(offset == packet_info.size);
|
|
m_packet_size_total += packet_info.size;
|
|
m_packet_infos.emplace(BAN::move(packet_info));
|
|
|
|
m_packet_thread_blocker.unblock();
|
|
|
|
epoll_notify(EPOLLIN);
|
|
|
|
return {};
|
|
}
|
|
|
|
bool UnixDomainSocket::can_read_impl() const
|
|
{
|
|
if (m_info.has<ConnectionInfo>())
|
|
{
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
if (connection_info.listening)
|
|
return !connection_info.pending_connections.empty();
|
|
if (connection_info.target_closed)
|
|
return true;
|
|
if (!connection_info.connection)
|
|
return false;
|
|
}
|
|
|
|
return m_packet_size_total > 0;
|
|
}
|
|
|
|
bool UnixDomainSocket::can_write_impl() const
|
|
{
|
|
if (m_info.has<ConnectionInfo>())
|
|
{
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
auto connection = connection_info.connection.lock();
|
|
if (!connection)
|
|
return false;
|
|
if (connection->m_packet_infos.full())
|
|
return false;
|
|
if (connection->m_packet_size_total >= s_packet_buffer_size)
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool UnixDomainSocket::has_hungup_impl() const
|
|
{
|
|
if (m_info.has<ConnectionInfo>())
|
|
{
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
return connection_info.target_closed;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
BAN::ErrorOr<size_t> UnixDomainSocket::recvmsg_impl(msghdr& message, int flags)
|
|
{
|
|
flags &= (MSG_OOB | MSG_PEEK | MSG_WAITALL);
|
|
if (flags != 0)
|
|
{
|
|
dwarnln("TODO: recvmsg with flags 0x{H}", flags);
|
|
return BAN::Error::from_errno(ENOTSUP);
|
|
}
|
|
|
|
LockGuard _(m_packet_lock);
|
|
while (m_packet_size_total == 0)
|
|
{
|
|
if (m_info.has<ConnectionInfo>())
|
|
{
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
bool expected = true;
|
|
if (connection_info.target_closed.compare_exchange(expected, false))
|
|
return 0;
|
|
if (!connection_info.connection)
|
|
return BAN::Error::from_errno(ENOTCONN);
|
|
}
|
|
|
|
TRY(Thread::current().block_or_eintr_indefinite(m_packet_thread_blocker, &m_packet_lock));
|
|
}
|
|
|
|
auto* cheader = CMSG_FIRSTHDR(&message);
|
|
if (cheader != nullptr)
|
|
cheader->cmsg_len = message.msg_controllen;
|
|
size_t cheader_len = 0;
|
|
|
|
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
|
|
|
|
message.msg_flags = 0;
|
|
|
|
int iov_index = 0;
|
|
size_t iov_offset = 0;
|
|
size_t total_recv = 0;
|
|
|
|
while (!m_packet_infos.empty() && iov_index < message.msg_iovlen)
|
|
{
|
|
auto& packet_info = m_packet_infos.front();
|
|
|
|
auto fds_to_open = BAN::move(packet_info.fds);
|
|
auto ucred_to_recv = BAN::move(packet_info.ucred);
|
|
const bool had_ancillary_data = !fds_to_open.empty() || ucred_to_recv.has_value();
|
|
|
|
if (!fds_to_open.empty()) do
|
|
{
|
|
if (cheader == nullptr)
|
|
{
|
|
dwarnln("no space to receive {} fds", fds_to_open.size());
|
|
message.msg_flags |= MSG_CTRUNC;
|
|
break;
|
|
}
|
|
|
|
const size_t max_fd_count = (cheader->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
|
|
if (max_fd_count < fds_to_open.size())
|
|
message.msg_flags |= MSG_CTRUNC;
|
|
|
|
const size_t fd_count = BAN::Math::min(fds_to_open.size(), max_fd_count);
|
|
const size_t fds_opened = Process::current().open_file_descriptor_set().open_all_fd_wrappers(fds_to_open.span().slice(0, fd_count));
|
|
|
|
auto* fd_data = reinterpret_cast<int*>(CMSG_DATA(cheader));
|
|
for (size_t i = 0; i < fds_opened; i++)
|
|
fd_data[i] = fds_to_open[i].fd();
|
|
|
|
const size_t header_length = CMSG_LEN(fds_opened * sizeof(int));
|
|
cheader->cmsg_level = SOL_SOCKET;
|
|
cheader->cmsg_type = SCM_RIGHTS;
|
|
cheader->cmsg_len = header_length;
|
|
cheader = CMSG_NXTHDR(&message, cheader);
|
|
if (cheader != nullptr)
|
|
cheader->cmsg_len = message.msg_controllen - header_length;
|
|
cheader_len += header_length;
|
|
} while (false);
|
|
|
|
if (ucred_to_recv.has_value()) do
|
|
{
|
|
if (cheader == nullptr || cheader->cmsg_len - sizeof(cmsghdr) < sizeof(struct ucred))
|
|
{
|
|
dwarnln("no space to receive credentials");
|
|
message.msg_flags |= MSG_CTRUNC;
|
|
break;
|
|
}
|
|
|
|
*reinterpret_cast<struct ucred*>(CMSG_DATA(cheader)) = ucred_to_recv.value();
|
|
|
|
const size_t header_length = CMSG_LEN(sizeof(struct ucred));
|
|
cheader->cmsg_level = SOL_SOCKET;
|
|
cheader->cmsg_type = SCM_CREDENTIALS;
|
|
cheader->cmsg_len = header_length;
|
|
cheader = CMSG_NXTHDR(&message, cheader);
|
|
if (cheader != nullptr)
|
|
cheader->cmsg_len = message.msg_controllen - header_length;
|
|
cheader_len += header_length;
|
|
} while (false);
|
|
|
|
size_t packet_received = 0;
|
|
while (iov_index < message.msg_iovlen && packet_received < packet_info.size)
|
|
{
|
|
auto& iov = message.msg_iov[iov_index];
|
|
uint8_t* iov_base = static_cast<uint8_t*>(iov.iov_base);
|
|
|
|
const size_t nrecv = BAN::Math::min<size_t>(iov.iov_len - iov_offset, packet_info.size - packet_received);
|
|
memcpy(iov_base + iov_offset, packet_buffer + packet_received, nrecv);
|
|
|
|
packet_received += nrecv;
|
|
|
|
iov_offset += nrecv;
|
|
if (iov_offset >= iov.iov_len)
|
|
{
|
|
iov_offset = 0;
|
|
iov_index++;
|
|
}
|
|
}
|
|
|
|
if (!is_streaming() && packet_received < packet_info.size)
|
|
message.msg_flags |= MSG_TRUNC;
|
|
|
|
const size_t to_discard = is_streaming() ? packet_received : packet_info.size;
|
|
|
|
packet_info.size -= to_discard;
|
|
if (packet_info.size == 0)
|
|
m_packet_infos.pop();
|
|
|
|
// FIXME: get rid of this memmove :)
|
|
memmove(packet_buffer, packet_buffer + to_discard, m_packet_size_total - to_discard);
|
|
m_packet_size_total -= to_discard;
|
|
|
|
total_recv += packet_received;
|
|
|
|
// on linux ancillary data is a barrier on stream sockets, lets do the same
|
|
if (!is_streaming() || had_ancillary_data)
|
|
break;
|
|
}
|
|
|
|
message.msg_controllen = cheader_len;
|
|
|
|
m_packet_thread_blocker.unblock();
|
|
|
|
epoll_notify(EPOLLOUT);
|
|
|
|
return total_recv;
|
|
}
|
|
|
|
BAN::ErrorOr<size_t> UnixDomainSocket::sendmsg_impl(const msghdr& message, int flags)
|
|
{
|
|
if (flags & MSG_NOSIGNAL)
|
|
dwarnln("sendmsg ignoring MSG_NOSIGNAL");
|
|
flags &= (MSG_EOR | MSG_OOB /* | MSG_NOSIGNAL */);
|
|
if (flags != 0)
|
|
{
|
|
dwarnln("TODO: sendmsg with flags 0x{H}", flags);
|
|
return BAN::Error::from_errno(ENOTSUP);
|
|
}
|
|
|
|
const size_t total_message_size =
|
|
[&message]() -> size_t {
|
|
size_t result = 0;
|
|
for (int i = 0; i < message.msg_iovlen; i++)
|
|
result += message.msg_iov[i].iov_len;
|
|
return result;
|
|
}();
|
|
|
|
if (total_message_size > s_packet_buffer_size)
|
|
return BAN::Error::from_errno(ENOBUFS);
|
|
|
|
PacketInfo packet_info {
|
|
.size = total_message_size,
|
|
.fds = {},
|
|
.ucred = {},
|
|
};
|
|
|
|
for (const auto* header = CMSG_FIRSTHDR(&message); header; header = CMSG_NXTHDR(&message, header))
|
|
{
|
|
if (header->cmsg_level != SOL_SOCKET)
|
|
{
|
|
dwarnln("ignoring control message with level {}", header->cmsg_level);
|
|
continue;
|
|
}
|
|
|
|
switch (header->cmsg_type)
|
|
{
|
|
case SCM_RIGHTS:
|
|
{
|
|
if (!packet_info.fds.empty())
|
|
{
|
|
dwarnln("multiple SCM_RIGHTS in one sendmsg");
|
|
return BAN::Error::from_errno(EINVAL);
|
|
}
|
|
|
|
const auto* fd_data = reinterpret_cast<const int*>(CMSG_DATA(header));
|
|
const size_t fd_count = (header->cmsg_len - sizeof(cmsghdr)) / sizeof(int);
|
|
for (size_t i = 0; i < fd_count; i++)
|
|
TRY(packet_info.fds.push_back(TRY(Process::current().open_file_descriptor_set().get_fd_wrapper(fd_data[i]))));
|
|
break;
|
|
}
|
|
case SCM_CREDENTIALS:
|
|
{
|
|
if (packet_info.ucred.has_value())
|
|
{
|
|
dwarnln("multiple SCM_CREDENTIALS in one sendmsg");
|
|
return BAN::Error::from_errno(EINVAL);
|
|
}
|
|
|
|
if (header->cmsg_len - sizeof(cmsghdr) < sizeof(struct ucred))
|
|
return BAN::Error::from_errno(EINVAL);
|
|
const ucred* ucred = reinterpret_cast<const struct ucred*>(CMSG_DATA(header));
|
|
|
|
const bool is_valid_ucred =
|
|
[ucred]() -> bool
|
|
{
|
|
const auto& creds = Process::current().credentials();
|
|
if (creds.is_superuser())
|
|
return true;
|
|
if (ucred->pid != Process::current().pid())
|
|
return false;
|
|
if (ucred->uid != creds.ruid() && ucred->uid != creds.euid() && ucred->uid != creds.suid())
|
|
return false;
|
|
if (ucred->gid != creds.rgid() && !creds.has_egid(ucred->gid) && ucred->gid != creds.sgid())
|
|
return false;
|
|
return true;
|
|
}();
|
|
|
|
if (!is_valid_ucred)
|
|
return BAN::Error::from_errno(EPERM);
|
|
|
|
packet_info.ucred = *ucred;
|
|
|
|
break;
|
|
}
|
|
default:
|
|
dwarnln("ignoring control message with type {}", header->cmsg_type);
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (m_info.has<ConnectionInfo>())
|
|
{
|
|
auto& connection_info = m_info.get<ConnectionInfo>();
|
|
auto target = connection_info.connection.lock();
|
|
if (!target)
|
|
return BAN::Error::from_errno(ENOTCONN);
|
|
TRY(target->add_packet(message, BAN::move(packet_info)));
|
|
return total_message_size;
|
|
}
|
|
else
|
|
{
|
|
BAN::RefPtr<Inode> target_inode;
|
|
|
|
if (!message.msg_name || message.msg_namelen == 0)
|
|
{
|
|
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
|
|
if (connectionless_info.peer_address.empty())
|
|
return BAN::Error::from_errno(EDESTADDRREQ);
|
|
|
|
auto absolute_path = TRY(Process::current().absolute_path_of(connectionless_info.peer_address));
|
|
target_inode = TRY(VirtualFileSystem::get().file_from_absolute_path(
|
|
Process::current().root_file().inode,
|
|
Process::current().credentials(),
|
|
absolute_path,
|
|
O_RDWR
|
|
)).inode;
|
|
}
|
|
else
|
|
{
|
|
const auto sun_path = TRY(validate_sockaddr_un(static_cast<sockaddr*>(message.msg_name), message.msg_namelen));
|
|
auto absolute_path = TRY(Process::current().absolute_path_of(sun_path));
|
|
target_inode = TRY(VirtualFileSystem::get().file_from_absolute_path(
|
|
Process::current().root_file().inode,
|
|
Process::current().credentials(),
|
|
absolute_path,
|
|
O_WRONLY
|
|
)).inode;
|
|
}
|
|
|
|
BAN::RefPtr<UnixDomainSocket> target;
|
|
{
|
|
LockGuard _(s_bound_socket_lock);
|
|
auto it = s_bound_sockets.find(target_inode);
|
|
if (it == s_bound_sockets.end())
|
|
return BAN::Error::from_errno(EDESTADDRREQ);
|
|
target = it->value.lock();
|
|
}
|
|
|
|
if (!target)
|
|
return BAN::Error::from_errno(EDESTADDRREQ);
|
|
TRY(target->add_packet(message, BAN::move(packet_info)));
|
|
|
|
return total_message_size;
|
|
}
|
|
}
|
|
|
|
BAN::ErrorOr<void> UnixDomainSocket::getpeername_impl(sockaddr* address, socklen_t* address_len)
|
|
{
|
|
if (!m_info.has<ConnectionInfo>())
|
|
return BAN::Error::from_errno(ENOTCONN);
|
|
auto connection = m_info.get<ConnectionInfo>().connection.lock();
|
|
if (!connection)
|
|
return BAN::Error::from_errno(ENOTCONN);
|
|
|
|
sockaddr_un sa_un {
|
|
.sun_family = AF_UNIX,
|
|
.sun_path = {},
|
|
};
|
|
strcpy(sa_un.sun_path, connection->m_bound_file.canonical_path.data());
|
|
|
|
const size_t to_copy = BAN::Math::min<socklen_t>(sizeof(sockaddr_un), *address_len);
|
|
memcpy(address, &sa_un, to_copy);
|
|
*address_len = to_copy;
|
|
return {};
|
|
}
|
|
|
|
}
|