diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 8ba07f0a50..2107df31be 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -51,6 +51,7 @@ set(KERNEL_SOURCES kernel/Memory/VirtualRange.cpp kernel/Networking/E1000/E1000.cpp kernel/Networking/E1000/E1000E.cpp + kernel/Networking/IPv4.cpp kernel/Networking/NetworkInterface.cpp kernel/Networking/NetworkManager.cpp kernel/Networking/NetworkSocket.cpp diff --git a/kernel/include/kernel/FS/Inode.h b/kernel/include/kernel/FS/Inode.h index f8cc7b12d1..5f5ba9f6cd 100644 --- a/kernel/include/kernel/FS/Inode.h +++ b/kernel/include/kernel/FS/Inode.h @@ -101,6 +101,7 @@ namespace Kernel // Socket API BAN::ErrorOr bind(const sockaddr* address, socklen_t address_len); + BAN::ErrorOr sendto(const sys_sendto_t*); // General API BAN::ErrorOr read(off_t, BAN::ByteSpan buffer); @@ -125,6 +126,7 @@ namespace Kernel // Socket API virtual BAN::ErrorOr bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } + virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); } // General API virtual BAN::ErrorOr read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); } diff --git a/kernel/include/kernel/Networking/IPv4.h b/kernel/include/kernel/Networking/IPv4.h new file mode 100644 index 0000000000..e111cfc821 --- /dev/null +++ b/kernel/include/kernel/Networking/IPv4.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace Kernel +{ + + BAN::ErrorOr add_ipv4_header(BAN::Vector&, uint32_t src_ipv4, uint32_t dst_ipv4, uint8_t protocol); + +} diff --git a/kernel/include/kernel/Networking/NetworkInterface.h b/kernel/include/kernel/Networking/NetworkInterface.h index 65a1e76683..3312c9d894 100644 --- a/kernel/include/kernel/Networking/NetworkInterface.h +++ b/kernel/include/kernel/Networking/NetworkInterface.h @@ -20,14 +20,16 @@ namespace Kernel virtual ~NetworkInterface() {} virtual uint8_t* get_mac_address() = 0; + uint32_t get_ipv4_address() const { return m_ipv4_address; } virtual bool link_up() = 0; virtual int link_speed() = 0; + BAN::ErrorOr add_interface_header(BAN::Vector&, uint8_t destination_mac[6]); + virtual dev_t rdev() const override { return m_rdev; } virtual BAN::StringView name() const override { return m_name; } - protected: virtual BAN::ErrorOr send_raw_bytes(BAN::ConstByteSpan) = 0; private: diff --git a/kernel/include/kernel/Networking/NetworkSocket.h b/kernel/include/kernel/Networking/NetworkSocket.h index 72a9f42c7c..f17ae34a17 100644 --- a/kernel/include/kernel/Networking/NetworkSocket.h +++ b/kernel/include/kernel/Networking/NetworkSocket.h @@ -9,20 +9,27 @@ namespace Kernel class NetworkSocket : public TmpInode, public BAN::Weakable { + public: + static constexpr uint16_t PORT_NONE = 0; + public: void bind_interface_and_port(NetworkInterface*, uint16_t port); ~NetworkSocket(); + virtual BAN::ErrorOr add_protocol_header(BAN::Vector&, uint16_t src_port, uint16_t dst_port) = 0; + virtual uint8_t protocol() const = 0; + protected: NetworkSocket(mode_t mode, uid_t uid, gid_t gid); virtual void on_close_impl() override; virtual BAN::ErrorOr bind_impl(const sockaddr* address, socklen_t address_len) override; + virtual BAN::ErrorOr sendto_impl(const sys_sendto_t*) override; protected: - NetworkInterface* m_interface = nullptr; - uint16_t m_port = 0; + NetworkInterface* m_interface = nullptr; + uint16_t m_port = PORT_NONE; }; } diff --git a/kernel/include/kernel/Networking/UDPSocket.h b/kernel/include/kernel/Networking/UDPSocket.h index aea0461344..a93dea7a6c 100644 --- a/kernel/include/kernel/Networking/UDPSocket.h +++ b/kernel/include/kernel/Networking/UDPSocket.h @@ -11,18 +11,13 @@ namespace Kernel public: static BAN::ErrorOr> create(mode_t, uid_t, gid_t); - void bind_interface(NetworkInterface*); - - protected: - virtual BAN::ErrorOr read_impl(off_t, BAN::ByteSpan) override; - virtual BAN::ErrorOr write_impl(off_t, BAN::ConstByteSpan) override; + virtual BAN::ErrorOr add_protocol_header(BAN::Vector&, uint16_t src_port, uint16_t dst_port) override; + virtual uint8_t protocol() const override { return 0x11; } private: UDPSocket(mode_t, uid_t, gid_t); private: - NetworkInterface* m_interface = nullptr; - friend class BAN::RefPtr; }; diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index 3eba7e967d..815dbda6d6 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -114,6 +114,7 @@ namespace Kernel BAN::ErrorOr sys_socket(int domain, int type, int protocol); BAN::ErrorOr sys_bind(int socket, const sockaddr* address, socklen_t address_len); + BAN::ErrorOr sys_sendto(const sys_sendto_t*); BAN::ErrorOr sys_pipe(int fildes[2]); BAN::ErrorOr sys_dup(int fildes); diff --git a/kernel/kernel/FS/Inode.cpp b/kernel/kernel/FS/Inode.cpp index aeb275f730..c32a47a104 100644 --- a/kernel/kernel/FS/Inode.cpp +++ b/kernel/kernel/FS/Inode.cpp @@ -124,6 +124,14 @@ namespace Kernel return bind_impl(address, address_len); } + BAN::ErrorOr Inode::sendto(const sys_sendto_t* arguments) + { + LockGuard _(m_lock); + if (!mode().ifsock()) + return BAN::Error::from_errno(ENOTSOCK); + return sendto_impl(arguments); + }; + BAN::ErrorOr Inode::read(off_t offset, BAN::ByteSpan buffer) { LockGuard _(m_lock); diff --git a/kernel/kernel/Networking/IPv4.cpp b/kernel/kernel/Networking/IPv4.cpp new file mode 100644 index 0000000000..ccd8fa4c2f --- /dev/null +++ b/kernel/kernel/Networking/IPv4.cpp @@ -0,0 +1,53 @@ +#include +#include + +namespace Kernel +{ + + + struct IPv4Header + { + uint8_t version_IHL; + uint8_t DSCP_ECN; + BAN::NetworkEndian total_length; + BAN::NetworkEndian identification; + BAN::NetworkEndian flags_frament; + uint8_t time_to_live; + uint8_t protocol; + BAN::NetworkEndian header_checksum; + BAN::NetworkEndian src_address; + BAN::NetworkEndian dst_address; + + uint16_t checksum() const + { + return 0xFFFF + - (((uint16_t)version_IHL << 8) | DSCP_ECN) + - total_length + - identification + - flags_frament + - (((uint16_t)time_to_live << 8) | protocol); + } + }; + static_assert(sizeof(IPv4Header) == 20); + + BAN::ErrorOr add_ipv4_header(BAN::Vector& packet, uint32_t src_ipv4, uint32_t dst_ipv4, uint8_t protocol) + { + TRY(packet.resize(packet.size() + sizeof(IPv4Header))); + memmove(packet.data() + sizeof(IPv4Header), packet.data(), packet.size() - sizeof(IPv4Header)); + + auto* header = reinterpret_cast(packet.data()); + header->version_IHL = 0x45; + header->DSCP_ECN = 0x10; + header->total_length = packet.size(); + header->identification = 1; + header->flags_frament = 0x00; + header->time_to_live = 0x40; + header->protocol = protocol; + header->header_checksum = header->checksum(); + header->src_address = src_ipv4; + header->dst_address = dst_ipv4; + + return {}; + } + +} diff --git a/kernel/kernel/Networking/NetworkInterface.cpp b/kernel/kernel/Networking/NetworkInterface.cpp index f9a60715b7..6ca4bc57ab 100644 --- a/kernel/kernel/Networking/NetworkInterface.cpp +++ b/kernel/kernel/Networking/NetworkInterface.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -7,6 +8,14 @@ namespace Kernel { + struct EthernetHeader + { + uint8_t dst_mac[6]; + uint8_t src_mac[6]; + BAN::NetworkEndian ether_type; + }; + static_assert(sizeof(EthernetHeader) == 14); + static dev_t get_network_rdev_major() { static dev_t major = DevFileSystem::get().get_next_dev(); @@ -31,4 +40,19 @@ namespace Kernel m_name[3] = minor(m_rdev) + '0'; } + BAN::ErrorOr NetworkInterface::add_interface_header(BAN::Vector& packet, uint8_t destination_mac[6]) + { + ASSERT(m_type == Type::Ethernet); + + TRY(packet.resize(packet.size() + sizeof(EthernetHeader))); + memmove(packet.data() + sizeof(EthernetHeader), packet.data(), packet.size() - sizeof(EthernetHeader)); + + auto* header = reinterpret_cast(packet.data()); + memcpy(header->dst_mac, destination_mac, 6); + memcpy(header->src_mac, get_mac_address(), 6); + header->ether_type = 0x0800; // ipv4 + + return {}; + } + } diff --git a/kernel/kernel/Networking/NetworkManager.cpp b/kernel/kernel/Networking/NetworkManager.cpp index 1f6a166a98..7c350f2aac 100644 --- a/kernel/kernel/Networking/NetworkManager.cpp +++ b/kernel/kernel/Networking/NetworkManager.cpp @@ -65,7 +65,6 @@ namespace Kernel return {}; } - BAN::ErrorOr> NetworkManager::create_socket(SocketType type, mode_t mode, uid_t uid, gid_t gid) { ASSERT((mode & Inode::Mode::TYPE_MASK) == 0); @@ -92,12 +91,16 @@ namespace Kernel { if (m_interfaces.empty()) return BAN::Error::from_errno(EADDRNOTAVAIL); - if (m_bound_sockets.contains(port)) - return BAN::Error::from_errno(EADDRINUSE); + + if (port != NetworkSocket::PORT_NONE) + { + if (m_bound_sockets.contains(port)) + return BAN::Error::from_errno(EADDRINUSE); + TRY(m_bound_sockets.insert(port, socket)); + } // FIXME: actually determine proper interface auto interface = m_interfaces.front(); - TRY(m_bound_sockets.insert(port, socket)); socket->bind_interface_and_port(interface.ptr(), port); return {}; diff --git a/kernel/kernel/Networking/NetworkSocket.cpp b/kernel/kernel/Networking/NetworkSocket.cpp index 70f2d295a6..79c888d78e 100644 --- a/kernel/kernel/Networking/NetworkSocket.cpp +++ b/kernel/kernel/Networking/NetworkSocket.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -39,4 +40,43 @@ namespace Kernel return NetworkManager::get().bind_socket(addr_in->sin_port, this); } + BAN::ErrorOr NetworkSocket::sendto_impl(const sys_sendto_t* arguments) + { + if (arguments->dest_len != sizeof(sockaddr_in)) + return BAN::Error::from_errno(EINVAL); + if (arguments->flags) + { + dprintln("flags not supported"); + return BAN::Error::from_errno(ENOTSUP); + } + + if (!m_interface) + TRY(NetworkManager::get().bind_socket(PORT_NONE, this)); + + auto* destination = reinterpret_cast(arguments->dest_addr); + auto message = BAN::ConstByteSpan((const uint8_t*)arguments->message, arguments->length); + + if (destination->sin_port == PORT_NONE) + return BAN::Error::from_errno(EINVAL); + + if (destination->sin_addr.s_addr != 0xFFFFFFFF) + { + dprintln("Only broadcast ip supported"); + return BAN::Error::from_errno(EINVAL); + } + + static uint8_t dest_mac[6] { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF }; + + BAN::Vector full_packet; + TRY(full_packet.resize(message.size())); + memcpy(full_packet.data(), message.data(), message.size()); + TRY(add_protocol_header(full_packet, m_port, destination->sin_port)); + TRY(add_ipv4_header(full_packet, m_interface->get_ipv4_address(), destination->sin_addr.s_addr, protocol())); + TRY(m_interface->add_interface_header(full_packet, dest_mac)); + + TRY(m_interface->send_raw_bytes(BAN::ConstByteSpan { full_packet.span() })); + + return arguments->length; + } + } diff --git a/kernel/kernel/Networking/UDPSocket.cpp b/kernel/kernel/Networking/UDPSocket.cpp index 967d1228d9..a14a42ddc6 100644 --- a/kernel/kernel/Networking/UDPSocket.cpp +++ b/kernel/kernel/Networking/UDPSocket.cpp @@ -1,8 +1,18 @@ +#include #include namespace Kernel { + struct UDPHeader + { + BAN::NetworkEndian src_port; + BAN::NetworkEndian dst_port; + BAN::NetworkEndian length; + BAN::NetworkEndian checksum; + }; + static_assert(sizeof(UDPHeader) == 8); + BAN::ErrorOr> UDPSocket::create(mode_t mode, uid_t uid, gid_t gid) { return TRY(BAN::RefPtr::create(mode, uid, gid)); @@ -12,14 +22,18 @@ namespace Kernel : NetworkSocket(mode, uid, gid) { } - BAN::ErrorOr UDPSocket::read_impl(off_t, BAN::ByteSpan) + BAN::ErrorOr UDPSocket::add_protocol_header(BAN::Vector& packet, uint16_t src_port, uint16_t dst_port) { - return BAN::Error::from_errno(ENOTSUP); - } + TRY(packet.resize(packet.size() + sizeof(UDPHeader))); + memmove(packet.data() + sizeof(UDPHeader), packet.data(), packet.size() - sizeof(UDPHeader)); - BAN::ErrorOr UDPSocket::write_impl(off_t, BAN::ConstByteSpan) - { - return BAN::Error::from_errno(ENOTSUP); + auto* header = reinterpret_cast(packet.data()); + header->src_port = src_port; + header->dst_port = dst_port; + header->length = packet.size(); + header->checksum = 0; + + return {}; } } diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index 965da0ec29..7144a4f86b 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -915,6 +915,21 @@ namespace Kernel return 0; } + + BAN::ErrorOr Process::sys_sendto(const sys_sendto_t* arguments) + { + LockGuard _(m_lock); + TRY(validate_pointer_access(arguments, sizeof(sys_sendto_t))); + TRY(validate_pointer_access(arguments->message, arguments->length)); + TRY(validate_pointer_access(arguments->dest_addr, arguments->dest_len)); + + auto inode = TRY(m_open_file_descriptors.inode_of(arguments->socket)); + if (!inode->mode().ifsock()) + return BAN::Error::from_errno(ENOTSOCK); + + return TRY(inode->sendto(arguments)); + } + BAN::ErrorOr Process::sys_pipe(int fildes[2]) { LockGuard _(m_lock); diff --git a/kernel/kernel/Syscall.cpp b/kernel/kernel/Syscall.cpp index 60b2369cb2..45fab33d40 100644 --- a/kernel/kernel/Syscall.cpp +++ b/kernel/kernel/Syscall.cpp @@ -219,6 +219,9 @@ namespace Kernel case SYS_BIND: ret = Process::current().sys_bind((int)arg1, (const sockaddr*)arg2, (socklen_t)arg3); break; + case SYS_SENDTO: + ret = Process::current().sys_sendto((const sys_sendto_t*)arg1); + break; default: dwarnln("Unknown syscall {}", syscall); break; diff --git a/libc/include/sys/socket.h b/libc/include/sys/socket.h index 5c9208a4a7..43e8b1b893 100644 --- a/libc/include/sys/socket.h +++ b/libc/include/sys/socket.h @@ -105,6 +105,16 @@ struct linger #define SHUT_WR 0x02 #define SHUT_RDWR (SHUT_RD | SHUT_WR) +struct sys_sendto_t +{ + int socket; + const void* message; + size_t length; + int flags; + const struct sockaddr* dest_addr; + socklen_t dest_len; +}; + int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len); int bind(int socket, const struct sockaddr* address, socklen_t address_len); int connect(int socket, const struct sockaddr* address, socklen_t address_len); diff --git a/libc/include/sys/syscall.h b/libc/include/sys/syscall.h index 8ce268415f..39e7da66c6 100644 --- a/libc/include/sys/syscall.h +++ b/libc/include/sys/syscall.h @@ -65,6 +65,7 @@ __BEGIN_DECLS #define SYS_LOAD_KEYMAP 64 #define SYS_SOCKET 65 #define SYS_BIND 66 +#define SYS_SENDTO 67 __END_DECLS diff --git a/libc/sys/socket.cpp b/libc/sys/socket.cpp index 2d4d23ef4f..ce05432fed 100644 --- a/libc/sys/socket.cpp +++ b/libc/sys/socket.cpp @@ -7,6 +7,19 @@ int bind(int socket, const struct sockaddr* address, socklen_t address_len) return syscall(SYS_BIND, socket, address, address_len); } +ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len) +{ + sys_sendto_t arguments { + .socket = socket, + .message = message, + .length = length, + .flags = flags, + .dest_addr = dest_addr, + .dest_len = dest_len + }; + return syscall(SYS_SENDTO, &arguments); +} + int socket(int domain, int type, int protocol) { return syscall(SYS_SOCKET, domain, type, protocol);