Compare commits

...

6 Commits

Author SHA1 Message Date
Bananymous 010c2c934b BAN: Write RefPtr and WeakPtr to be thread safe 2024-06-28 22:00:29 +03:00
Bananymous 48a76426e7 BAN: Add more APIs for Atomic and make compare_exchage take a reference 2024-06-28 21:47:47 +03:00
Bananymous 0c645ba867 LibGUI: Window now uses double buffering
This allows data in shared memory object be always up to date. With this
change window server can update lazily, and not necessarily on all
invalidate calls
2024-06-27 00:39:59 +03:00
Bananymous f538dd5276 test-tcp: Fix printing of "connection reset" when tcp connection closed 2024-06-27 00:39:22 +03:00
Bananymous 31568fc5a1 Kernel: Rewrite Sockets to not be TmpInodes
TmpInodes just caused issues because TmpFS kept them alive. There was
really no reason for sockets to even be stored inside a TmpFS...
2024-06-27 00:35:19 +03:00
Bananymous 44c7fde2f7 BAN: Fix Function requires clause argument forwariding 2024-06-27 00:33:50 +03:00
24 changed files with 236 additions and 150 deletions

View File

@ -45,9 +45,23 @@ namespace BAN
inline T operator--(int) volatile { return __atomic_fetch_sub(&m_value, 1, MEM_ORDER); }
inline T operator++(int) volatile { return __atomic_fetch_add(&m_value, 1, MEM_ORDER); }
inline bool compare_exchange(T expected, T desired, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_compare_exchange_n(&m_value, &expected, desired, false, mem_order, mem_order); }
inline bool compare_exchange(T& expected, T desired, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_compare_exchange_n(&m_value, &expected, desired, false, mem_order, mem_order); }
inline T exchange(T desired, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_exchange_n(&m_value, desired, mem_order); };
inline T add_fetch (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_add_fetch (&m_value, val, mem_order); }
inline T sub_fetch (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_sub_fetch (&m_value, val, mem_order); }
inline T and_fetch (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_and_fetch (&m_value, val, mem_order); }
inline T xor_fetch (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_xor_fetch (&m_value, val, mem_order); }
inline T or_fetch (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_or_fetch (&m_value, val, mem_order); }
inline T nand_fetch(T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_nand_fetch(&m_value, val, mem_order); }
inline T fetch_add (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_fetch_add (&m_value, val, mem_order); }
inline T fetch_sub (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_fetch_sub (&m_value, val, mem_order); }
inline T fetch_and (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_fetch_and (&m_value, val, mem_order); }
inline T fetch_xor (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_fetch_xor (&m_value, val, mem_order); }
inline T fetch_or (T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_fetch__or (&m_value, val, mem_order); }
inline T fetch_nand(T val, MemoryOrder mem_order = MEM_ORDER) volatile { return __atomic_nfetch_and(&m_value, val, mem_order); }
private:
T m_value;
};

View File

@ -32,7 +32,7 @@ namespace BAN
new (m_storage) CallableMemberConst<Own>(function, owner);
}
template<typename Lambda>
Function(Lambda lambda) requires requires(Lambda lamda, Args... args) { { lambda(args...) } -> BAN::same_as<Ret>; }
Function(Lambda lambda) requires requires(Lambda lamda, Args&&... args) { { lambda(forward<Args>(args)...) } -> BAN::same_as<Ret>; }
{
static_assert(sizeof(CallableLambda<Lambda>) <= m_size);
new (m_storage) CallableLambda<Lambda>(lambda);

View File

@ -1,5 +1,6 @@
#pragma once
#include <BAN/Atomic.h>
#include <BAN/Errors.h>
#include <BAN/Move.h>
#include <BAN/NoCopyMove.h>
@ -22,15 +23,27 @@ namespace BAN
void ref() const
{
ASSERT(m_ref_count > 0);
m_ref_count++;
uint32_t old = m_ref_count.fetch_add(1, MemoryOrder::memory_order_relaxed);
ASSERT(old > 0);
}
bool try_ref() const
{
uint32_t expected = m_ref_count.load(MemoryOrder::memory_order_relaxed);
for (;;)
{
if (expected == 0)
return false;
if (m_ref_count.compare_exchange(expected, expected + 1, MemoryOrder::memory_order_acquire))
return true;
}
}
void unref() const
{
ASSERT(m_ref_count > 0);
m_ref_count--;
if (m_ref_count == 0)
uint32_t old = m_ref_count.fetch_sub(1);
ASSERT(old > 0);
if (old == 1)
delete (const T*)this;
}
@ -39,7 +52,7 @@ namespace BAN
virtual ~RefCounted() { ASSERT(m_ref_count == 0); }
private:
mutable uint32_t m_ref_count = 1;
mutable Atomic<uint32_t> m_ref_count = 1;
};
template<typename T>

