Kernel: Rewrite Sockets to not be TmpInodes

TmpInodes just caused issues because TmpFS kept them alive. There was
really no reason for sockets to even be stored inside a TmpFS...
This commit is contained in:
Bananymous 2024-06-27 00:35:19 +03:00
parent 44c7fde2f7
commit 31568fc5a1
15 changed files with 127 additions and 108 deletions

View File

@ -1,20 +1,59 @@
#pragma once
#include <kernel/FS/Inode.h>
namespace Kernel
{
enum class SocketDomain
class Socket : public Inode
{
public:
enum class Domain
{
INET,
INET6,
UNIX,
};
enum class SocketType
enum class Type
{
STREAM,
DGRAM,
SEQPACKET,
};
struct Info
{
mode_t mode;
uid_t uid;
gid_t gid;
};
public:
ino_t ino() const final override { ASSERT_NOT_REACHED(); }
Mode mode() const final override { return Mode(m_info.mode); }
nlink_t nlink() const final override { ASSERT_NOT_REACHED(); }
uid_t uid() const final override { return m_info.uid; }
gid_t gid() const final override { return m_info.gid; }
off_t size() const final override { ASSERT_NOT_REACHED(); }
timespec atime() const final override { ASSERT_NOT_REACHED(); }
timespec mtime() const final override { ASSERT_NOT_REACHED(); }
timespec ctime() const final override { ASSERT_NOT_REACHED(); }
blksize_t blksize() const final override { ASSERT_NOT_REACHED(); }
blkcnt_t blocks() const final override { ASSERT_NOT_REACHED(); }
dev_t dev() const final override { ASSERT_NOT_REACHED(); }
dev_t rdev() const final override { ASSERT_NOT_REACHED(); }
protected:
Socket(const Info& info)
: m_info(info)
{}
BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan buffer) override { return recvfrom_impl(buffer, nullptr, nullptr); }
BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override { return sendto_impl(buffer, nullptr, 0); }
private:
const Info m_info;
};
}

View File

@ -44,14 +44,14 @@ namespace Kernel
void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan);
virtual void unbind_socket(BAN::RefPtr<NetworkSocket>, uint16_t port) override;
virtual void unbind_socket(uint16_t port) override;
virtual BAN::ErrorOr<void> bind_socket_to_unused(BAN::RefPtr<NetworkSocket>, const sockaddr* send_address, socklen_t send_address_len) override;
virtual BAN::ErrorOr<void> bind_socket_to_address(BAN::RefPtr<NetworkSocket>, const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<void> get_socket_address(BAN::RefPtr<NetworkSocket>, sockaddr* address, socklen_t* address_len) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
virtual SocketDomain domain() const override { return SocketDomain::INET ;}
virtual Socket::Domain domain() const override { return Socket::Domain::INET ;}
virtual size_t header_size() const override { return sizeof(IPv4Header); }
private:

View File

@ -1,5 +1,6 @@
#pragma once
#include <kernel/FS/Socket.h>
#include <kernel/Networking/NetworkInterface.h>
namespace Kernel
@ -15,22 +16,20 @@ namespace Kernel
static_assert(sizeof(PseudoHeader) == 12);
class NetworkSocket;
enum class SocketDomain;
enum class SocketType;
class NetworkLayer
{
public:
virtual ~NetworkLayer() {}
virtual void unbind_socket(BAN::RefPtr<NetworkSocket>, uint16_t port) = 0;
virtual void unbind_socket(uint16_t port) = 0;
virtual BAN::ErrorOr<void> bind_socket_to_unused(BAN::RefPtr<NetworkSocket>, const sockaddr* send_address, socklen_t send_address_len) = 0;
virtual BAN::ErrorOr<void> bind_socket_to_address(BAN::RefPtr<NetworkSocket>, const sockaddr* address, socklen_t address_len) = 0;
virtual BAN::ErrorOr<void> get_socket_address(BAN::RefPtr<NetworkSocket>, sockaddr* address, socklen_t* address_len) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0;
virtual SocketDomain domain() const = 0;
virtual Socket::Domain domain() const = 0;
virtual size_t header_size() const = 0;
protected:

View File

@ -1,10 +1,9 @@
#pragma once
#include <BAN/Vector.h>
#include <kernel/FS/TmpFS/FileSystem.h>
#include <kernel/FS/Socket.h>
#include <kernel/Networking/IPv4Layer.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkSocket.h>
#include <kernel/PCI.h>
#include <netinet/in.h>
@ -12,7 +11,7 @@
namespace Kernel
{
class NetworkManager : public TmpFileSystem
class NetworkManager
{
BAN_NON_COPYABLE(NetworkManager);
BAN_NON_MOVABLE(NetworkManager);
@ -25,16 +24,18 @@ namespace Kernel
BAN::Vector<BAN::RefPtr<NetworkInterface>> interfaces() { return m_interfaces; }
BAN::ErrorOr<BAN::RefPtr<TmpInode>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t);
BAN::ErrorOr<BAN::RefPtr<Socket>> create_socket(Socket::Domain, Socket::Type, mode_t, uid_t, gid_t);
void on_receive(NetworkInterface&, BAN::ConstByteSpan);
private:
NetworkManager();
NetworkManager() {}
private:
BAN::UniqPtr<IPv4Layer> m_ipv4_layer;
BAN::Vector<BAN::RefPtr<NetworkInterface>> m_interfaces;
friend class BAN::UniqPtr<NetworkManager>;
};
}

View File

@ -2,7 +2,6 @@
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
@ -16,7 +15,7 @@ namespace Kernel
UDP = 0x11,
};
class NetworkSocket : public TmpInode, public BAN::Weakable<NetworkSocket>
class NetworkSocket : public Socket, public BAN::Weakable<NetworkSocket>
{
BAN_NON_COPYABLE(NetworkSocket);
BAN_NON_MOVABLE(NetworkSocket);
@ -39,7 +38,7 @@ namespace Kernel
bool is_bound() const { return m_interface != nullptr; }
protected:
NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
NetworkSocket(NetworkLayer&, const Socket::Info&);
virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override;
virtual BAN::ErrorOr<void> getsockname_impl(sockaddr*, socklen_t*) override;

View File

@ -46,7 +46,7 @@ namespace Kernel
static constexpr size_t m_tcp_options_bytes = 4;
public:
static BAN::ErrorOr<BAN::RefPtr<TCPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
static BAN::ErrorOr<BAN::RefPtr<TCPSocket>> create(NetworkLayer&, const Info&);
~TCPSocket();
virtual NetworkProtocol protocol() const override { return NetworkProtocol::TCP; }
@ -141,7 +141,7 @@ namespace Kernel
};
private:
TCPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
TCPSocket(NetworkLayer&, const Info&);
void process_task();
void start_close_sequence();

