From ab150b458a6e71b1674ba2416527d8b679a30b89 Mon Sep 17 00:00:00 2001 From: Bananymous Date: Fri, 2 Feb 2024 01:31:58 +0200 Subject: [PATCH] Kernel/LibC: Implement basic socket binding --- kernel/include/kernel/FS/Inode.h | 11 +++++++ .../kernel/Networking/NetworkInterface.h | 2 ++ .../kernel/Networking/NetworkManager.h | 8 +++-- .../include/kernel/Networking/NetworkSocket.h | 13 ++++++-- kernel/include/kernel/Process.h | 2 ++ kernel/kernel/FS/Inode.cpp | 14 +++++++++ kernel/kernel/Networking/NetworkManager.cpp | 30 ++++++++++++++++++- kernel/kernel/Networking/NetworkSocket.cpp | 21 ++++++++++++- kernel/kernel/OpenFileDescriptorSet.cpp | 2 ++ kernel/kernel/Process.cpp | 14 +++++++++ kernel/kernel/Syscall.cpp | 3 ++ libc/include/sys/syscall.h | 1 + libc/sys/socket.cpp | 5 ++++ 13 files changed, 119 insertions(+), 7 deletions(-) diff --git a/kernel/include/kernel/FS/Inode.h b/kernel/include/kernel/FS/Inode.h index 5fbe8564..f8cc7b12 100644 --- a/kernel/include/kernel/FS/Inode.h +++ b/kernel/include/kernel/FS/Inode.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -86,6 +87,8 @@ namespace Kernel virtual bool is_pipe() const { return false; } virtual bool is_tty() const { return false; } + void on_close(); + // Directory API BAN::ErrorOr> find_inode(BAN::StringView); BAN::ErrorOr list_next_inodes(off_t, DirectoryEntryList*, size_t); @@ -96,6 +99,9 @@ namespace Kernel // Link API BAN::ErrorOr link_target(); + // Socket API + BAN::ErrorOr bind(const sockaddr* address, socklen_t address_len); + // General API BAN::ErrorOr read(off_t, BAN::ByteSpan buffer); BAN::ErrorOr write(off_t, BAN::ConstByteSpan buffer); @@ -105,6 +111,8 @@ namespace Kernel bool has_data() const; protected: + virtual void on_close_impl() {} + // Directory API virtual BAN::ErrorOr> find_inode_impl(BAN::StringView) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr list_next_inodes_impl(off_t, DirectoryEntryList*, size_t) { return BAN::Error::from_errno(ENOTSUP); } @@ -115,6 +123,9 @@ namespace Kernel // Link API virtual BAN::ErrorOr link_target_impl() { return BAN::Error::from_errno(ENOTSUP); } + // Socket API + virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } + // General API virtual BAN::ErrorOr read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr write_impl(off_t, BAN::ConstByteSpan) { return BAN::Error::from_errno(ENOTSUP); } diff --git a/kernel/include/kernel/Networking/NetworkInterface.h b/kernel/include/kernel/Networking/NetworkInterface.h index 4888615b..65a1e766 100644 --- a/kernel/include/kernel/Networking/NetworkInterface.h +++ b/kernel/include/kernel/Networking/NetworkInterface.h @@ -35,6 +35,8 @@ namespace Kernel const dev_t m_rdev; char m_name[10]; + + uint32_t m_ipv4_address {}; }; } diff --git a/kernel/include/kernel/Networking/NetworkManager.h b/kernel/include/kernel/Networking/NetworkManager.h index 112474bd..3777262a 100644 --- a/kernel/include/kernel/Networking/NetworkManager.h +++ b/kernel/include/kernel/Networking/NetworkManager.h @@ -6,6 +6,8 @@ #include #include +#include + namespace Kernel { @@ -24,7 +26,9 @@ namespace Kernel static NetworkManager& get(); BAN::ErrorOr add_interface(PCI::Device& pci_device); - BAN::ErrorOr bind_socket(int port, BAN::RefPtr); + + void unbind_socket(uint16_t port, BAN::RefPtr); + BAN::ErrorOr bind_socket(uint16_t port, BAN::RefPtr); BAN::ErrorOr> create_socket(SocketType, mode_t, uid_t, gid_t); @@ -33,7 +37,7 @@ namespace Kernel private: BAN::Vector> m_interfaces; - BAN::HashMap> m_bound_sockets; + BAN::HashMap> m_bound_sockets; }; } diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index b03072d0..72a9f42c 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -1,21 +1,28 @@ #pragma once +#include #include #include namespace Kernel { - class NetworkSocket : public TmpInode + class NetworkSocket : public TmpInode, public BAN::Weakable { public: - void bind_interface(NetworkInterface*); + void bind_interface_and_port(NetworkInterface*, uint16_t port); + ~NetworkSocket(); protected: NetworkSocket(mode_t mode, uid_t uid, gid_t gid); + virtual void on_close_impl() override; + + virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; + protected: - NetworkInterface* m_interface = nullptr; + NetworkInterface* m_interface = nullptr; + uint16_t m_port = 0; }; } diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index 25742af2..3eba7e96 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -16,6 +16,7 @@ #include #include +#include #include namespace LibELF { class LoadableELF; } @@ -112,6 +113,7 @@ namespace Kernel BAN::ErrorOr sys_chown(const char*, uid_t, gid_t); BAN::ErrorOr sys_socket(int domain, int type, int protocol); + BAN::ErrorOr sys_bind(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr sys_pipe(int fildes[2]); BAN::ErrorOr sys_dup(int fildes); diff --git a/kernel/kernel/FS/Inode.cpp b/kernel/kernel/FS/Inode.cpp index a3815f4c..aeb275f7 100644 --- a/kernel/kernel/FS/Inode.cpp +++ b/kernel/kernel/FS/Inode.cpp @@ -56,6 +56,12 @@ namespace Kernel return true; } + void Inode::on_close() + { + LockGuard _(m_lock); + on_close_impl(); + } + BAN::ErrorOr> Inode::find_inode(BAN::StringView name) { LockGuard _(m_lock); @@ -110,6 +116,14 @@ namespace Kernel return link_target_impl(); } + BAN::ErrorOr Inode::bind(const sockaddr* address, socklen_t address_len) + { + LockGuard _(m_lock); + if (!mode().ifsock()) + return BAN::Error::from_errno(ENOTSOCK); + return bind_impl(address, address_len); + } + BAN::ErrorOr Inode::read(off_t offset, BAN::ByteSpan buffer) { LockGuard _(m_lock); diff --git a/kernel/kernel/Networking/NetworkManager.cpp b/kernel/kernel/Networking/NetworkManager.cpp index f0ce65c8..1f6a166a 100644 --- a/kernel/kernel/Networking/NetworkManager.cpp +++ b/kernel/kernel/Networking/NetworkManager.cpp @@ -68,11 +68,39 @@ namespace Kernel BAN::ErrorOr> NetworkManager::create_socket(SocketType type, mode_t mode, uid_t uid, gid_t gid) { + ASSERT((mode & Inode::Mode::TYPE_MASK) == 0); + if (type != SocketType::DGRAM) return BAN::Error::from_errno(EPROTOTYPE); - auto udp_socket = TRY(UDPSocket::create(mode, uid, gid)); + auto udp_socket = TRY(UDPSocket::create(mode | Inode::Mode::IFSOCK, uid, gid)); return BAN::RefPtr(udp_socket); } + void NetworkManager::unbind_socket(uint16_t port, BAN::RefPtr socket) + { + if (m_bound_sockets.contains(port)) + { + ASSERT(m_bound_sockets[port].valid()); + ASSERT(m_bound_sockets[port].lock() == socket); + m_bound_sockets.remove(port); + } + NetworkManager::get().remove_from_cache(socket); + } + + BAN::ErrorOr NetworkManager::bind_socket(uint16_t port, BAN::RefPtr socket) + { + if (m_interfaces.empty()) + return BAN::Error::from_errno(EADDRNOTAVAIL); + if (m_bound_sockets.contains(port)) + return BAN::Error::from_errno(EADDRINUSE); + + // FIXME: actually determine proper interface + auto interface = m_interfaces.front(); + TRY(m_bound_sockets.insert(port, socket)); + socket->bind_interface_and_port(interface.ptr(), port); + + return {}; + } + } diff --git a/kernel/kernel/Networking/NetworkSocket.cpp b/kernel/kernel/Networking/NetworkSocket.cpp index fd8b3f3c..70f2d295 100644 --- a/kernel/kernel/Networking/NetworkSocket.cpp +++ b/kernel/kernel/Networking/NetworkSocket.cpp @@ -13,11 +13,30 @@ namespace Kernel ) { } - void NetworkSocket::bind_interface(NetworkInterface* interface) + NetworkSocket::~NetworkSocket() + { + } + + void NetworkSocket::on_close_impl() + { + if (m_interface) + NetworkManager::get().unbind_socket(m_port, this); + } + + void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port) { ASSERT(!m_interface); ASSERT(interface); m_interface = interface; + m_port = port; + } + + BAN::ErrorOr NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len) + { + if (address_len != sizeof(sockaddr_in)) + return BAN::Error::from_errno(EINVAL); + auto* addr_in = reinterpret_cast(address); + return NetworkManager::get().bind_socket(addr_in->sin_port, this); } } diff --git a/kernel/kernel/OpenFileDescriptorSet.cpp b/kernel/kernel/OpenFileDescriptorSet.cpp index 207ec2b8..7cce348d 100644 --- a/kernel/kernel/OpenFileDescriptorSet.cpp +++ b/kernel/kernel/OpenFileDescriptorSet.cpp @@ -276,6 +276,8 @@ namespace Kernel if (m_open_files[fd]->flags & O_WRONLY && m_open_files[fd]->inode->is_pipe()) ((Pipe*)m_open_files[fd]->inode.ptr())->close_writing(); + m_open_files[fd]->inode->on_close(); + m_open_files[fd].clear(); return {}; diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index 607308b4..965da0ec 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -901,6 +901,20 @@ namespace Kernel return TRY(m_open_file_descriptors.socket(domain, type, protocol)); } + + BAN::ErrorOr Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len) + { + LockGuard _(m_lock); + TRY(validate_pointer_access(address, address_len)); + + auto inode = TRY(m_open_file_descriptors.inode_of(socket)); + if (!inode->mode().ifsock()) + return BAN::Error::from_errno(ENOTSOCK); + + TRY(inode->bind(address, address_len)); + return 0; + } + BAN::ErrorOr Process::sys_pipe(int fildes[2]) { LockGuard _(m_lock); diff --git a/kernel/kernel/Syscall.cpp b/kernel/kernel/Syscall.cpp index 5a37b4a8..60b2369c 100644 --- a/kernel/kernel/Syscall.cpp +++ b/kernel/kernel/Syscall.cpp @@ -216,6 +216,9 @@ namespace Kernel case SYS_SOCKET: ret = Process::current().sys_socket((int)arg1, (int)arg2, (int)arg3); break; + case SYS_BIND: + ret = Process::current().sys_bind((int)arg1, (const sockaddr*)arg2, (socklen_t)arg3); + break; default: dwarnln("Unknown syscall {}", syscall); break; diff --git a/libc/include/sys/syscall.h b/libc/include/sys/syscall.h index 584c2fc0..8ce26841 100644 --- a/libc/include/sys/syscall.h +++ b/libc/include/sys/syscall.h @@ -64,6 +64,7 @@ __BEGIN_DECLS #define SYS_CHOWN 63 #define SYS_LOAD_KEYMAP 64 #define SYS_SOCKET 65 +#define SYS_BIND 66 __END_DECLS diff --git a/libc/sys/socket.cpp b/libc/sys/socket.cpp index e0e3aed5..2d4d23ef 100644 --- a/libc/sys/socket.cpp +++ b/libc/sys/socket.cpp @@ -2,6 +2,11 @@ #include #include +int bind(int socket, const struct sockaddr* address, socklen_t address_len) +{ + return syscall(SYS_BIND, socket, address, address_len); +} + int socket(int domain, int type, int protocol) { return syscall(SYS_SOCKET, domain, type, protocol);