View File

@ -2,6 +2,10 @@
#include <BAN/RefPtr.h>
#if __is_kernel
#include <kernel/Lock/SpinLock.h>
#endif
namespace BAN
{
@ -11,22 +15,37 @@ namespace BAN
template<typename T>
class WeakPtr;
// FIXME: Write this without using locks...
template<typename T>
class WeakLink : public RefCounted<WeakLink<T>>
{
public:
RefPtr<T> lock() { ASSERT(m_ptr); return raw_ptr(); }
T* raw_ptr() { return m_ptr; }
RefPtr<T> try_lock()
{
#if __is_kernel
Kernel::SpinLockGuard _(m_weak_lock);
#endif
if (m_ptr && m_ptr->try_ref())
return RefPtr<T>::adopt(m_ptr);
return nullptr;
}
bool valid() const { return m_ptr; }
void invalidate() { m_ptr = nullptr; }
void invalidate()
{
#if __is_kernel
Kernel::SpinLockGuard _(m_weak_lock);
#endif
m_ptr = nullptr;
}
private:
WeakLink(T* ptr) : m_ptr(ptr) {}
private:
T* m_ptr;
#if __is_kernel
Kernel::SpinLock m_weak_lock;
#endif
friend class RefPtr<WeakLink<T>>;
};
@ -82,8 +101,8 @@ namespace BAN
RefPtr<T> lock()
{
if (valid())
return m_link->lock();
if (m_link)
return m_link->try_lock();
return nullptr;
}

View File

@ -1,20 +1,59 @@
#pragma once
#include <kernel/FS/Inode.h>
namespace Kernel
{
enum class SocketDomain
class Socket : public Inode
{
INET,
INET6,
UNIX,
};
public:
enum class Domain
{
INET,
INET6,
UNIX,
};
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
enum class Type
{
STREAM,
DGRAM,
SEQPACKET,
};
struct Info
{
mode_t mode;
uid_t uid;
gid_t gid;
};
public:
ino_t ino() const final override { ASSERT_NOT_REACHED(); }
Mode mode() const final override { return Mode(m_info.mode); }
nlink_t nlink() const final override { ASSERT_NOT_REACHED(); }
uid_t uid() const final override { return m_info.uid; }
gid_t gid() const final override { return m_info.gid; }
off_t size() const final override { ASSERT_NOT_REACHED(); }
timespec atime() const final override { ASSERT_NOT_REACHED(); }
timespec mtime() const final override { ASSERT_NOT_REACHED(); }
timespec ctime() const final override { ASSERT_NOT_REACHED(); }
blksize_t blksize() const final override { ASSERT_NOT_REACHED(); }
blkcnt_t blocks() const final override { ASSERT_NOT_REACHED(); }
dev_t dev() const final override { ASSERT_NOT_REACHED(); }
dev_t rdev() const final override { ASSERT_NOT_REACHED(); }
protected:
Socket(const Info& info)
: m_info(info)
{}
BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan buffer) override { return recvfrom_impl(buffer, nullptr, nullptr); }
BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override { return sendto_impl(buffer, nullptr, 0); }
private:
const Info m_info;
};
}

View File