View File

@ -23,7 +23,7 @@ namespace Kernel
class UDPSocket final : public NetworkSocket
{
public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, const Socket::Info&);
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
@ -42,7 +42,7 @@ namespace Kernel
virtual bool has_error_impl() const override { return false; }
private:
UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
UDPSocket(NetworkLayer&, const Socket::Info&);
~UDPSocket();
struct PacketInfo

View File

@ -9,13 +9,13 @@
namespace Kernel
{
class UnixDomainSocket final : public TmpInode, public BAN::Weakable<UnixDomainSocket>
class UnixDomainSocket final : public Socket, public BAN::Weakable<UnixDomainSocket>
{
BAN_NON_COPYABLE(UnixDomainSocket);
BAN_NON_MOVABLE(UnixDomainSocket);
public:
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&);
protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
@ -30,7 +30,7 @@ namespace Kernel
virtual bool has_error_impl() const override { return false; }
private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);
UnixDomainSocket(Socket::Type, const Socket::Info&);
~UnixDomainSocket();
BAN::ErrorOr<void> add_packet(BAN::ConstByteSpan);
@ -58,7 +58,7 @@ namespace Kernel
};
private:
const SocketType m_socket_type;
const Socket::Type m_socket_type;
BAN::String m_bound_path;
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;

View File

