Kernel: Implement basic connection-mode unix domain sockets

This commit is contained in:
2024-02-08 02:28:19 +02:00
parent 0c8e9fe095
commit e7dd03e551
13 changed files with 454 additions and 22 deletions

View File

@@ -100,7 +100,7 @@ namespace Kernel
BAN::ErrorOr<BAN::String> link_target();
// Socket API
BAN::ErrorOr<void> accept(sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<long> accept(sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> listen(int backlog);
@@ -131,7 +131,7 @@ namespace Kernel
virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); }
// Socket API
virtual BAN::ErrorOr<void> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }

View File

@@ -0,0 +1,20 @@
#pragma once
namespace Kernel
{
enum class SocketDomain
{
INET,
INET6,
UNIX,
};
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
}

View File

@@ -25,7 +25,7 @@ namespace Kernel
BAN::Vector<BAN::RefPtr<NetworkInterface>> interfaces() { return m_interfaces; }
BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t);
BAN::ErrorOr<BAN::RefPtr<TmpInode>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t);
void on_receive(NetworkInterface&, BAN::ConstByteSpan);

View File

@@ -1,6 +1,7 @@
#pragma once
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
@@ -16,20 +17,6 @@ namespace Kernel
UDP = 0x11,
};
enum class SocketDomain
{
INET,
INET6,
UNIX,
};
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
class NetworkSocket : public TmpInode, public BAN::Weakable<NetworkSocket>
{
BAN_NON_COPYABLE(NetworkSocket);

View File

@@ -0,0 +1,67 @@
#pragma once
#include <BAN/Queue.h>
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
namespace Kernel
{
class UnixDomainSocket final : public TmpInode, public BAN::Weakable<UnixDomainSocket>
{
BAN_NON_COPYABLE(UnixDomainSocket);
BAN_NON_MOVABLE(UnixDomainSocket);
public:
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);
BAN::ErrorOr<void> add_packet(BAN::ConstByteSpan);
bool is_bound() const { return !m_bound_path.empty(); }
bool is_bound_to_unused() const { return m_bound_path == "X"sv; }
bool is_streaming() const;
private:
struct ConnectionInfo
{
bool listening { false };
BAN::Atomic<bool> connection_done { false };
BAN::WeakPtr<UnixDomainSocket> connection;
BAN::Queue<BAN::RefPtr<UnixDomainSocket>> pending_connections;
Semaphore pending_semaphore;
SpinLock pending_lock;
};
struct ConnectionlessInfo
{
};
private:
const SocketType m_socket_type;
BAN::String m_bound_path;
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;
BAN::CircularQueue<size_t, 128> m_packet_sizes;
size_t m_packet_size_total { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
Semaphore m_packet_semaphore;
friend class BAN::RefPtr<UnixDomainSocket>;
};
}

View File

@@ -21,6 +21,7 @@ namespace Kernel
BAN::ErrorOr<void> clone_from(const OpenFileDescriptorSet&);
BAN::ErrorOr<int> open(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags);
BAN::ErrorOr<int> socket(int domain, int type, int protocol);

View File

@@ -95,6 +95,8 @@ namespace Kernel
BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); }
BAN::ErrorOr<long> sys_getpgid(pid_t);
BAN::ErrorOr<long> open_inode(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode);
BAN::ErrorOr<long> open_file(BAN::StringView path, int oflag, mode_t = 0);
BAN::ErrorOr<long> sys_open(const char* path, int, mode_t);