@ -24,8 +24,12 @@ namespace Kernel
ASSERT(m_lock_depth > 0);
else
{
while (!m_locker.compare_exchange(-1, tid))
pid_t expected = -1;
while (!m_locker.compare_exchange(expected, tid))
{
Scheduler::get().yield();
expected = -1;
}
ASSERT(m_lock_depth == 0);
if (Scheduler::current_tid())
Thread::current().add_mutex();
@ -40,7 +44,8 @@ namespace Kernel
ASSERT(m_lock_depth > 0);
else
{
if (!m_locker.compare_exchange(-1, tid))
pid_t expected = -1;
if (!m_locker.compare_exchange(expected, tid))
return false;
ASSERT(m_lock_depth == 0);
if (Scheduler::current_tid())
@ -89,8 +94,12 @@ namespace Kernel
bool has_priority = tid ? !Thread::current().is_userspace() : true;
if (has_priority)
m_queue_length++;
while (!(has_priority || m_queue_length == 0) || !m_locker.compare_exchange(-1, tid))
pid_t expected = -1;
while (!(has_priority || m_queue_length == 0) || !m_locker.compare_exchange(expected, tid))
{
Scheduler::get().yield();
expected = -1;
}
ASSERT(m_lock_depth == 0);
if (Scheduler::current_tid())
Thread::current().add_mutex();
@ -106,7 +115,8 @@ namespace Kernel
else
{
bool has_priority = tid ? !Thread::current().is_userspace() : true;
if (!(has_priority || m_queue_length == 0) || !m_locker.compare_exchange(-1, tid))
pid_t expected = -1;
if (!(has_priority || m_queue_length == 0) || !m_locker.compare_exchange(expected, tid))
return false;
if (has_priority)
m_queue_length++;

View File

@ -26,8 +26,12 @@ namespace Kernel
auto id = Processor::current_id();
ASSERT(m_locker != id);
while (!m_locker.compare_exchange(PROCESSOR_NONE, id, BAN::MemoryOrder::memory_order_acquire))
ProcessorID expected = PROCESSOR_NONE;
while (!m_locker.compare_exchange(expected, id, BAN::MemoryOrder::memory_order_acquire))
{
__builtin_ia32_pause();
expected = PROCESSOR_NONE;
}
return state;
}
@ -67,8 +71,12 @@ namespace Kernel
ASSERT(m_lock_depth > 0);
else
{
while (!m_locker.compare_exchange(PROCESSOR_NONE, id, BAN::MemoryOrder::memory_order_acquire))
ProcessorID expected = PROCESSOR_NONE;
while (!m_locker.compare_exchange(expected, id, BAN::MemoryOrder::memory_order_acquire))
{
__builtin_ia32_pause();
expected = PROCESSOR_NONE;
}
ASSERT(m_lock_depth == 0);
}

View File

@ -44,14 +44,14 @@ namespace Kernel
void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan);
virtual void unbind_socket(BAN::RefPtr<NetworkSocket>, uint16_t port) override;
virtual void unbind_socket(uint16_t port) override;
virtual BAN::ErrorOr<void> bind_socket_to_unused(BAN::RefPtr<NetworkSocket>, const sockaddr* send_address, socklen_t send_address_len) override;
virtual BAN::ErrorOr<void> bind_socket_to_address(BAN::RefPtr<NetworkSocket>, const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<void> get_socket_address(BAN::RefPtr<NetworkSocket>, sockaddr* address, socklen_t* address_len) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
virtual SocketDomain domain() const override { return SocketDomain::INET ;}
virtual Socket::Domain domain() const override { return Socket::Domain::INET ;}
virtual size_t header_size() const override { return sizeof(IPv4Header); }
private:

View File

@ -1,5 +1,6 @@
#pragma once
#include <kernel/FS/Socket.h>
#include <kernel/Networking/NetworkInterface.h>
namespace Kernel
@ -15,22 +16,20 @@ namespace Kernel
static_assert(sizeof(PseudoHeader) == 12);
class NetworkSocket;
enum class SocketDomain;
enum class SocketType;
class NetworkLayer
{
public:
virtual ~NetworkLayer() {}
virtual void unbind_socket(BAN::RefPtr<NetworkSocket>, uint16_t port) = 0;
virtual void unbind_socket(uint16_t port) = 0;
virtual BAN::ErrorOr<void> bind_socket_to_unused(BAN::RefPtr<NetworkSocket>, const sockaddr* send_address, socklen_t send_address_len) = 0;
virtual BAN::ErrorOr<void> bind_socket_to_address(BAN::RefPtr<NetworkSocket>, const sockaddr* address, socklen_t address_len) = 0;
virtual BAN::ErrorOr<void> get_socket_address(BAN::RefPtr<NetworkSocket>, sockaddr* address, socklen_t* address_len) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0;
virtual SocketDomain domain() const = 0;
virtual Socket::Domain domain() const = 0;
virtual size_t header_size() const = 0;
protected:

View File

@ -1,10 +1,9 @@
#pragma once
#include <BAN/Vector.h>
#include <kernel/FS/TmpFS/FileSystem.h>
#include <kernel/FS/Socket.h>
#include <kernel/Networking/IPv4Layer.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkSocket.h>
#include <kernel/PCI.h>
#include <netinet/in.h>
@ -12,7 +11,7 @@
namespace Kernel
{
class NetworkManager : public TmpFileSystem
class NetworkManager
{
BAN_NON_COPYABLE(NetworkManager);
BAN_NON_MOVABLE(NetworkManager);
@ -25,16 +24,18 @@ namespace Kernel
BAN::Vector<BAN::RefPtr<NetworkInterface>> interfaces() { return m_interfaces; }
BAN::ErrorOr<BAN::RefPtr<TmpInode>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t);
BAN::ErrorOr<BAN::RefPtr<Socket>> create_socket(Socket::Domain, Socket::Type, mode_t, uid_t, gid_t);
void on_receive(NetworkInterface&, BAN::ConstByteSpan);
private:
NetworkManager();
NetworkManager() {}
private:
BAN::UniqPtr<IPv4Layer> m_ipv4_layer;
BAN::Vector<BAN::RefPtr<NetworkInterface>> m_interfaces;
friend class BAN::UniqPtr<NetworkManager>;
};
}

View File

@ -2,7 +2,6 @@
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
@ -16,7 +15,7 @@ namespace Kernel
UDP = 0x11,
};
class NetworkSocket : public TmpInode, public BAN::Weakable<NetworkSocket>
class NetworkSocket : public Socket, public BAN::Weakable<NetworkSocket>
{
BAN_NON_COPYABLE(NetworkSocket);
BAN_NON_MOVABLE(NetworkSocket);
@ -39,7 +38,7 @@ namespace Kernel
bool is_bound() const { return m_interface != nullptr; }
protected:
NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
NetworkSocket(NetworkLayer&, const Socket::Info&);
virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override;
virtual BAN::ErrorOr<void> getsockname_impl(sockaddr*, socklen_t*) override;

View File

@ -46,7 +46,7 @@ namespace Kernel
static constexpr size_t m_tcp_options_bytes = 4;
public:
static BAN::ErrorOr<BAN::RefPtr<TCPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
static BAN::ErrorOr<BAN::RefPtr<TCPSocket>> create(NetworkLayer&, const Info&);
~TCPSocket();
virtual NetworkProtocol protocol() const override { return NetworkProtocol::TCP; }
@ -141,7 +141,7 @@ namespace Kernel
};
private:
TCPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
TCPSocket(NetworkLayer&, const Info&);
void process_task();
void start_close_sequence();

View File

@ -23,7 +23,7 @@ namespace Kernel
class UDPSocket final : public NetworkSocket
{
public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, const Socket::Info&);
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
@ -42,7 +42,7 @@ namespace Kernel
virtual bool has_error_impl() const override { return false; }
private:
UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
UDPSocket(NetworkLayer&, const Socket::Info&);
~UDPSocket();
struct PacketInfo

View File

@ -9,13 +9,13 @@
namespace Kernel
{
class UnixDomainSocket final : public TmpInode, public BAN::Weakable<UnixDomainSocket>
class UnixDomainSocket final : public Socket, public BAN::Weakable<UnixDomainSocket>
{
BAN_NON_COPYABLE(UnixDomainSocket);
BAN_NON_MOVABLE(UnixDomainSocket);
public:
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(Socket::Type, const Socket::Info&);
protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
@ -30,7 +30,7 @@ namespace Kernel
virtual bool has_error_impl() const override { return false; }
private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);
UnixDomainSocket(Socket::Type, const Socket::Info&);
~UnixDomainSocket();
BAN::ErrorOr<void> add_packet(BAN::ConstByteSpan);
@ -58,7 +58,7 @@ namespace Kernel
};
private:
const SocketType m_socket_type;
const Socket::Type m_socket_type;
BAN::String m_bound_path;
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;

View File

@ -68,18 +68,12 @@ namespace Kernel
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {});
}
void IPv4Layer::unbind_socket(BAN::RefPtr<NetworkSocket> socket, uint16_t port)
void IPv4Layer::unbind_socket(uint16_t port)
{
{
SpinLockGuard _(m_bound_socket_lock);
auto it = m_bound_sockets.find(port);
if (it != m_bound_sockets.end())
{
ASSERT(it->value.lock() == socket);
m_bound_sockets.remove(it);
}
}
NetworkManager::get().TmpFileSystem::remove_from_cache(socket);
SpinLockGuard _(m_bound_socket_lock);
auto it = m_bound_sockets.find(port);
ASSERT(it != m_bound_sockets.end());
m_bound_sockets.remove(it);
}
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_unused(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)

View File

@ -19,11 +19,7 @@ namespace Kernel
BAN::ErrorOr<void> NetworkManager::initialize()
{
ASSERT(!s_instance);
NetworkManager* manager_ptr = new NetworkManager();
if (manager_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
auto manager = BAN::UniqPtr<NetworkManager>::adopt(manager_ptr);
TRY(manager->TmpFileSystem::initialize(0777, 0, 0));
auto manager = TRY(BAN::UniqPtr<NetworkManager>::create());
manager->m_ipv4_layer = TRY(IPv4Layer::create());
s_instance = BAN::move(manager);
return {};
@ -35,10 +31,6 @@ namespace Kernel
return *s_instance;
}
NetworkManager::NetworkManager()
: TmpFileSystem(128)
{ }
BAN::ErrorOr<void> NetworkManager::add_interface(PCI::Device& pci_device)
{
BAN::RefPtr<NetworkInterface> interface;
@ -72,21 +64,21 @@ namespace Kernel
return {};
}
BAN::ErrorOr<BAN::RefPtr<TmpInode>> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid)
BAN::ErrorOr<BAN::RefPtr<Socket>> NetworkManager::create_socket(Socket::Domain domain, Socket::Type type, mode_t mode, uid_t uid, gid_t gid)
{
switch (domain)
{
case SocketDomain::INET:
case Socket::Domain::INET:
switch (type)
{
case SocketType::DGRAM:
case SocketType::STREAM:
case Socket::Type::DGRAM:
case Socket::Type::STREAM:
break;
default:
return BAN::Error::from_errno(EPROTOTYPE);
}
break;
case SocketDomain::UNIX:
case Socket::Domain::UNIX:
break;
default:
return BAN::Error::from_errno(EAFNOSUPPORT);
@ -95,30 +87,28 @@ namespace Kernel
ASSERT((mode & Inode::Mode::TYPE_MASK) == 0);
mode |= Inode::Mode::IFSOCK;
auto inode_info = create_inode_info(mode, uid, gid);
ino_t ino = TRY(allocate_inode(inode_info));
BAN::RefPtr<TmpInode> socket;
auto socket_info = Socket::Info { .mode = mode, .uid = uid, .gid = gid };
BAN::RefPtr<Socket> socket;
switch (domain)
{
case SocketDomain::INET:
case Socket::Domain::INET:
{
switch (type)
{
case SocketType::DGRAM:
socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info));
case Socket::Type::DGRAM:
socket = TRY(UDPSocket::create(*m_ipv4_layer, socket_info));
break;
case SocketType::STREAM:
socket = TRY(TCPSocket::create(*m_ipv4_layer, ino, inode_info));
case Socket::Type::STREAM:
socket = TRY(TCPSocket::create(*m_ipv4_layer, socket_info));
break;
default:
ASSERT_NOT_REACHED();
}
break;
}
case SocketDomain::UNIX:
case Socket::Domain::UNIX:
{
socket = TRY(UnixDomainSocket::create(type, ino, inode_info));
socket = TRY(UnixDomainSocket::create(type, socket_info));
break;
}
default:

View File

@ -6,8 +6,8 @@
namespace Kernel
{
NetworkSocket::NetworkSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
NetworkSocket::NetworkSocket(NetworkLayer& network_layer, const Socket::Info& info)
: Socket(info)
, m_network_layer(network_layer)
{ }

View File

@ -23,9 +23,9 @@ namespace Kernel
static constexpr size_t s_window_buffer_size = 15 * PAGE_SIZE;
static_assert(s_window_buffer_size <= UINT16_MAX);
BAN::ErrorOr<BAN::RefPtr<TCPSocket>> TCPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
BAN::ErrorOr<BAN::RefPtr<TCPSocket>> TCPSocket::create(NetworkLayer& network_layer, const Info& info)
{
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, ino, inode_info));
auto socket = TRY(BAN::RefPtr<TCPSocket>::create(network_layer, info));
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
@ -48,11 +48,13 @@ namespace Kernel
reinterpret_cast<TCPSocket*>(socket_ptr)->process_task();
}, socket.ptr()
);
// hack to keep socket alive until its process starts
socket->ref();
return socket;
}
TCPSocket::TCPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: NetworkSocket(network_layer, ino, inode_info)
TCPSocket::TCPSocket(NetworkLayer& network_layer, const Info& info)
: NetworkSocket(network_layer, info)
{
m_send_window.start_seq = Random::get_u32() & 0x7FFFFFFF;
m_send_window.current_seq = m_send_window.start_seq;
@ -89,7 +91,7 @@ namespace Kernel
BAN::RefPtr<TCPSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(m_network_layer.domain(), SocketType::STREAM, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(m_network_layer.domain(), Socket::Type::STREAM, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = static_cast<TCPSocket*>(return_inode_tmp.ptr());
}
@ -605,13 +607,9 @@ namespace Kernel
// NOTE: Only listen socket can unbind the socket as
// listen socket is always alive to redirect packets
if (!m_listen_parent)
m_network_layer.unbind_socket(this, m_port);
m_network_layer.unbind_socket(m_port);
else
{
m_listen_parent->remove_listen_child(this);
// Listen children are not actually bound, so they have to be manually removed
NetworkManager::get().TmpFileSystem::remove_from_cache(this);
}
m_interface = nullptr;
m_port = PORT_NONE;
dprintln_if(DEBUG_TCP, "Socket unbound");
@ -643,6 +641,7 @@ namespace Kernel
static constexpr uint32_t retransmit_timeout_ms = 1000;
BAN::RefPtr<TCPSocket> keep_alive { this };
this->unref();
while (m_process)
{
@ -657,8 +656,8 @@ namespace Kernel
continue;
}
// This is the last instance (one instance in network manager and another keep_alive)
if (ref_count() == 2)
// This is the last instance
if (ref_count() == 1)
{
if (m_state == State::Listen)
{

View File

@ -5,9 +5,9 @@
namespace Kernel
{
BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(NetworkLayer& network_layer, const Socket::Info& info)
{
auto socket = TRY(BAN::RefPtr<UDPSocket>::create(network_layer, ino, inode_info));
auto socket = TRY(BAN::RefPtr<UDPSocket>::create(network_layer, info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
@ -19,14 +19,14 @@ namespace Kernel
return socket;
}
UDPSocket::UDPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: NetworkSocket(network_layer, ino, inode_info)
UDPSocket::UDPSocket(NetworkLayer& network_layer, const Socket::Info& info)
: NetworkSocket(network_layer, info)
{ }
UDPSocket::~UDPSocket()
{
if (is_bound())
m_network_layer.unbind_socket(this, m_port);
m_network_layer.unbind_socket(m_port);
m_port = PORT_NONE;
m_interface = nullptr;
}

View File

@ -15,9 +15,9 @@ namespace Kernel
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(Socket::Type socket_type, const Socket::Info& info)
{
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, ino, inode_info));
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
@ -29,17 +29,17 @@ namespace Kernel
return socket;
}
UnixDomainSocket::UnixDomainSocket(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
UnixDomainSocket::UnixDomainSocket(Socket::Type socket_type, const Socket::Info& info)
: Socket(info)
, m_socket_type(socket_type)
{
switch (socket_type)
{
case SocketType::STREAM:
case SocketType::SEQPACKET:
case Socket::Type::STREAM:
case Socket::Type::SEQPACKET:
m_info.emplace<ConnectionInfo>();
break;
case SocketType::DGRAM:
case Socket::Type::DGRAM:
m_info.emplace<ConnectionlessInfo>();
break;
default:
@ -55,7 +55,6 @@ namespace Kernel
auto it = s_bound_sockets.find(m_bound_path);
if (it != s_bound_sockets.end())
s_bound_sockets.remove(it);
m_bound_path.clear();
}
if (m_info.has<ConnectionInfo>())
{
@ -63,7 +62,6 @@ namespace Kernel
if (auto connection = connection_info.connection.lock(); connection && connection->m_info.has<ConnectionInfo>())
connection->m_info.get<ConnectionInfo>().target_closed = true;
}
m_info.clear();
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len)
@ -89,7 +87,7 @@ namespace Kernel
BAN::RefPtr<UnixDomainSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(SocketDomain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(Socket::Domain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
}
@ -227,10 +225,10 @@ namespace Kernel
{
switch (m_socket_type)
{
case SocketType::STREAM:
case Socket::Type::STREAM:
return true;
case SocketType::SEQPACKET:
case SocketType::DGRAM:
case Socket::Type::SEQPACKET:
case Socket::Type::DGRAM:
return false;
default:
ASSERT_NOT_REACHED();
@ -348,7 +346,8 @@ namespace Kernel
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.target_closed.compare_exchange(true, false))
bool expected = true;
if (connection_info.target_closed.compare_exchange(expected, false))
return 0;
if (!connection_info.connection)
return BAN::Error::from_errno(ENOTCONN);

View File

@ -97,38 +97,38 @@ namespace Kernel
{
bool valid_protocol = true;
SocketDomain sock_domain;
Socket::Domain sock_domain;
switch (domain)
{
case AF_INET:
sock_domain = SocketDomain::INET;
sock_domain = Socket::Domain::INET;
break;
case AF_INET6:
sock_domain = SocketDomain::INET6;
sock_domain = Socket::Domain::INET6;
break;
case AF_UNIX:
sock_domain = SocketDomain::UNIX;
sock_domain = Socket::Domain::UNIX;
valid_protocol = false;
break;
default:
return BAN::Error::from_errno(EPROTOTYPE);
}
SocketType sock_type;
Socket::Type sock_type;
switch (type)
{
case SOCK_STREAM:
sock_type = SocketType::STREAM;
sock_type = Socket::Type::STREAM;
if (protocol != IPPROTO_TCP)
valid_protocol = false;
break;
case SOCK_DGRAM:
sock_type = SocketType::DGRAM;
sock_type = Socket::Type::DGRAM;
if (protocol != IPPROTO_UDP)
valid_protocol = false;
break;
case SOCK_SEQPACKET:
sock_type = SocketType::SEQPACKET;
sock_type = Socket::Type::SEQPACKET;
valid_protocol = false;
break;
default:

View File

@ -1,5 +1,7 @@
#include "LibGUI/Window.h"
#include <BAN/ScopeGuard.h>
#include <LibFont/Font.h>
#include <fcntl.h>
@ -16,7 +18,7 @@ namespace LibGUI
Window::~Window()
{
munmap(m_framebuffer, m_width * m_height * 4);
munmap(m_framebuffer_smo, m_width * m_height * 4);
close(m_server_fd);
}
@ -25,9 +27,13 @@ namespace LibGUI
if (title.size() >= sizeof(WindowCreatePacket::title))
return BAN::Error::from_errno(EINVAL);
BAN::Vector<uint32_t> framebuffer;
TRY(framebuffer.resize(width * height));
int server_fd = socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (server_fd == -1)
return BAN::Error::from_errno(errno);
BAN::ScopeGuard server_closer([server_fd] { close(server_fd); });
if (fcntl(server_fd, F_SETFL, fcntl(server_fd, F_GETFL) | O_CLOEXEC) == -1)
return BAN::Error::from_errno(errno);
@ -46,11 +52,8 @@ namespace LibGUI
timespec current_time;
clock_gettime(CLOCK_MONOTONIC, &current_time);
time_t duration_s = (current_time.tv_sec - start_time.tv_sec) + (current_time.tv_nsec >= start_time.tv_nsec);
if (duration_s > 10)
{
close(server_fd);
if (duration_s > 1)
return BAN::Error::from_errno(ETIMEDOUT);
}
timespec sleep_time;
sleep_time.tv_sec = 0;
@ -64,28 +67,22 @@ namespace LibGUI
strncpy(packet.title, title.data(), title.size());
packet.title[title.size()] = '\0';
if (send(server_fd, &packet, sizeof(packet), 0) != sizeof(packet))
{
close(server_fd);
return BAN::Error::from_errno(errno);
}
WindowCreateResponse response;
if (recv(server_fd, &response, sizeof(response), 0) != sizeof(response))
{
close(server_fd);
return BAN::Error::from_errno(errno);
}
void* framebuffer_addr = smo_map(response.framebuffer_smo_key);
if (framebuffer_addr == nullptr)
{
close(server_fd);
return BAN::Error::from_errno(errno);
}
server_closer.disable();
return TRY(BAN::UniqPtr<Window>::create(
server_fd,
static_cast<uint32_t*>(framebuffer_addr),
BAN::move(framebuffer),
width,
height
));
@ -143,8 +140,8 @@ namespace LibGUI
uint32_t amount_abs = BAN::Math::abs(amount);
if (amount_abs == 0 || amount_abs >= height())
return;
uint32_t* dst = (amount > 0) ? m_framebuffer + width() * amount_abs : m_framebuffer;
uint32_t* src = (amount < 0) ? m_framebuffer + width() * amount_abs : m_framebuffer;
uint32_t* dst = (amount > 0) ? m_framebuffer.data() + width() * amount_abs : m_framebuffer.data();
uint32_t* src = (amount < 0) ? m_framebuffer.data() + width() * amount_abs : m_framebuffer.data();
memmove(dst, src, width() * (height() - amount_abs) * 4);
}
@ -172,6 +169,9 @@ namespace LibGUI
if (!clamp_to_framebuffer(x, y, width, height))
return true;
for (uint32_t i = 0; i < height; i++)
memcpy(&m_framebuffer_smo[(y + i) * m_width + x], &m_framebuffer[(y + i) * m_width + x], width * sizeof(uint32_t));
WindowInvalidatePacket packet;
packet.x = x;
packet.y = y;

View File

@ -136,9 +136,10 @@ namespace LibGUI
int server_fd() const { return m_server_fd; }
private:
Window(int server_fd, uint32_t* framebuffer, uint32_t width, uint32_t height)
Window(int server_fd, uint32_t* framebuffer_smo, BAN::Vector<uint32_t>&& framebuffer, uint32_t width, uint32_t height)
: m_server_fd(server_fd)
, m_framebuffer(framebuffer)
, m_framebuffer_smo(framebuffer_smo)
, m_width(width)
, m_height(height)
{ }
@ -147,7 +148,9 @@ namespace LibGUI
private:
int m_server_fd;
uint32_t* m_framebuffer;
BAN::Vector<uint32_t> m_framebuffer;
uint32_t* m_framebuffer_smo;
uint32_t m_width;
uint32_t m_height;

View File

@ -99,10 +99,9 @@ int main(int argc, char** argv)
{
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
if (nrecv <= 0)
break;
}
write(STDOUT_FILENO, buffer, nrecv);
}