@ -68,19 +68,13 @@ namespace Kernel
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {});
}
void IPv4Layer::unbind_socket(BAN::RefPtr<NetworkSocket> socket, uint16_t port)
{
void IPv4Layer::unbind_socket(uint16_t port)
{
SpinLockGuard _(m_bound_socket_lock);
auto it = m_bound_sockets.find(port);
if (it != m_bound_sockets.end())
{
ASSERT(it->value.lock() == socket);
ASSERT(it != m_bound_sockets.end());
m_bound_sockets.remove(it);
}
}
NetworkManager::get().TmpFileSystem::remove_from_cache(socket);
}
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_unused(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
{

View File

@ -19,11 +19,7 @@ namespace Kernel
BAN::ErrorOr<void> NetworkManager::initialize()
{
ASSERT(!s_instance);
NetworkManager* manager_ptr = new NetworkManager();
if (manager_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
auto manager = BAN::UniqPtr<NetworkManager>::adopt(manager_ptr);
TRY(manager->TmpFileSystem::initialize(0777, 0, 0));
auto manager = TRY(BAN::UniqPtr<NetworkManager>::create());
manager->m_ipv4_layer = TRY(IPv4Layer::create());
s_instance = BAN::move(manager);
return {};
@ -35,10 +31,6 @@ namespace Kernel
return *s_instance;
}
NetworkManager::NetworkManager()
: TmpFileSystem(128)
{ }
BAN::ErrorOr<void> NetworkManager::add_interface(PCI::Device& pci_device)
{
BAN::RefPtr<NetworkInterface> interface;
@ -72,21 +64,21 @@ namespace Kernel
return {};
}
BAN::ErrorOr<BAN::RefPtr<TmpInode>> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid)
BAN::ErrorOr<BAN::RefPtr<Socket>> NetworkManager::create_socket(Socket::Domain domain, Socket::Type type, mode_t mode, uid_t uid, gid_t gid)
{
switch (domain)
{
case SocketDomain::INET:
case Socket::Domain::INET:
switch (type)
{
case SocketType::DGRAM:
case SocketType::STREAM:
case Socket::Type::DGRAM:
case Socket::Type::STREAM:
break;
default:
return BAN::Error::from_errno(EPROTOTYPE);
}
break;
case SocketDomain::UNIX:
case Socket::Domain::UNIX:
break;
default:
return BAN::Error::from_errno(EAFNOSUPPORT);
@ -95,30 +87,28 @@ namespace Kernel
ASSERT((mode & Inode::Mode::TYPE_MASK) == 0);
mode |= Inode::Mode::IFSOCK;
auto inode_info = create_inode_info(mode, uid, gid);
ino_t ino = TRY(allocate_inode(inode_info));
BAN::RefPtr<TmpInode> socket;
auto socket_info = Socket::Info { .mode = mode, .uid = uid, .gid = gid };
BAN::RefPtr<Socket> socket;
switch (domain)
{
case SocketDomain::INET:
case Socket::Domain::INET:
{
switch (type)
{
case SocketType::DGRAM:
socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info));
case Socket::Type::DGRAM:
socket = TRY(UDPSocket::create(*m_ipv4_layer, socket_info));
break;
case SocketType::STREAM:
socket = TRY(TCPSocket::create(*m_ipv4_layer, ino, inode_info));
case Socket::Type::STREAM:
socket = TRY(TCPSocket::create(*m_ipv4_layer, socket_info));
break;
default:
ASSERT_NOT_REACHED();
}
break;
}
case SocketDomain::UNIX:
case Socket::Domain::UNIX:
{
socket = TRY(UnixDomainSocket::create(type, ino, inode_info));
socket = TRY(UnixDomainSocket::create(type, socket_info));
break;
}
default:

View File

@ -6,8 +6,8 @@
namespace Kernel
{
NetworkSocket::NetworkSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
NetworkSocket::NetworkSocket(NetworkLayer& network_layer, const Socket::Info& info)
: Socket(info)
, m_network_layer(network_layer)
{ }

View File

@ -23,9 +23,9 @@ namespace Kernel
static constexpr size_t s_window_buffer_size = 15 * PAGE_SIZE;
static_assert(s_window_buffer_size <= UINT16_MAX);
BAN::ErrorOr<BAN::RefPtr<TCPSocket>> TCPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
BAN::ErrorOr<BAN::RefPtr<TCPSocket>> TCPSocket::create(NetworkLayer& network_layer, const Info& info)
{
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, ino, inode_info));
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, info));
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
@ -48,11 +48,13 @@ namespace Kernel
reinterpret_cast<TCPSocket*>(socket_ptr)->process_task();
}, socket.ptr()
);
// hack to keep socket alive until its process starts
socket->ref();
return socket;
}
TCPSocket::TCPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: NetworkSocket(network_layer, ino, inode_info)
TCPSocket::TCPSocket(NetworkLayer& network_layer, const Info& info)
: NetworkSocket(network_layer, info)
{
m_send_window.start_seq = Random::get_u32() & 0x7FFFFFFF;
m_send_window.current_seq = m_send_window.start_seq;
@ -89,7 +91,7 @@ namespace Kernel
BAN::RefPtr<TCPSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(m_network_layer.domain(), SocketType::STREAM, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(m_network_layer.domain(), Socket::Type::STREAM, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = static_cast<TCPSocket*>(return_inode_tmp.ptr());
}
@ -605,13 +607,9 @@ namespace Kernel
// NOTE: Only listen socket can unbind the socket as
// listen socket is always alive to redirect packets
if (!m_listen_parent)
m_network_layer.unbind_socket(this, m_port);
m_network_layer.unbind_socket(m_port);
else
{
m_listen_parent->remove_listen_child(this);
// Listen children are not actually bound, so they have to be manually removed
NetworkManager::get().TmpFileSystem::remove_from_cache(this);
}
m_interface = nullptr;
m_port = PORT_NONE;
dprintln_if(DEBUG_TCP, "Socket unbound");
@ -643,6 +641,7 @@ namespace Kernel
static constexpr uint32_t retransmit_timeout_ms = 1000;
BAN::RefPtr<TCPSocket> keep_alive { this };
this->unref();
while (m_process)
{
@ -657,8 +656,8 @@ namespace Kernel
continue;
}
// This is the last instance (one instance in network manager and another keep_alive)
if (ref_count() == 2)
// This is the last instance
if (ref_count() == 1)
{
if (m_state == State::Listen)
{

View File

@ -5,9 +5,9 @@
namespace Kernel
{
BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(NetworkLayer& network_layer, const Socket::Info& info)
{
auto socket = TRY(BAN::RefPtr<UDPSocket>::create(network_layer, ino, inode_info));
auto socket = TRY(BAN::RefPtr<UDPSocket>::create(network_layer, info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
@ -19,14 +19,14 @@ namespace Kernel
return socket;
}
UDPSocket::UDPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: NetworkSocket(network_layer, ino, inode_info)
UDPSocket::UDPSocket(NetworkLayer& network_layer, const Socket::Info& info)
: NetworkSocket(network_layer, info)
{ }
UDPSocket::~UDPSocket()
{
if (is_bound())
m_network_layer.unbind_socket(this, m_port);
m_network_layer.unbind_socket(m_port);
m_port = PORT_NONE;
m_interface = nullptr;
}

View File

@ -15,9 +15,9 @@ namespace Kernel
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
{
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, ino, inode_info));
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
@ -29,17 +29,17 @@ namespace Kernel
return socket;
}
UnixDomainSocket::UnixDomainSocket(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
: Socket(info)
, m_socket_type(socket_type)
{
switch (socket_type)
{
case SocketType::STREAM:
case SocketType::SEQPACKET:
case Socket::Type::STREAM:
case Socket::Type::SEQPACKET:
m_info.emplace<ConnectionInfo>();
break;
case SocketType::DGRAM:
case Socket::Type::DGRAM:
m_info.emplace<ConnectionlessInfo>();
break;
default:
@ -55,7 +55,6 @@ namespace Kernel
auto it = s_bound_sockets.find(m_bound_path);
if (it != s_bound_sockets.end())
s_bound_sockets.remove(it);
m_bound_path.clear();
}
if (m_info.has<ConnectionInfo>())
{
@ -63,7 +62,6 @@ namespace Kernel
if (auto connection = connection_info.connection.lock(); connection && connection->m_info.has<ConnectionInfo>())
connection->m_info.get<ConnectionInfo>().target_closed = true;
}
m_info.clear();
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len)
@ -89,7 +87,7 @@ namespace Kernel
BAN::RefPtr<UnixDomainSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(SocketDomain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(Socket::Domain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
}
@ -227,10 +225,10 @@ namespace Kernel
{
switch (m_socket_type)
{
case SocketType::STREAM:
case Socket::Type::STREAM:
return true;
case SocketType::SEQPACKET:
case SocketType::DGRAM:
case Socket::Type::SEQPACKET:
case Socket::Type::DGRAM:
return false;
default:
ASSERT_NOT_REACHED();

View File

@ -97,38 +97,38 @@ namespace Kernel
{
bool valid_protocol = true;
SocketDomain sock_domain;
Socket::Domain sock_domain;
switch (domain)
{
case AF_INET:
sock_domain = SocketDomain::INET;
sock_domain = Socket::Domain::INET;
break;
case AF_INET6:
sock_domain = SocketDomain::INET6;
sock_domain = Socket::Domain::INET6;
break;
case AF_UNIX:
sock_domain = SocketDomain::UNIX;
sock_domain = Socket::Domain::UNIX;
valid_protocol = false;
break;
default:
return BAN::Error::from_errno(EPROTOTYPE);
}
SocketType sock_type;
Socket::Type sock_type;
switch (type)
{
case SOCK_STREAM:
sock_type = SocketType::STREAM;
sock_type = Socket::Type::STREAM;
if (protocol != IPPROTO_TCP)
valid_protocol = false;
break;
case SOCK_DGRAM:
sock_type = SocketType::DGRAM;
sock_type = Socket::Type::DGRAM;
if (protocol != IPPROTO_UDP)
valid_protocol = false;
break;
case SOCK_SEQPACKET:
sock_type = SocketType::SEQPACKET;
sock_type = Socket::Type::SEQPACKET;
valid_protocol = false;
break;
default: