Compare commits

..

21 Commits

Author SHA1 Message Date
Bananymous 9314528b9b Kernel: Improve syscall handling
Syscalls are now called from a list of function pointers
2024-02-12 21:51:11 +02:00
Bananymous 78ef7e804f BAN: Implement bit_cast 2024-02-12 21:46:33 +02:00
Bananymous 3fc1edede0 Kernel/LibC: Implement super basic select
This does not really even block but it works... :D
2024-02-12 17:26:33 +02:00
Bananymous f50b4be162 Kernel: Cleanup TCP code 2024-02-12 15:44:40 +02:00
Bananymous ccde8148a7 Userspace: Implement basic udp test program 2024-02-12 04:45:42 +02:00
Bananymous b9bbf42538 Userspace: Implement basic test program for tcp connection 2024-02-12 04:45:42 +02:00
Bananymous 435636a655 Kernel: Implement super simple TCP stack
No SACK support and windows are fixed size
2024-02-12 04:45:42 +02:00
Bananymous ba06269b14 Kernel: Move on_close_impl from network socket to udp socket 2024-02-12 04:45:42 +02:00
Bananymous be01ccdb08 Kernel: Fix E1000 mtu 2024-02-12 04:25:39 +02:00
Bananymous b45d27593f Kernel: Implement super simple PRNG 2024-02-12 04:25:06 +02:00
Bananymous ff49d8b84f Kernel: Cleanup OSI layer overlapping 2024-02-09 17:05:07 +02:00
Bananymous 5d78cd3016 Kernel: Add spin lock assert back. I had accidentally deleted it 2024-02-09 16:58:55 +02:00
Bananymous ed0b1a86aa Kernel: Semaphores and Threads can now be blocked with timeout 2024-02-09 15:28:15 +02:00
Bananymous 534b3e6a9a Kernel: Add LockFreeGuard to LockGuard.h 2024-02-09 15:13:54 +02:00
Bananymous d452cf4170 Kernel: Fix checksum for packets with odd number of bytes 2024-02-09 01:20:40 +02:00
Bananymous f117027175 resolver: dump errors to debug output 2024-02-08 18:34:15 +02:00
Bananymous acf79570ef Kernel: Cleanup network APIs and error messages 2024-02-08 18:33:49 +02:00
Bananymous 5a939cf252 Userspace: Add simple test for unix domain sockets 2024-02-08 13:18:54 +02:00
Bananymous 9bc7a72a25 Kernel: Implement unix domain sockets with SOCK_DGRAM
Also unbind sockets on close
2024-02-08 13:18:54 +02:00
Bananymous 065ee9004c Userspace: Add DNS cache to resolver
Also the format of resolver reply is now just sockaddr_storage with
family set and address in the storage field.
2024-02-08 12:06:30 +02:00
Bananymous 6fb69a1dc2 LibC: Implement inet_ntop for IPv4 addresses 2024-02-08 11:59:11 +02:00
73 changed files with 2283 additions and 724 deletions

12
BAN/include/BAN/Bitcast.h Normal file
View File

@ -0,0 +1,12 @@
#pragma once
namespace BAN
{
template<typename To, typename From>
constexpr To bit_cast(const From& from)
{
return __builtin_bit_cast(To, from);
}
}

View File

@ -57,6 +57,7 @@ set(KERNEL_SOURCES
kernel/Networking/NetworkInterface.cpp kernel/Networking/NetworkInterface.cpp
kernel/Networking/NetworkManager.cpp kernel/Networking/NetworkManager.cpp
kernel/Networking/NetworkSocket.cpp kernel/Networking/NetworkSocket.cpp
kernel/Networking/TCPSocket.cpp
kernel/Networking/UDPSocket.cpp kernel/Networking/UDPSocket.cpp
kernel/Networking/UNIX/Socket.cpp kernel/Networking/UNIX/Socket.cpp
kernel/OpenFileDescriptorSet.cpp kernel/OpenFileDescriptorSet.cpp
@ -64,6 +65,7 @@ set(KERNEL_SOURCES
kernel/PCI.cpp kernel/PCI.cpp
kernel/PIC.cpp kernel/PIC.cpp
kernel/Process.cpp kernel/Process.cpp
kernel/Random.cpp
kernel/Scheduler.cpp kernel/Scheduler.cpp
kernel/Semaphore.cpp kernel/Semaphore.cpp
kernel/SpinLock.cpp kernel/SpinLock.cpp

View File

@ -1,7 +1,5 @@
.section .userspace, "aw" .section .userspace, "aw"
#include <sys/syscall.h>
// stack contains // stack contains
// return address // return address
// signal number // signal number

View File

@ -21,6 +21,10 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return 0; } virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return 0; }
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override; virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override;
virtual bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
const dev_t m_rdev; const dev_t m_rdev;
}; };

View File

@ -30,6 +30,10 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override; virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override;
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override; virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override;
virtual bool can_read_impl() const override { return true; }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
FramebufferDevice(mode_t mode, uid_t uid, gid_t gid, dev_t rdev, paddr_t paddr, uint32_t width, uint32_t height, uint32_t pitch, uint8_t bpp); FramebufferDevice(mode_t mode, uid_t uid, gid_t gid, dev_t rdev, paddr_t paddr, uint32_t width, uint32_t height, uint32_t pitch, uint8_t bpp);
BAN::ErrorOr<void> initialize(); BAN::ErrorOr<void> initialize();

View File

@ -23,6 +23,10 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return 0; } virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return 0; }
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override { return buffer.size(); }; virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override { return buffer.size(); };
virtual bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
const dev_t m_rdev; const dev_t m_rdev;
}; };

View File

@ -21,6 +21,10 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override; virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override;
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override { return buffer.size(); }; virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override { return buffer.size(); };
virtual bool can_read_impl() const override { return true; }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
private: private:
const dev_t m_rdev; const dev_t m_rdev;
}; };

View File

@ -43,6 +43,10 @@ namespace Kernel
virtual BAN::ErrorOr<void> truncate_impl(size_t) override; virtual BAN::ErrorOr<void> truncate_impl(size_t) override;
virtual BAN::ErrorOr<void> chmod_impl(mode_t) override; virtual BAN::ErrorOr<void> chmod_impl(mode_t) override;
virtual bool can_read_impl() const override { return true; }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
// Returns maximum number of data blocks in use // Returns maximum number of data blocks in use
// NOTE: the inode might have more blocks than what this suggests if it has been shrinked // NOTE: the inode might have more blocks than what this suggests if it has been shrinked

View File

@ -104,8 +104,8 @@ namespace Kernel
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> listen(int backlog); BAN::ErrorOr<void> listen(int backlog);
BAN::ErrorOr<size_t> sendto(const sys_sendto_t*); BAN::ErrorOr<size_t> sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<size_t> recvfrom(sys_recvfrom_t*); BAN::ErrorOr<size_t> recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len);
// General API // General API
BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer); BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer);
@ -113,7 +113,11 @@ namespace Kernel
BAN::ErrorOr<void> truncate(size_t); BAN::ErrorOr<void> truncate(size_t);
BAN::ErrorOr<void> chmod(mode_t); BAN::ErrorOr<void> chmod(mode_t);
BAN::ErrorOr<void> chown(uid_t, gid_t); BAN::ErrorOr<void> chown(uid_t, gid_t);
bool has_data() const;
// Select/Non blocking API
bool can_read() const;
bool can_write() const;
bool has_error() const;
BAN::ErrorOr<long> ioctl(int request, void* arg); BAN::ErrorOr<long> ioctl(int request, void* arg);
@ -135,8 +139,8 @@ namespace Kernel
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
// General API // General API
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); }
@ -144,7 +148,11 @@ namespace Kernel
virtual BAN::ErrorOr<void> truncate_impl(size_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> truncate_impl(size_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> chmod_impl(mode_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> chmod_impl(mode_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> chown_impl(uid_t, gid_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> chown_impl(uid_t, gid_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual bool has_data_impl() const { dwarnln("nonblock not supported"); return true; }
// Select/Non blocking API
virtual bool can_read_impl() const = 0;
virtual bool can_write_impl() const = 0;
virtual bool has_error_impl() const = 0;
virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) { return BAN::Error::from_errno(ENOTSUP); }

View File

@ -34,6 +34,10 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override; virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override;
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override; virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override;
virtual bool can_read_impl() const override { return !m_buffer.empty(); }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
Pipe(const Credentials&); Pipe(const Credentials&);

View File

@ -43,7 +43,10 @@ namespace Kernel
// You may not write here and this is always non blocking // You may not write here and this is always non blocking
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override { return BAN::Error::from_errno(EINVAL); } virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override { return BAN::Error::from_errno(EINVAL); }
virtual BAN::ErrorOr<void> truncate_impl(size_t) override { return BAN::Error::from_errno(EINVAL); } virtual BAN::ErrorOr<void> truncate_impl(size_t) override { return BAN::Error::from_errno(EINVAL); }
virtual bool has_data_impl() const override { return true; }
virtual bool can_read_impl() const override { return true; }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
private: private:
ProcROInode(Process&, size_t (Process::*)(off_t, BAN::ByteSpan) const, TmpFileSystem&, const TmpInodeInfo&); ProcROInode(Process&, size_t (Process::*)(off_t, BAN::ByteSpan) const, TmpFileSystem&, const TmpInodeInfo&);

View File

@ -72,7 +72,10 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override; virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override;
virtual BAN::ErrorOr<void> truncate_impl(size_t) override; virtual BAN::ErrorOr<void> truncate_impl(size_t) override;
virtual BAN::ErrorOr<void> chmod_impl(mode_t) override; virtual BAN::ErrorOr<void> chmod_impl(mode_t) override;
virtual bool has_data_impl() const override { return true; }
virtual bool can_read_impl() const override { return true; }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
TmpFileInode(TmpFileSystem&, ino_t, const TmpInodeInfo&); TmpFileInode(TmpFileSystem&, ino_t, const TmpInodeInfo&);
@ -91,7 +94,10 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override { return BAN::Error::from_errno(ENODEV); } virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<void> truncate_impl(size_t) override { return BAN::Error::from_errno(ENODEV); } virtual BAN::ErrorOr<void> truncate_impl(size_t) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<void> chmod_impl(mode_t) override; virtual BAN::ErrorOr<void> chmod_impl(mode_t) override;
virtual bool has_data_impl() const override { return true; }
virtual bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
private: private:
TmpSocketInode(TmpFileSystem&, ino_t, const TmpInodeInfo&); TmpSocketInode(TmpFileSystem&, ino_t, const TmpInodeInfo&);
@ -110,6 +116,10 @@ namespace Kernel
protected: protected:
virtual BAN::ErrorOr<BAN::String> link_target_impl() override; virtual BAN::ErrorOr<BAN::String> link_target_impl() override;
virtual bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
private: private:
TmpSymlinkInode(TmpFileSystem&, ino_t, const TmpInodeInfo&); TmpSymlinkInode(TmpFileSystem&, ino_t, const TmpInodeInfo&);
}; };
@ -136,6 +146,10 @@ namespace Kernel
virtual BAN::ErrorOr<void> create_directory_impl(BAN::StringView, mode_t, uid_t, gid_t) override final; virtual BAN::ErrorOr<void> create_directory_impl(BAN::StringView, mode_t, uid_t, gid_t) override final;
virtual BAN::ErrorOr<void> unlink_impl(BAN::StringView) override; virtual BAN::ErrorOr<void> unlink_impl(BAN::StringView) override;
virtual bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
private: private:
template<TmpFuncs::for_each_valid_entry_callback F> template<TmpFuncs::for_each_valid_entry_callback F>
void for_each_valid_entry(F callback); void for_each_valid_entry(F callback);

View File

@ -48,7 +48,10 @@ namespace Kernel::Input
protected: protected:
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override; virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override;
virtual bool has_data_impl() const override;
virtual bool can_read_impl() const override { return !m_event_queue.empty(); }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
}; };
} }

View File

@ -42,7 +42,10 @@ namespace Kernel::Input
protected: protected:
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override; virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override;
virtual bool has_data_impl() const override;
virtual bool can_read_impl() const override { return !m_event_queue.empty(); }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
}; };
} }

View File

@ -2,6 +2,8 @@
#include <BAN/NoCopyMove.h> #include <BAN/NoCopyMove.h>
#include <stdint.h>
namespace Kernel namespace Kernel
{ {
@ -27,4 +29,30 @@ namespace Kernel
Lock& m_lock; Lock& m_lock;
}; };
template<typename Lock>
class LockFreeGuard
{
BAN_NON_COPYABLE(LockFreeGuard);
BAN_NON_MOVABLE(LockFreeGuard);
public:
LockFreeGuard(Lock& lock)
: m_lock(lock)
, m_depth(lock.lock_depth())
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.unlock();
}
~LockFreeGuard()
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.lock();
}
private:
Lock& m_lock;
const uint32_t m_depth;
};
} }

View File

@ -28,6 +28,8 @@ namespace Kernel
virtual bool link_up() override { return m_link_up; } virtual bool link_up() override { return m_link_up; }
virtual int link_speed() override; virtual int link_speed() override;
virtual size_t payload_mtu() const { return E1000_RX_BUFFER_SIZE - sizeof(EthernetHeader); }
virtual void handle_irq() final override; virtual void handle_irq() final override;
protected: protected:
@ -44,6 +46,10 @@ namespace Kernel
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override;
virtual bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
private: private:
BAN::ErrorOr<void> read_mac_address(); BAN::ErrorOr<void> read_mac_address();

View File

@ -29,31 +29,10 @@ namespace Kernel
BAN::NetworkEndian<uint16_t> checksum { 0 }; BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address; BAN::IPv4Address src_address;
BAN::IPv4Address dst_address; BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
total_sum -= checksum;
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return ~(uint16_t)total_sum;
}
constexpr bool is_valid_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return total_sum == 0xFFFF;
}
}; };
static_assert(sizeof(IPv4Header) == 20); static_assert(sizeof(IPv4Header) == 20);
class IPv4Layer : public NetworkLayer class IPv4Layer final : public NetworkLayer
{ {
BAN_NON_COPYABLE(IPv4Layer); BAN_NON_COPYABLE(IPv4Layer);
BAN_NON_MOVABLE(IPv4Layer); BAN_NON_MOVABLE(IPv4Layer);
@ -66,10 +45,13 @@ namespace Kernel
void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan); void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan);
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override; virtual void unbind_socket(BAN::RefPtr<NetworkSocket>, uint16_t port) override;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override; virtual BAN::ErrorOr<void> bind_socket_to_unused(BAN::RefPtr<NetworkSocket>, const sockaddr* send_address, socklen_t send_address_len) override;
virtual BAN::ErrorOr<void> bind_socket_to_address(BAN::RefPtr<NetworkSocket>, const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) override; virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
virtual size_t header_size() const override { return sizeof(IPv4Header); }
private: private:
IPv4Layer(); IPv4Layer();
@ -86,7 +68,7 @@ namespace Kernel
}; };
private: private:
SpinLock m_lock; RecursiveSpinLock m_lock;
BAN::UniqPtr<ARPTable> m_arp_table; BAN::UniqPtr<ARPTable> m_arp_table;
Process* m_process { nullptr }; Process* m_process { nullptr };

View File

@ -52,6 +52,8 @@ namespace Kernel
virtual bool link_up() = 0; virtual bool link_up() = 0;
virtual int link_speed() = 0; virtual int link_speed() = 0;
virtual size_t payload_mtu() const = 0;
virtual dev_t rdev() const override { return m_rdev; } virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return m_name; } virtual BAN::StringView name() const override { return m_name; }

View File

@ -5,6 +5,15 @@
namespace Kernel namespace Kernel
{ {
struct PseudoHeader
{
BAN::IPv4Address src_ipv4 { 0 };
BAN::IPv4Address dst_ipv4 { 0 };
BAN::NetworkEndian<uint16_t> protocol { 0 };
BAN::NetworkEndian<uint16_t> extra { 0 };
};
static_assert(sizeof(PseudoHeader) == 12);
class NetworkSocket; class NetworkSocket;
enum class SocketType; enum class SocketType;
@ -13,13 +22,30 @@ namespace Kernel
public: public:
virtual ~NetworkLayer() {} virtual ~NetworkLayer() {}
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0; virtual void unbind_socket(BAN::RefPtr<NetworkSocket>, uint16_t port) = 0;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0; virtual BAN::ErrorOr<void> bind_socket_to_unused(BAN::RefPtr<NetworkSocket>, const sockaddr* send_address, socklen_t send_address_len) = 0;
virtual BAN::ErrorOr<void> bind_socket_to_address(BAN::RefPtr<NetworkSocket>, const sockaddr* address, socklen_t address_len) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) = 0; virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, BAN::ConstByteSpan, const sockaddr*, socklen_t) = 0;
virtual size_t header_size() const = 0;
protected: protected:
NetworkLayer() = default; NetworkLayer() = default;
}; };
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet, const PseudoHeader& pseudo_header)
{
uint32_t checksum = 0;
for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(&pseudo_header)[i]);
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
if (packet.size() % 2)
checksum += (uint16_t)packet[packet.size() - 1] << 8;
while (checksum >> 16)
checksum = (checksum >> 16) + (checksum & 0xFFFF);
return ~(uint16_t)checksum;
}
} }

View File

@ -6,14 +6,13 @@
#include <kernel/Networking/NetworkInterface.h> #include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h> #include <kernel/Networking/NetworkLayer.h>
#include <netinet/in.h>
namespace Kernel namespace Kernel
{ {
enum NetworkProtocol : uint8_t enum NetworkProtocol : uint8_t
{ {
ICMP = 0x01, ICMP = 0x01,
TCP = 0x06,
UDP = 0x11, UDP = 0x11,
}; };
@ -32,28 +31,22 @@ namespace Kernel
NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; } NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; }
virtual size_t protocol_header_size() const = 0; virtual size_t protocol_header_size() const = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) = 0; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) = 0;
virtual NetworkProtocol protocol() const = 0; virtual NetworkProtocol protocol() const = 0;
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) = 0;
bool is_bound() const { return m_interface != nullptr; }
protected: protected:
NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0;
virtual void on_close_impl() override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override; virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override;
protected: protected:
NetworkLayer& m_network_layer; NetworkLayer& m_network_layer;
NetworkInterface* m_interface = nullptr; NetworkInterface* m_interface = nullptr;
uint16_t m_port = PORT_NONE; uint16_t m_port { PORT_NONE };
}; };
} }

View File

@ -0,0 +1,138 @@
#pragma once
#include <BAN/Endianness.h>
#include <kernel/Memory/VirtualRange.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkSocket.h>
#include <kernel/Process.h>
#include <kernel/Semaphore.h>
namespace Kernel
{
struct TCPHeader
{
BAN::NetworkEndian<uint16_t> src_port { 0 };
BAN::NetworkEndian<uint16_t> dst_port { 0 };
BAN::NetworkEndian<uint32_t> seq_number { 0 };
BAN::NetworkEndian<uint32_t> ack_number { 0 };
uint8_t reserved : 4 { 0 };
uint8_t data_offset : 4 { 0 };
uint8_t fin : 1 { 0 };
uint8_t syn : 1 { 0 };
uint8_t rst : 1 { 0 };
uint8_t psh : 1 { 0 };
uint8_t ack : 1 { 0 };
uint8_t urg : 1 { 0 };
uint8_t ece : 1 { 0 };
uint8_t cwr : 1 { 0 };
BAN::NetworkEndian<uint16_t> window_size { 0 };
BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::NetworkEndian<uint16_t> urgent_pointer { 0 };
uint8_t options[0];
};
static_assert(sizeof(TCPHeader) == 20);
class TCPSocket final : public NetworkSocket
{
public:
static constexpr size_t m_tcp_options_bytes = 4;
public:
static BAN::ErrorOr<BAN::RefPtr<TCPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
~TCPSocket();
virtual NetworkProtocol protocol() const override { return NetworkProtocol::TCP; }
virtual size_t protocol_header_size() const override { return sizeof(TCPHeader) + m_tcp_options_bytes; }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override;
protected:
virtual void on_close_impl() override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) override;
virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan message, sockaddr* address, socklen_t* address_len) override;
virtual bool can_read_impl() const override { return m_recv_window.data_size; }
virtual bool can_write_impl() const override { return m_state == State::Established; }
virtual bool has_error_impl() const override { return m_state != State::Established && m_state != State::Listen && m_state != State::SynSent && m_state != State::SynReceived; }
private:
enum class State
{
Closed = 0,
Listen,
SynSent,
SynReceived,
Established,
FinWait1,
FinWait2,
CloseWait,
Closing,
LastAck,
TimeWait,
};
struct RecvWindowInfo
{
uint32_t start_seq { 0 }; // sequence number of first byte in buffer
bool has_ghost_byte { false };
uint32_t data_size { 0 }; // number of bytes in this buffer
BAN::UniqPtr<VirtualRange> buffer;
};
struct SendWindowInfo
{
uint32_t mss { 0 }; // maximum segment size
uint16_t non_scaled_size { 0 }; // window size without scaling
uint8_t scale { 0 }; // window scale
uint32_t scaled_size() const { return (uint32_t)non_scaled_size << scale; }
uint32_t start_seq { 0 }; // sequence number of first byte in buffer
uint32_t current_seq { 0 }; // sequence number of next send
uint32_t current_ack { 0 }; // sequence number aknowledged by connection
uint64_t last_send_ms { 0 }; // last send time, used for retransmission timeout
bool has_ghost_byte { false };
uint32_t data_size { 0 }; // number of bytes in this buffer
BAN::UniqPtr<VirtualRange> buffer;
};
private:
TCPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
void process_task();
void set_connection_as_closed();
private:
State m_state = State::Closed;
Process* m_process { nullptr };
uint64_t m_time_wait_start_ms { 0 };
RecursiveSpinLock m_lock;
Semaphore m_semaphore;
BAN::Atomic<bool> m_should_ack { false };
RecvWindowInfo m_recv_window;
SendWindowInfo m_send_window;
struct ConnectionInfo
{
sockaddr_storage address;
socklen_t address_len;
};
BAN::Optional<ConnectionInfo> m_connection_info;
};
}

View File

@ -24,30 +24,40 @@ namespace Kernel
public: public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&); static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) override;
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader) override;
protected: protected:
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_addr, uint16_t sender_port) override; virtual void receive_packet(BAN::ConstByteSpan, const sockaddr_storage& sender) override;
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override;
virtual void on_close_impl() override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len) override;
virtual bool can_read_impl() const override { return !m_packets.empty(); }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
struct PacketInfo struct PacketInfo
{ {
BAN::IPv4Address sender_addr; sockaddr_storage sender;
uint16_t sender_port;
size_t packet_size; size_t packet_size;
}; };
private: private:
static constexpr size_t packet_buffer_size = 10 * PAGE_SIZE; static constexpr size_t packet_buffer_size = 10 * PAGE_SIZE;
BAN::UniqPtr<VirtualRange> m_packet_buffer; BAN::UniqPtr<VirtualRange> m_packet_buffer;
BAN::CircularQueue<PacketInfo, 128> m_packets; BAN::CircularQueue<PacketInfo, 32> m_packets;
size_t m_packet_total_size { 0 }; size_t m_packet_total_size { 0 };
Semaphore m_semaphore; SpinLock m_packet_lock;
Semaphore m_packet_semaphore;
friend class BAN::RefPtr<UDPSocket>; friend class BAN::RefPtr<UDPSocket>;
}; };

View File

@ -17,12 +17,18 @@ namespace Kernel
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&); static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
protected: protected:
virtual void on_close_impl() override;
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override; virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override; virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override; virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override; virtual BAN::ErrorOr<size_t> sendto_impl(BAN::ConstByteSpan, const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override; virtual BAN::ErrorOr<size_t> recvfrom_impl(BAN::ByteSpan, sockaddr*, socklen_t*) override;
virtual bool can_read_impl() const override;
virtual bool can_write_impl() const override;
virtual bool has_error_impl() const override { return false; }
private: private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&); UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);
@ -47,7 +53,7 @@ namespace Kernel
struct ConnectionlessInfo struct ConnectionlessInfo
{ {
BAN::String peer_address;
}; };
private: private:

View File

@ -16,6 +16,7 @@
#include <sys/banan-os.h> #include <sys/banan-os.h>
#include <sys/mman.h> #include <sys/mman.h>
#include <sys/select.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <termios.h> #include <termios.h>
@ -95,6 +96,8 @@ namespace Kernel
BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); } BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); }
BAN::ErrorOr<long> sys_getpgid(pid_t); BAN::ErrorOr<long> sys_getpgid(pid_t);
BAN::ErrorOr<long> sys_getpid() const { return pid(); }
BAN::ErrorOr<long> open_inode(BAN::RefPtr<Inode>, int flags); BAN::ErrorOr<long> open_inode(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode); BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode);
@ -126,6 +129,8 @@ namespace Kernel
BAN::ErrorOr<long> sys_ioctl(int fildes, int request, void* arg); BAN::ErrorOr<long> sys_ioctl(int fildes, int request, void* arg);
BAN::ErrorOr<long> sys_pselect(sys_pselect_t* arguments);
BAN::ErrorOr<long> sys_pipe(int fildes[2]); BAN::ErrorOr<long> sys_pipe(int fildes[2]);
BAN::ErrorOr<long> sys_dup(int fildes); BAN::ErrorOr<long> sys_dup(int fildes);
BAN::ErrorOr<long> sys_dup2(int fildes, int fildes2); BAN::ErrorOr<long> sys_dup2(int fildes, int fildes2);
@ -145,7 +150,7 @@ namespace Kernel
BAN::ErrorOr<void> mount(BAN::StringView source, BAN::StringView target); BAN::ErrorOr<void> mount(BAN::StringView source, BAN::StringView target);
BAN::ErrorOr<long> sys_read_dir_entries(int fd, DirectoryEntryList* buffer, size_t buffer_size); BAN::ErrorOr<long> sys_readdir(int fd, DirectoryEntryList* buffer, size_t buffer_size);
BAN::ErrorOr<long> sys_mmap(const sys_mmap_t*); BAN::ErrorOr<long> sys_mmap(const sys_mmap_t*);
BAN::ErrorOr<long> sys_munmap(void* addr, size_t len); BAN::ErrorOr<long> sys_munmap(void* addr, size_t len);
@ -154,7 +159,7 @@ namespace Kernel
BAN::ErrorOr<long> sys_tty_ctrl(int fildes, int command, int flags); BAN::ErrorOr<long> sys_tty_ctrl(int fildes, int command, int flags);
BAN::ErrorOr<long> sys_signal(int, void (*)(int)); BAN::ErrorOr<long> sys_signal(int, void (*)(int));
static BAN::ErrorOr<long> sys_kill(pid_t pid, int signal); BAN::ErrorOr<long> sys_kill(pid_t pid, int signal);
BAN::ErrorOr<long> sys_tcsetpgrp(int fd, pid_t pgid); BAN::ErrorOr<long> sys_tcsetpgrp(int fd, pid_t pgid);

View File

@ -0,0 +1,16 @@
#pragma once
#include <stdint.h>
namespace Kernel
{
class Random
{
public:
static void initialize();
static uint32_t get_u32();
static uint64_t get_u64();
};
}

View File

@ -19,9 +19,9 @@ namespace Kernel
void reschedule(); void reschedule();
void reschedule_if_idling(); void reschedule_if_idling();
void set_current_thread_sleeping(uint64_t); void set_current_thread_sleeping(uint64_t wake_time);
void block_current_thread(Semaphore*); void block_current_thread(Semaphore*, uint64_t wake_time);
void unblock_threads(Semaphore*); void unblock_threads(Semaphore*);
// Makes sleeping or blocked thread with tid active. // Makes sleeping or blocked thread with tid active.
void unblock_thread(pid_t tid); void unblock_thread(pid_t tid);
@ -36,6 +36,8 @@ namespace Kernel
private: private:
Scheduler() = default; Scheduler() = default;
void set_current_thread_sleeping_impl(uint64_t wake_time);
void wake_threads(); void wake_threads();
[[nodiscard]] bool save_current_thread(); [[nodiscard]] bool save_current_thread();
void remove_and_advance_current_thread(); void remove_and_advance_current_thread();
@ -51,17 +53,13 @@ namespace Kernel
{} {}
Thread* thread; Thread* thread;
union
{
uint64_t wake_time; uint64_t wake_time;
Semaphore* semaphore; Semaphore* semaphore;
}; };
};
Thread* m_idle_thread { nullptr }; Thread* m_idle_thread { nullptr };
BAN::LinkedList<SchedulerThread> m_active_threads; BAN::LinkedList<SchedulerThread> m_active_threads;
BAN::LinkedList<SchedulerThread> m_sleeping_threads; BAN::LinkedList<SchedulerThread> m_sleeping_threads;
BAN::LinkedList<SchedulerThread> m_blocking_threads;
BAN::LinkedList<SchedulerThread>::iterator m_current_thread; BAN::LinkedList<SchedulerThread>::iterator m_current_thread;

View File

@ -6,7 +6,9 @@ namespace Kernel
class Semaphore class Semaphore
{ {
public: public:
void block(); void block_indefinite();
void block_with_timeout(uint64_t timeout_ms);
void block_with_wake_time(uint64_t wake_time_ms);
void unblock(); void unblock();
}; };

View File

@ -19,6 +19,8 @@ namespace Kernel
void unlock(); void unlock();
bool is_locked() const; bool is_locked() const;
uint32_t lock_depth() const { return m_locker != -1; }
private: private:
BAN::Atomic<pid_t> m_locker = -1; BAN::Atomic<pid_t> m_locker = -1;
}; };
@ -34,6 +36,8 @@ namespace Kernel
void unlock(); void unlock();
bool is_locked() const; bool is_locked() const;
uint32_t lock_depth() const { return m_lock_depth; }
private: private:
BAN::Atomic<pid_t> m_locker = -1; BAN::Atomic<pid_t> m_locker = -1;
BAN::Atomic<uint32_t> m_lock_depth = 0; BAN::Atomic<uint32_t> m_lock_depth = 0;

View File

@ -24,6 +24,11 @@ namespace Kernel
virtual dev_t rdev() const override { return m_rdev; } virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return m_name; } virtual BAN::StringView name() const override { return m_name; }
protected:
virtual bool can_read_impl() const override { return false; }
virtual bool can_write_impl() const override { return false; }
virtual bool has_error_impl() const override { return false; }
private: private:
NVMeController(PCI::Device& pci_device); NVMeController(PCI::Device& pci_device);
virtual BAN::ErrorOr<void> initialize() override; virtual BAN::ErrorOr<void> initialize() override;

View File

@ -46,6 +46,10 @@ namespace Kernel
protected: protected:
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override; virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override;
virtual bool can_read_impl() const override { return true; }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
const dev_t m_rdev; const dev_t m_rdev;
}; };

View File

@ -39,6 +39,10 @@ namespace Kernel
virtual BAN::ErrorOr<void> write_sectors_impl(uint64_t lba, uint64_t sector_count, BAN::ConstByteSpan) = 0; virtual BAN::ErrorOr<void> write_sectors_impl(uint64_t lba, uint64_t sector_count, BAN::ConstByteSpan) = 0;
void add_disk_cache(); void add_disk_cache();
virtual bool can_read_impl() const override { return true; }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
private: private:
SpinLock m_lock; SpinLock m_lock;
BAN::Optional<DiskCache> m_disk_cache; BAN::Optional<DiskCache> m_disk_cache;

View File

@ -44,7 +44,9 @@ namespace Kernel
virtual BAN::ErrorOr<void> chmod_impl(mode_t) override; virtual BAN::ErrorOr<void> chmod_impl(mode_t) override;
virtual BAN::ErrorOr<void> chown_impl(uid_t, gid_t) override; virtual BAN::ErrorOr<void> chown_impl(uid_t, gid_t) override;
virtual bool has_data_impl() const override; virtual bool can_read_impl() const override { return m_output.flush; }
virtual bool can_write_impl() const override { return true; }
virtual bool has_error_impl() const override { return false; }
protected: protected:
TTY(mode_t mode, uid_t uid, gid_t gid) TTY(mode_t mode, uid_t uid, gid_t gid)

View File

@ -47,8 +47,10 @@ namespace Kernel
void handle_signal(int signal = 0); void handle_signal(int signal = 0);
bool add_signal(int signal); bool add_signal(int signal);
// blocks semaphore and returns either on unblock, eintr or spuriously // blocks semaphore and returns either on unblock, eintr, spuriously or after timeout
BAN::ErrorOr<void> block_or_eintr(Semaphore&); BAN::ErrorOr<void> block_or_eintr_indefinite(Semaphore& semaphore);
BAN::ErrorOr<void> block_or_eintr_or_timeout(Semaphore& semaphore, uint64_t timeout_ms, bool etimedout);
BAN::ErrorOr<void> block_or_eintr_or_waketime(Semaphore& semaphore, uint64_t wake_time_ms, bool etimedout);
void set_return_rsp(uintptr_t& rsp) { m_return_rsp = &rsp; } void set_return_rsp(uintptr_t& rsp) { m_return_rsp = &rsp; }
void set_return_rip(uintptr_t& rip) { m_return_rip = &rip; } void set_return_rip(uintptr_t& rip) { m_return_rip = &rip; }

View File

@ -64,9 +64,8 @@ namespace Kernel
{ {
while (!s_instance->m_should_sync) while (!s_instance->m_should_sync)
{ {
s_instance->m_device_lock.unlock(); LockFreeGuard _(s_instance->m_device_lock);
s_instance->m_sync_semaphore.block(); s_instance->m_sync_semaphore.block_indefinite();
s_instance->m_device_lock.lock();
} }
for (auto& device : s_instance->m_devices) for (auto& device : s_instance->m_devices)
@ -105,7 +104,7 @@ namespace Kernel
m_sync_semaphore.unblock(); m_sync_semaphore.unblock();
} }
if (should_block) if (should_block)
m_sync_done.block(); m_sync_done.block_indefinite();
} }
void DevFileSystem::add_device(BAN::RefPtr<Device> device) void DevFileSystem::add_device(BAN::RefPtr<Device> device)

View File

@ -148,20 +148,20 @@ namespace Kernel
return listen_impl(backlog); return listen_impl(backlog);
} }
BAN::ErrorOr<size_t> Inode::sendto(const sys_sendto_t* arguments) BAN::ErrorOr<size_t> Inode::sendto(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (!mode().ifsock()) if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return sendto_impl(arguments); return sendto_impl(message, address, address_len);
}; };
BAN::ErrorOr<size_t> Inode::recvfrom(sys_recvfrom_t* arguments) BAN::ErrorOr<size_t> Inode::recvfrom(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (!mode().ifsock()) if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return recvfrom_impl(arguments); return recvfrom_impl(buffer, address, address_len);
}; };
BAN::ErrorOr<size_t> Inode::read(off_t offset, BAN::ByteSpan buffer) BAN::ErrorOr<size_t> Inode::read(off_t offset, BAN::ByteSpan buffer)
@ -201,10 +201,22 @@ namespace Kernel
return chown_impl(uid, gid); return chown_impl(uid, gid);
} }
bool Inode::has_data() const bool Inode::can_read() const
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
return has_data_impl(); return can_read_impl();
}
bool Inode::can_write() const
{
LockGuard _(m_lock);
return can_write_impl();
}
bool Inode::has_error() const
{
LockGuard _(m_lock);
return has_error_impl();
} }
BAN::ErrorOr<long> Inode::ioctl(int request, void* arg) BAN::ErrorOr<long> Inode::ioctl(int request, void* arg)

View File

@ -1,5 +1,6 @@
#include <kernel/FS/Pipe.h> #include <kernel/FS/Pipe.h>
#include <kernel/LockGuard.h> #include <kernel/LockGuard.h>
#include <kernel/Thread.h>
#include <kernel/Timer/Timer.h> #include <kernel/Timer/Timer.h>
namespace Kernel namespace Kernel
@ -46,9 +47,8 @@ namespace Kernel
{ {
if (m_writing_count == 0) if (m_writing_count == 0)
return 0; return 0;
m_lock.unlock(); LockFreeGuard lock_free(m_lock);
m_semaphore.block(); TRY(Thread::current().block_or_eintr_indefinite(m_semaphore));
m_lock.lock();
} }
size_t to_copy = BAN::Math::min<size_t>(buffer.size(), m_buffer.size()); size_t to_copy = BAN::Math::min<size_t>(buffer.size(), m_buffer.size());

View File

@ -195,7 +195,7 @@ namespace Kernel::Input
while (true) while (true)
{ {
if (m_event_queue.empty()) if (m_event_queue.empty())
TRY(Thread::current().block_or_eintr(m_semaphore)); TRY(Thread::current().block_or_eintr_indefinite(m_semaphore));
CriticalScope _; CriticalScope _;
if (m_event_queue.empty()) if (m_event_queue.empty())
@ -208,10 +208,4 @@ namespace Kernel::Input
} }
} }
bool PS2Keyboard::has_data_impl() const
{
CriticalScope _;
return !m_event_queue.empty();
}
} }

View File

@ -179,7 +179,7 @@ namespace Kernel::Input
while (true) while (true)
{ {
if (m_event_queue.empty()) if (m_event_queue.empty())
TRY(Thread::current().block_or_eintr(m_semaphore)); TRY(Thread::current().block_or_eintr_indefinite(m_semaphore));
CriticalScope _; CriticalScope _;
if (m_event_queue.empty()) if (m_event_queue.empty())
@ -192,10 +192,4 @@ namespace Kernel::Input
} }
} }
bool PS2Mouse::has_data_impl() const
{
CriticalScope _;
return !m_event_queue.empty();
}
} }

View File

@ -155,7 +155,7 @@ namespace Kernel
if (!pending.has_value()) if (!pending.has_value())
{ {
m_pending_semaphore.block(); m_pending_semaphore.block_indefinite();
continue; continue;
} }

View File

@ -3,7 +3,9 @@
#include <kernel/Networking/ICMP.h> #include <kernel/Networking/ICMP.h>
#include <kernel/Networking/IPv4Layer.h> #include <kernel/Networking/IPv4Layer.h>
#include <kernel/Networking/NetworkManager.h> #include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/TCPSocket.h>
#include <kernel/Networking/UDPSocket.h> #include <kernel/Networking/UDPSocket.h>
#include <kernel/Random.h>
#include <netinet/in.h> #include <netinet/in.h>
@ -12,6 +14,11 @@
namespace Kernel namespace Kernel
{ {
enum IPv4Flags : uint16_t
{
DF = 1 << 14,
};
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create() BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
{ {
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create()); auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
@ -57,10 +64,11 @@ namespace Kernel
header.protocol = protocol; header.protocol = protocol;
header.src_address = src_ipv4; header.src_address = src_ipv4;
header.dst_address = dst_ipv4; header.dst_address = dst_ipv4;
header.checksum = header.calculate_checksum(); header.checksum = 0;
header.checksum = calculate_internet_checksum(BAN::ConstByteSpan::from(header), {});
} }
void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket) void IPv4Layer::unbind_socket(BAN::RefPtr<NetworkSocket> socket, uint16_t port)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (m_bound_sockets.contains(port)) if (m_bound_sockets.contains(port))
@ -72,33 +80,55 @@ namespace Kernel
NetworkManager::get().TmpFileSystem::remove_from_cache(socket); NetworkManager::get().TmpFileSystem::remove_from_cache(socket);
} }
BAN::ErrorOr<void> IPv4Layer::bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket) BAN::ErrorOr<void> IPv4Layer::bind_socket_to_unused(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
{ {
if (NetworkManager::get().interfaces().empty()) if (!address || address_len < (socklen_t)sizeof(sockaddr_in))
return BAN::Error::from_errno(EADDRNOTAVAIL); return BAN::Error::from_errno(EINVAL);
if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
LockGuard _(m_lock); LockGuard _(m_lock);
if (port == NetworkSocket::PORT_NONE) uint16_t port = NetworkSocket::PORT_NONE;
{ for (uint32_t i = 0; i < 100 && port == NetworkSocket::PORT_NONE; i++)
for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++) if (uint32_t temp = 0xC000 | (Random::get_u32() & 0x3FFF); !m_bound_sockets.contains(temp))
{ port = temp;
if (!m_bound_sockets.contains(temp)) for (uint32_t temp = 0xC000; temp < 0xFFFF && port == NetworkSocket::PORT_NONE; temp++)
{ if (!m_bound_sockets.contains(temp))
port = temp; port = temp;
break;
}
}
if (port == NetworkSocket::PORT_NONE) if (port == NetworkSocket::PORT_NONE)
{ {
dwarnln("No ports available"); dwarnln("No ports available");
return BAN::Error::from_errno(EAGAIN); return BAN::Error::from_errno(EAGAIN);
} }
dprintln_if(DEBUG_IPV4, "using port {}", port);
struct sockaddr_in target;
target.sin_family = AF_INET;
target.sin_port = BAN::host_to_network_endian(port);
target.sin_addr.s_addr = sockaddr_in.sin_addr.s_addr;
return bind_socket_to_address(socket, (sockaddr*)&target, sizeof(sockaddr_in));
} }
BAN::ErrorOr<void> IPv4Layer::bind_socket_to_address(BAN::RefPtr<NetworkSocket> socket, const sockaddr* address, socklen_t address_len)
{
if (NetworkManager::get().interfaces().empty())
return BAN::Error::from_errno(EADDRNOTAVAIL);
if (!address || address_len < (socklen_t)sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
uint16_t port = BAN::host_to_network_endian(sockaddr_in.sin_port);
LockGuard _(m_lock);
if (m_bound_sockets.contains(port)) if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE); return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket)); TRY(m_bound_sockets.insert(port, TRY(socket->get_weak_ptr())));
// FIXME: actually determine proper interface // FIXME: actually determine proper interface
auto interface = NetworkManager::get().interfaces().front(); auto interface = NetworkManager::get().interfaces().front();
@ -107,28 +137,37 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, const sys_sendto_t* arguments) BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, BAN::ConstByteSpan buffer, const sockaddr* address, socklen_t address_len)
{ {
if (arguments->dest_addr->sa_family != AF_INET) if (address->sa_family != AF_INET)
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(arguments->dest_addr); if (address == nullptr || address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(address);
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port); auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr }; auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4)); auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4));
BAN::Vector<uint8_t> packet_buffer; BAN::Vector<uint8_t> packet_buffer;
TRY(packet_buffer.resize(arguments->length + sizeof(IPv4Header) + socket.protocol_header_size())); TRY(packet_buffer.resize(buffer.size() + sizeof(IPv4Header) + socket.protocol_header_size()));
auto packet = BAN::ByteSpan { packet_buffer.span() }; auto packet = BAN::ByteSpan { packet_buffer.span() };
auto pseudo_header = PseudoHeader {
.src_ipv4 = socket.interface().get_ipv4_address(),
.dst_ipv4 = dst_ipv4,
.protocol = socket.protocol()
};
memcpy( memcpy(
packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(), packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(),
arguments->message, buffer.data(),
arguments->length buffer.size()
); );
socket.add_protocol_header( socket.add_protocol_header(
packet.slice(sizeof(IPv4Header)), packet.slice(sizeof(IPv4Header)),
dst_port dst_port,
pseudo_header
); );
add_ipv4_header( add_ipv4_header(
packet, packet,
@ -139,17 +178,7 @@ namespace Kernel
TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet)); TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet));
return arguments->length; return buffer.size();
}
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet)
{
uint32_t checksum = 0;
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
while (checksum >> 16)
checksum = (checksum >> 16) | (checksum & 0xFFFF);
return ~(uint16_t)checksum;
} }
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet) BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet)
@ -157,9 +186,11 @@ namespace Kernel
auto& ipv4_header = packet.as<const IPv4Header>(); auto& ipv4_header = packet.as<const IPv4Header>();
auto ipv4_data = packet.slice(sizeof(IPv4Header)); auto ipv4_data = packet.slice(sizeof(IPv4Header));
ASSERT(ipv4_header.is_valid_checksum());
auto src_ipv4 = ipv4_header.src_address; auto src_ipv4 = ipv4_header.src_address;
uint16_t dst_port = NetworkSocket::PORT_NONE;
uint16_t src_port = NetworkSocket::PORT_NONE;
switch (ipv4_header.protocol) switch (ipv4_header.protocol)
{ {
case NetworkProtocol::ICMP: case NetworkProtocol::ICMP:
@ -174,7 +205,7 @@ namespace Kernel
auto& reply_icmp_header = ipv4_data.as<ICMPHeader>(); auto& reply_icmp_header = ipv4_data.as<ICMPHeader>();
reply_icmp_header.type = ICMPType::EchoReply; reply_icmp_header.type = ICMPType::EchoReply;
reply_icmp_header.checksum = 0; reply_icmp_header.checksum = 0;
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data); reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data, {});
add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP); add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP);
@ -185,31 +216,60 @@ namespace Kernel
dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type); dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type);
break; break;
} }
break; return {};
} }
case NetworkProtocol::UDP: case NetworkProtocol::UDP:
{ {
auto& udp_header = ipv4_data.as<const UDPHeader>(); auto& udp_header = ipv4_data.as<const UDPHeader>();
uint16_t src_port = udp_header.src_port; dst_port = udp_header.dst_port;
uint16_t dst_port = udp_header.dst_port; src_port = udp_header.src_port;
break;
}
case NetworkProtocol::TCP:
{
auto& tcp_header = ipv4_data.as<const TCPHeader>();
dst_port = tcp_header.dst_port;
src_port = tcp_header.src_port;
break;
}
default:
dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol);
return {};
}
ASSERT(dst_port != NetworkSocket::PORT_NONE);
ASSERT(src_port != NetworkSocket::PORT_NONE);
BAN::RefPtr<Kernel::NetworkSocket> bound_socket;
{
LockGuard _(m_lock); LockGuard _(m_lock);
if (!m_bound_sockets.contains(dst_port))
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
bound_socket = m_bound_sockets[dst_port].lock();
}
if (!m_bound_sockets.contains(dst_port) || !m_bound_sockets[dst_port].valid()) if (!bound_socket)
{ {
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port); dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {}; return {};
} }
auto udp_data = ipv4_data.slice(sizeof(UDPHeader)); if (bound_socket->protocol() != ipv4_header.protocol)
m_bound_sockets[dst_port].lock()->add_packet(udp_data, src_ipv4, src_port); {
break; dprintln_if(DEBUG_IPV4, "got data with wrong protocol ({}) on port {} (bound as {})", ipv4_header.protocol, dst_port, (uint8_t)bound_socket->protocol());
} return {};
default:
dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol);
break;
} }
sockaddr_in sender;
sender.sin_family = AF_INET;
sender.sin_port = BAN::NetworkEndian<uint16_t>(src_port);
sender.sin_addr.s_addr = src_ipv4.raw;
bound_socket->receive_packet(ipv4_data, *reinterpret_cast<const sockaddr_storage*>(&sender));
return {}; return {};
} }
@ -230,7 +290,7 @@ namespace Kernel
if (!pending.has_value()) if (!pending.has_value())
{ {
m_pending_semaphore.block(); m_pending_semaphore.block_indefinite();
continue; continue;
} }
@ -262,14 +322,17 @@ namespace Kernel
} }
auto& ipv4_header = buffer.as<const IPv4Header>(); auto& ipv4_header = buffer.as<const IPv4Header>();
if (!ipv4_header.is_valid_checksum()) if (calculate_internet_checksum(BAN::ConstByteSpan::from(ipv4_header), {}) != 0)
{ {
dwarnln("Invalid IPv4 packet"); dwarnln("Invalid IPv4 packet");
return; return;
} }
if (ipv4_header.total_length > buffer.size()) if (ipv4_header.total_length > buffer.size() || ipv4_header.total_length > interface.payload_mtu())
{ {
dwarnln("Too short IPv4 packet"); if (ipv4_header.flags_frament & IPv4Flags::DF)
dwarnln("Invalid IPv4 packet");
else
dwarnln("IPv4 fragmentation not supported");
return; return;
} }

View File

@ -5,6 +5,7 @@
#include <kernel/Networking/E1000/E1000E.h> #include <kernel/Networking/E1000/E1000E.h>
#include <kernel/Networking/ICMP.h> #include <kernel/Networking/ICMP.h>
#include <kernel/Networking/NetworkManager.h> #include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/TCPSocket.h>
#include <kernel/Networking/UDPSocket.h> #include <kernel/Networking/UDPSocket.h>
#include <kernel/Networking/UNIX/Socket.h> #include <kernel/Networking/UNIX/Socket.h>
@ -76,15 +77,17 @@ namespace Kernel
switch (domain) switch (domain)
{ {
case SocketDomain::INET: case SocketDomain::INET:
switch (type)
{ {
if (type != SocketType::DGRAM) case SocketType::DGRAM:
case SocketType::STREAM:
break;
default:
return BAN::Error::from_errno(EPROTOTYPE); return BAN::Error::from_errno(EPROTOTYPE);
break;
} }
break;
case SocketDomain::UNIX: case SocketDomain::UNIX:
{
break; break;
}
default: default:
return BAN::Error::from_errno(EAFNOSUPPORT); return BAN::Error::from_errno(EAFNOSUPPORT);
} }
@ -100,9 +103,18 @@ namespace Kernel
{ {
case SocketDomain::INET: case SocketDomain::INET:
{ {
if (type == SocketType::DGRAM) switch (type)
{
case SocketType::DGRAM:
socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info)); socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info));
break; break;
case SocketType::STREAM:
socket = TRY(TCPSocket::create(*m_ipv4_layer, ino, inode_info));
break;
default:
ASSERT_NOT_REACHED();
}
break;
} }
case SocketDomain::UNIX: case SocketDomain::UNIX:
{ {

View File

@ -15,12 +15,6 @@ namespace Kernel
{ {
} }
void NetworkSocket::on_close_impl()
{
if (m_interface)
m_network_layer.unbind_socket(m_port, this);
}
void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port) void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port)
{ {
ASSERT(!m_interface); ASSERT(!m_interface);
@ -29,59 +23,6 @@ namespace Kernel
m_port = port; m_port = port;
} }
BAN::ErrorOr<void> NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (m_interface || address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
auto* addr_in = reinterpret_cast<const sockaddr_in*>(address);
uint16_t dst_port = BAN::host_to_network_endian(addr_in->sin_port);
return m_network_layer.bind_socket(dst_port, this);
}
BAN::ErrorOr<size_t> NetworkSocket::sendto_impl(const sys_sendto_t* arguments)
{
if (arguments->flags)
{
dprintln("flags not supported");
return BAN::Error::from_errno(ENOTSUP);
}
if (!m_interface)
TRY(m_network_layer.bind_socket(PORT_NONE, this));
return TRY(m_network_layer.sendto(*this, arguments));
}
BAN::ErrorOr<size_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{
sockaddr_in* sender_addr = nullptr;
if (arguments->address)
{
ASSERT(arguments->address_len);
if (*arguments->address_len < (socklen_t)sizeof(sockaddr_in))
*arguments->address_len = 0;
else
{
sender_addr = reinterpret_cast<sockaddr_in*>(arguments->address);
*arguments->address_len = sizeof(sockaddr_in);
}
}
if (!m_interface)
{
dprintln("No interface bound");
return BAN::Error::from_errno(EINVAL);
}
if (m_port == PORT_NONE)
{
dprintln("No port bound");
return BAN::Error::from_errno(EINVAL);
}
return TRY(read_packet(BAN::ByteSpan { reinterpret_cast<uint8_t*>(arguments->buffer), arguments->length }, sender_addr));
}
BAN::ErrorOr<long> NetworkSocket::ioctl_impl(int request, void* arg) BAN::ErrorOr<long> NetworkSocket::ioctl_impl(int request, void* arg)
{ {
if (!arg) if (!arg)

View File

@ -0,0 +1,644 @@
#include <kernel/LockGuard.h>
#include <kernel/Networking/TCPSocket.h>
#include <kernel/Random.h>
#include <kernel/Timer/Timer.h>
#include <netinet/in.h>
#define DEBUG_TCP 0
namespace Kernel
{
enum TCPOption : uint8_t
{
End = 0x00,
NOP = 0x01,
MaximumSeqmentSize = 0x02,
WindowScale = 0x03,
};
static constexpr size_t s_window_buffer_size = 15 * PAGE_SIZE;
static_assert(s_window_buffer_size <= UINT16_MAX);
BAN::ErrorOr<BAN::RefPtr<TCPSocket>> TCPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
{
auto* socket_ptr = new TCPSocket(network_layer, ino, inode_info);
if (socket_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
auto socket = BAN::RefPtr<TCPSocket>::adopt(socket_ptr);
socket->m_recv_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(vaddr_t)0,
s_window_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
socket->m_send_window.buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(vaddr_t)0,
s_window_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
socket->m_process = Process::create_kernel(
[](void* socket_ptr)
{
reinterpret_cast<TCPSocket*>(socket_ptr)->process_task();
}, socket.ptr()
);
return socket;
}
TCPSocket::TCPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: NetworkSocket(network_layer, ino, inode_info)
{
m_send_window.start_seq = Random::get_u32() & 0x7FFFFFFF;
m_send_window.current_seq = m_send_window.start_seq;
}
TCPSocket::~TCPSocket()
{
ASSERT(!is_bound());
ASSERT(m_process == nullptr);
dprintln_if(DEBUG_TCP, "socket destroyed");
}
void TCPSocket::on_close_impl()
{
LockGuard _(m_lock);
if (!is_bound())
return;
switch (m_state)
{
case State::Established:
break;
case State::SynSent:
set_connection_as_closed();
// fall through
case State::SynReceived:
case State::FinWait1:
case State::FinWait2:
case State::CloseWait:
case State::Closing:
case State::TimeWait:
case State::LastAck:
return;
case State::Closed: ASSERT_NOT_REACHED();
case State::Listen: ASSERT_NOT_REACHED();
}
m_state = State::FinWait1;
m_should_ack = true;
dprintln_if(DEBUG_TCP, "Initiated close");
}
BAN::ErrorOr<void> TCPSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len > (socklen_t)sizeof(sockaddr_storage))
address_len = sizeof(sockaddr_storage);
LockGuard _(m_lock);
ASSERT(!m_connection_info.has_value());
switch (m_state)
{
case State::Closed:
break;
case State::SynSent:
case State::SynReceived:
return BAN::Error::from_errno(EALREADY);
case State::Established:
case State::FinWait1:
case State::FinWait2:
case State::CloseWait:
case State::Closing:
case State::LastAck:
case State::TimeWait:
return BAN::Error::from_errno(EISCONN);
case State::Listen:
return BAN::Error::from_errno(EOPNOTSUPP);
};
if (!is_bound())
TRY(m_network_layer.bind_socket_to_unused(this, address, address_len));
m_connection_info.emplace(sockaddr_storage {}, address_len);
memcpy(&m_connection_info->address, address, address_len);
TRY(m_network_layer.sendto(*this, {}, address, address_len));
ASSERT(m_state == State::SynSent);
dprintln_if(DEBUG_TCP, "Sent SYN");
uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + 5000;
while (m_state != State::Established)
{
LockFreeGuard free(m_lock);
if (SystemTimer::get().ms_since_boot() >= wake_time_ms)
return BAN::Error::from_errno(ECONNREFUSED);
TRY(Thread::current().block_or_eintr_or_waketime(m_semaphore, wake_time_ms, true));
}
return {};
}
template<size_t Off, TCPOption Op>
static void add_tcp_header_option(TCPHeader& header, uint32_t value)
{
if constexpr(Op == TCPOption::MaximumSeqmentSize)
{
header.options[Off + 0] = Op;
header.options[Off + 1] = 0x04;
header.options[Off + 2] = value >> 8;
header.options[Off + 3] = value;
}
else if constexpr(Op == TCPOption::WindowScale)
{
header.options[Off + 0] = Op;
header.options[Off + 1] = 0x03;
header.options[Off + 2] = value;
}
}
struct ParsedTCPOptions
{
BAN::Optional<uint16_t> maximum_seqment_size;
BAN::Optional<uint8_t> window_scale;
};
static ParsedTCPOptions parse_tcp_options(const TCPHeader& header)
{
ParsedTCPOptions result;
for (size_t i = 0; i < header.data_offset * sizeof(uint32_t) - sizeof(TCPHeader) - 1; i++)
{
if (header.options[i] == TCPOption::End)
break;
if (header.options[i] == TCPOption::NOP)
continue;
if (header.options[i] == TCPOption::MaximumSeqmentSize)
result.maximum_seqment_size = BAN::host_to_network_endian(*reinterpret_cast<const uint16_t*>(&header.options[i + 2]));
if (header.options[i] == TCPOption::WindowScale)
result.window_scale = header.options[i + 2];
if (header.options[i + 1] == 0)
break;
i += header.options[i + 1] - 1;
}
return result;
}
void TCPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader pseudo_header)
{
auto& header = packet.as<TCPHeader>();
memset(&header, 0, sizeof(TCPHeader));
memset(header.options, TCPOption::End, m_tcp_options_bytes);
header.dst_port = dst_port;
header.src_port = m_port;
header.seq_number = m_send_window.current_seq + m_send_window.has_ghost_byte;
header.ack_number = m_recv_window.start_seq + m_recv_window.data_size + m_recv_window.has_ghost_byte;
header.data_offset = (sizeof(TCPHeader) + m_tcp_options_bytes) / sizeof(uint32_t);
header.window_size = m_recv_window.buffer->size();
ASSERT(m_recv_window.buffer->size() < 1 << (8 * sizeof(header.window_size)));
switch (m_state)
{
case State::Closed:
{
LockGuard _(m_lock);
header.syn = 1;
add_tcp_header_option<0, TCPOption::MaximumSeqmentSize>(header, m_interface->payload_mtu() - m_network_layer.header_size());
add_tcp_header_option<4, TCPOption::WindowScale>(header, 0);
m_state = State::SynSent;
m_send_window.start_seq++;
m_send_window.current_seq = m_send_window.start_seq;
break;
}
case State::SynSent:
header.ack = 1;
break;
case State::SynReceived:
header.ack = 1;
m_state = State::Established;
break;
case State::Established:
header.ack = 1;
break;
case State::CloseWait:
{
LockGuard _(m_lock);
header.ack = 1;
header.fin = 1;
m_state = State::LastAck;
dprintln_if(DEBUG_TCP, "Waiting for last ACK");
break;
}
case State::FinWait1:
{
LockGuard _(m_lock);
header.ack = 1;
header.fin = 1;
m_state = State::FinWait2;
break;
}
case State::FinWait2:
{
LockGuard _(m_lock);
header.ack = 1;
m_state = State::TimeWait;
m_time_wait_start_ms = SystemTimer::get().ms_since_boot();
dprintln_if(DEBUG_TCP, "Sent final ACK");
break;
}
case State::Listen: ASSERT_NOT_REACHED();
case State::Closing: ASSERT_NOT_REACHED();
case State::LastAck: ASSERT_NOT_REACHED();
case State::TimeWait: ASSERT_NOT_REACHED();
}
pseudo_header.extra = packet.size();
header.checksum = calculate_internet_checksum(packet, pseudo_header);
}
void TCPSocket::receive_packet(BAN::ConstByteSpan buffer, const sockaddr_storage& sender)
{
{
uint16_t checksum = 0;
if (sender.ss_family == AF_INET)
{
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(&sender);
checksum = calculate_internet_checksum(buffer,
PseudoHeader {
.src_ipv4 = BAN::IPv4Address(sockaddr_in.sin_addr.s_addr),
.dst_ipv4 = m_interface->get_ipv4_address(),
.protocol = NetworkProtocol::TCP,
.extra = buffer.size()
}
);
}
else
{
dwarnln("No tcp checksum validation for socket family {}", sender.ss_family);
return;
}
if (checksum != 0)
{
dprintln("Checksum does not match");
return;
}
}
auto& header = buffer.as<const TCPHeader>();
m_send_window.non_scaled_size = header.window_size;
auto payload = buffer.slice(header.data_offset * sizeof(uint32_t));
switch (m_state)
{
case State::Closed:
break;
case State::SynSent:
{
if (!header.ack || !header.syn)
break;
LockGuard _(m_lock);
if (header.ack_number != m_send_window.current_seq)
{
dprintln_if(DEBUG_TCP, "Invalid ack number in SYN/ACK", (uint32_t)header.ack_number, m_send_window.current_seq);
break;
}
auto options = parse_tcp_options(header);
if (options.maximum_seqment_size.has_value())
m_send_window.mss = *options.maximum_seqment_size;
if (options.window_scale.has_value())
m_send_window.scale = *options.window_scale;
m_send_window.start_seq = m_send_window.current_seq;
m_send_window.current_ack = m_send_window.current_seq;
m_recv_window.start_seq = header.seq_number + 1;
dprintln_if(DEBUG_TCP, "Got SYN/ACK");
m_should_ack = true;
m_state = State::SynReceived;
break;
}
case State::FinWait2:
case State::TimeWait:
case State::CloseWait:
case State::Established:
{
if (!header.ack)
break;
LockGuard _(m_lock);
if (header.fin)
{
if (m_recv_window.start_seq + m_recv_window.data_size != header.seq_number)
dprintln_if(DEBUG_TCP, "Got FIN, but missing packets");
else
{
if (m_state == State::FinWait2)
m_send_window.has_ghost_byte = true;
else
m_state = State::CloseWait;
m_recv_window.has_ghost_byte = true;
m_should_ack = true;
dprintln_if(DEBUG_TCP, "Got FIN");
}
break;
}
if (header.ack_number > m_send_window.current_ack)
m_send_window.current_ack = header.ack_number;
if (payload.size() > 0)
{
if (header.seq_number != m_recv_window.start_seq + m_recv_window.data_size)
{
dprintln_if(DEBUG_TCP, "Missing packet");
break;
}
if (m_recv_window.data_size + payload.size() > m_recv_window.buffer->size())
{
dprintln_if(DEBUG_TCP, "Cannot fit received bytes to window, waiting for retransmission");
break;
}
auto* buffer = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
memcpy(buffer + m_recv_window.data_size, payload.data(), payload.size());
m_recv_window.data_size += payload.size();
m_should_ack = true;
dprintln_if(DEBUG_TCP, "Received {} bytes", payload.size());
}
break;
}
case State::LastAck:
if (!header.ack)
break;
dprintln_if(DEBUG_TCP, "Got final ACK");
set_connection_as_closed();
break;
case State::Listen: ASSERT_NOT_REACHED();
case State::SynReceived: ASSERT_NOT_REACHED();
case State::FinWait1: ASSERT_NOT_REACHED();
case State::Closing: ASSERT_NOT_REACHED();
}
m_semaphore.unblock();
}
void TCPSocket::set_connection_as_closed()
{
if (is_bound())
{
m_network_layer.unbind_socket(this, m_port);
m_interface = nullptr;
m_port = PORT_NONE;
dprintln_if(DEBUG_TCP, "Socket unbound");
}
m_process = nullptr;
}
void TCPSocket::process_task()
{
// FIXME: this should be dynamic
static constexpr uint32_t retransmit_timeout_ms = 1000;
BAN::RefPtr<TCPSocket> keep_alive = this;
while (m_process)
{
uint64_t current_ms = SystemTimer::get().ms_since_boot();
if (m_state == State::TimeWait && current_ms >= m_time_wait_start_ms + 30'000)
set_connection_as_closed();
{
LockGuard _(m_lock);
if (m_should_ack)
{
m_should_ack = false;
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
if (auto ret = m_network_layer.sendto(*this, {}, target_address, target_address_len); ret.is_error())
dwarnln("{}", ret.error());
continue;
}
if (m_send_window.data_size > 0 && m_send_window.current_ack - m_send_window.has_ghost_byte > m_send_window.start_seq)
{
uint32_t acknowledged_bytes = m_send_window.current_ack - m_send_window.start_seq - m_send_window.has_ghost_byte;
ASSERT_LTE(acknowledged_bytes, m_send_window.data_size);
m_send_window.data_size -= acknowledged_bytes;
m_send_window.start_seq += acknowledged_bytes;
if (m_send_window.data_size > 0)
{
auto* send_buffer = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
memmove(send_buffer, send_buffer + acknowledged_bytes, m_send_window.data_size);
}
else
{
m_send_window.last_send_ms = 0;
}
dprintln_if(DEBUG_TCP, "Target acknowledged {} bytes", acknowledged_bytes);
continue;
}
if (m_send_window.data_size > 0 && current_ms >= m_send_window.last_send_ms + retransmit_timeout_ms)
{
ASSERT(m_connection_info.has_value());
auto* target_address = reinterpret_cast<const sockaddr*>(&m_connection_info->address);
auto target_address_len = m_connection_info->address_len;
const uint32_t total_send = BAN::Math::min<uint32_t>(m_send_window.data_size, m_send_window.scaled_size());
m_send_window.current_seq = m_send_window.start_seq;
auto* send_buffer = reinterpret_cast<const uint8_t*>(m_send_window.buffer->vaddr());
for (uint32_t i = 0; i < total_send;)
{
const uint32_t to_send = BAN::Math::min(total_send - i, m_send_window.mss);
auto message = BAN::ConstByteSpan(send_buffer + i, to_send);
if (auto ret = m_network_layer.sendto(*this, message, target_address, target_address_len); ret.is_error())
{
dwarnln("{}", ret.error());
break;
}
dprintln_if(DEBUG_TCP, "Sent {} bytes", to_send);
m_send_window.current_seq += to_send;
i += to_send;
}
m_send_window.last_send_ms = current_ms;
continue;
}
}
m_semaphore.block_with_wake_time(current_ms + retransmit_timeout_ms);
}
m_semaphore.unblock();
}
BAN::ErrorOr<size_t> TCPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*)
{
LockGuard _(m_lock);
if (m_state == State::Closed)
return BAN::Error::from_errno(ENOTCONN);
while (m_recv_window.data_size == 0)
{
switch (m_state)
{
case State::SynSent:
case State::SynReceived:
case State::Established:
case State::CloseWait:
case State::Listen:
break;
case State::FinWait1:
case State::FinWait2:
case State::LastAck:
case State::TimeWait:
return BAN::Error::from_errno(ECONNRESET);
case State::Closed: ASSERT_NOT_REACHED();
case State::Closing: ASSERT_NOT_REACHED();
};
LockFreeGuard free(m_lock);
TRY(Thread::current().block_or_eintr_indefinite(m_semaphore));
}
uint32_t to_recv = BAN::Math::min<uint32_t>(buffer.size(), m_recv_window.data_size);
auto* recv_buffer = reinterpret_cast<uint8_t*>(m_recv_window.buffer->vaddr());
memcpy(buffer.data(), recv_buffer, to_recv);
m_recv_window.data_size -= to_recv;
m_recv_window.start_seq += to_recv;
if (m_recv_window.data_size > 0)
memmove(recv_buffer, recv_buffer + to_recv, m_recv_window.data_size);
return to_recv;
}
BAN::ErrorOr<size_t> TCPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{
if (address)
return BAN::Error::from_errno(EISCONN);
if (message.size() > m_send_window.buffer->size())
{
for (size_t i = 0; i < message.size(); i++)
{
const size_t to_send = BAN::Math::min<size_t>(message.size() - i, m_send_window.buffer->size());
TRY(sendto_impl(message.slice(i, to_send), address, address_len));
i += to_send;
}
return message.size();
}
LockGuard _(m_lock);
if (m_state == State::Closed)
return BAN::Error::from_errno(ENOTCONN);
while (true)
{
switch (m_state)
{
case State::SynSent:
case State::SynReceived:
case State::Established:
case State::CloseWait:
case State::Listen:
break;
case State::FinWait1:
case State::FinWait2:
case State::LastAck:
case State::TimeWait:
return BAN::Error::from_errno(ECONNRESET);
case State::Closed: ASSERT_NOT_REACHED();
case State::Closing: ASSERT_NOT_REACHED();
};
if (m_send_window.data_size + message.size() <= m_send_window.buffer->size())
break;
LockFreeGuard free(m_lock);
TRY(Thread::current().block_or_eintr_indefinite(m_semaphore));
}
{
auto* buffer = reinterpret_cast<uint8_t*>(m_send_window.buffer->vaddr());
memcpy(buffer + m_send_window.data_size, message.data(), message.size());
m_send_window.data_size += message.size();
}
uint32_t target_ack = m_send_window.start_seq + m_send_window.data_size;
m_semaphore.unblock();
while (m_send_window.start_seq < target_ack)
{
switch (m_state)
{
case State::SynSent:
case State::SynReceived:
case State::Established:
case State::CloseWait:
case State::Listen:
case State::TimeWait:
case State::FinWait1:
case State::FinWait2:
break;
case State::LastAck:
return BAN::Error::from_errno(ECONNRESET);
case State::Closed: ASSERT_NOT_REACHED();
case State::Closing: ASSERT_NOT_REACHED();
};
LockFreeGuard free(m_lock);
TRY(Thread::current().block_or_eintr_indefinite(m_semaphore));
}
return message.size();
}
}

View File

@ -1,3 +1,4 @@
#include <kernel/LockGuard.h>
#include <kernel/Memory/Heap.h> #include <kernel/Memory/Heap.h>
#include <kernel/Networking/UDPSocket.h> #include <kernel/Networking/UDPSocket.h>
#include <kernel/Thread.h> #include <kernel/Thread.h>
@ -23,7 +24,15 @@ namespace Kernel
: NetworkSocket(network_layer, ino, inode_info) : NetworkSocket(network_layer, ino, inode_info)
{ } { }
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) void UDPSocket::on_close_impl()
{
if (is_bound())
m_network_layer.unbind_socket(this, m_port);
m_port = PORT_NONE;
m_interface = nullptr;
}
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port, PseudoHeader)
{ {
auto& header = packet.as<UDPHeader>(); auto& header = packet.as<UDPHeader>();
header.src_port = m_port; header.src_port = m_port;
@ -32,9 +41,12 @@ namespace Kernel
header.checksum = 0; header.checksum = 0;
} }
void UDPSocket::add_packet(BAN::ConstByteSpan packet, BAN::IPv4Address sender_addr, uint16_t sender_port) void UDPSocket::receive_packet(BAN::ConstByteSpan packet, const sockaddr_storage& sender)
{ {
CriticalScope _; //auto& header = packet.as<const UDPHeader>();
auto payload = packet.slice(sizeof(UDPHeader));
LockGuard _(m_packet_lock);
if (m_packets.full()) if (m_packets.full())
{ {
@ -42,60 +54,82 @@ namespace Kernel
return; return;
} }
if (!m_packets.empty() && m_packet_total_size > m_packet_buffer->size()) if (m_packet_total_size + payload.size() > m_packet_buffer->size())
{ {
dprintln("Packet buffer full, dropping packet"); dprintln("Packet buffer full, dropping packet");
return; return;
} }
void* buffer = reinterpret_cast<void*>(m_packet_buffer->vaddr() + m_packet_total_size); void* buffer = reinterpret_cast<void*>(m_packet_buffer->vaddr() + m_packet_total_size);
memcpy(buffer, packet.data(), packet.size()); memcpy(buffer, payload.data(), payload.size());
m_packets.push(PacketInfo { m_packets.emplace(PacketInfo {
.sender_addr = sender_addr, .sender = sender,
.sender_port = sender_port, .packet_size = payload.size()
.packet_size = packet.size()
}); });
m_packet_total_size += packet.size(); m_packet_total_size += payload.size();
m_semaphore.unblock(); m_packet_semaphore.unblock();
} }
BAN::ErrorOr<size_t> UDPSocket::read_packet(BAN::ByteSpan buffer, sockaddr_in* sender_addr) BAN::ErrorOr<void> UDPSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{ {
while (m_packets.empty()) if (is_bound())
TRY(Thread::current().block_or_eintr(m_semaphore)); return BAN::Error::from_errno(EINVAL);
return m_network_layer.bind_socket_to_address(this, address, address_len);
}
CriticalScope _; BAN::ErrorOr<size_t> UDPSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr* address, socklen_t* address_len)
if (m_packets.empty()) {
return read_packet(buffer, sender_addr); if (!is_bound())
{
dprintln("No interface bound");
return BAN::Error::from_errno(EINVAL);
}
ASSERT(m_port != PORT_NONE);
LockGuard _(m_packet_lock);
while (m_packets.empty())
{
LockFreeGuard free(m_packet_lock);
TRY(Thread::current().block_or_eintr_indefinite(m_packet_semaphore));
}
auto packet_info = m_packets.front(); auto packet_info = m_packets.front();
m_packets.pop(); m_packets.pop();
size_t nread = BAN::Math::min<size_t>(packet_info.packet_size, buffer.size()); size_t nread = BAN::Math::min<size_t>(packet_info.packet_size, buffer.size());
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
memcpy( memcpy(
buffer.data(), buffer.data(),
(const void*)m_packet_buffer->vaddr(), packet_buffer,
nread nread
); );
memmove( memmove(
(void*)m_packet_buffer->vaddr(), packet_buffer,
(void*)(m_packet_buffer->vaddr() + packet_info.packet_size), packet_buffer + packet_info.packet_size,
m_packet_total_size - packet_info.packet_size m_packet_total_size - packet_info.packet_size
); );
m_packet_total_size -= packet_info.packet_size; m_packet_total_size -= packet_info.packet_size;
if (sender_addr) if (address && address_len)
{ {
sender_addr->sin_family = AF_INET; if (*address_len > (socklen_t)sizeof(sockaddr_storage))
sender_addr->sin_port = BAN::NetworkEndian(packet_info.sender_port); *address_len = sizeof(sockaddr_storage);
sender_addr->sin_addr.s_addr = packet_info.sender_addr.raw; memcpy(address, &packet_info.sender, *address_len);
} }
return nread; return nread;
} }
BAN::ErrorOr<size_t> UDPSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{
if (!is_bound())
TRY(m_network_layer.bind_socket_to_unused(this, address, address_len));
return TRY(m_network_layer.sendto(*this, message, address, address_len));
}
} }

View File

@ -10,7 +10,7 @@
namespace Kernel namespace Kernel
{ {
static BAN::HashMap<BAN::String, BAN::RefPtr<UnixDomainSocket>> s_bound_sockets; static BAN::HashMap<BAN::String, BAN::WeakPtr<UnixDomainSocket>> s_bound_sockets;
static SpinLock s_bound_socket_lock; static SpinLock s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE; static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
@ -47,6 +47,16 @@ namespace Kernel
} }
} }
void UnixDomainSocket::on_close_impl()
{
if (is_bound() && !is_bound_to_unused())
{
LockGuard _(s_bound_socket_lock);
if (s_bound_sockets.contains(m_bound_path))
s_bound_sockets.remove(m_bound_path);
}
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len) BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len)
{ {
if (!m_info.has<ConnectionInfo>()) if (!m_info.has<ConnectionInfo>())
@ -56,7 +66,7 @@ namespace Kernel
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
while (connection_info.pending_connections.empty()) while (connection_info.pending_connections.empty())
TRY(Thread::current().block_or_eintr(connection_info.pending_semaphore)); TRY(Thread::current().block_or_eintr_indefinite(connection_info.pending_semaphore));
BAN::RefPtr<UnixDomainSocket> pending; BAN::RefPtr<UnixDomainSocket> pending;
@ -74,7 +84,7 @@ namespace Kernel
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr()); return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
} }
TRY(return_inode->m_bound_path.append(m_bound_path)); TRY(return_inode->m_bound_path.push_back('X'));
return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr()); return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr()); pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection_done = true; pending->m_info.get<ConnectionInfo>().connection_done = true;
@ -113,14 +123,21 @@ namespace Kernel
LockGuard _(s_bound_socket_lock); LockGuard _(s_bound_socket_lock);
if (!s_bound_sockets.contains(file.canonical_path)) if (!s_bound_sockets.contains(file.canonical_path))
return BAN::Error::from_errno(ECONNREFUSED); return BAN::Error::from_errno(ECONNREFUSED);
target = s_bound_sockets[file.canonical_path]; target = s_bound_sockets[file.canonical_path].lock();
if (!target)
return BAN::Error::from_errno(ECONNREFUSED);
} }
if (m_socket_type != target->m_socket_type) if (m_socket_type != target->m_socket_type)
return BAN::Error::from_errno(EPROTOTYPE); return BAN::Error::from_errno(EPROTOTYPE);
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionlessInfo>())
{ {
auto& connectionless_info = m_info.get<ConnectionlessInfo>();
connectionless_info.peer_address = BAN::move(file.canonical_path);
return {};
}
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection) if (connection_info.connection)
return BAN::Error::from_errno(ECONNREFUSED); return BAN::Error::from_errno(ECONNREFUSED);
@ -141,7 +158,7 @@ namespace Kernel
break; break;
} }
} }
TRY(Thread::current().block_or_eintr(target_info.pending_semaphore)); TRY(Thread::current().block_or_eintr_indefinite(target_info.pending_semaphore));
} }
while (!connection_info.connection_done) while (!connection_info.connection_done)
@ -149,11 +166,6 @@ namespace Kernel
return {}; return {};
} }
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog) BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog)
{ {
@ -195,7 +207,7 @@ namespace Kernel
LockGuard _(s_bound_socket_lock); LockGuard _(s_bound_socket_lock);
ASSERT(!s_bound_sockets.contains(file.canonical_path)); ASSERT(!s_bound_sockets.contains(file.canonical_path));
TRY(s_bound_sockets.emplace(file.canonical_path, this)); TRY(s_bound_sockets.emplace(file.canonical_path, TRY(get_weak_ptr())));
m_bound_path = BAN::move(file.canonical_path); m_bound_path = BAN::move(file.canonical_path);
return {}; return {};
@ -215,28 +227,6 @@ namespace Kernel
} }
} }
// This to feels too hacky to expose out of here
struct LockFreeGuard
{
LockFreeGuard(RecursivePrioritySpinLock& lock)
: m_lock(lock)
, m_depth(lock.lock_depth())
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.unlock();
}
~LockFreeGuard()
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.lock();
}
private:
RecursivePrioritySpinLock& m_lock;
const uint32_t m_depth;
};
BAN::ErrorOr<void> UnixDomainSocket::add_packet(BAN::ConstByteSpan packet) BAN::ErrorOr<void> UnixDomainSocket::add_packet(BAN::ConstByteSpan packet)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
@ -244,7 +234,7 @@ namespace Kernel
while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size) while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size)
{ {
LockFreeGuard _(m_lock); LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore)); TRY(Thread::current().block_or_eintr_indefinite(m_packet_semaphore));
} }
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total); uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
@ -258,35 +248,87 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(const sys_sendto_t* arguments) bool UnixDomainSocket::can_read_impl() const
{ {
if (arguments->flags) if (m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(ENOTSUP); {
if (arguments->length > s_packet_buffer_size) auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.connection)
return false;
}
return m_packet_size_total > 0;
}
bool UnixDomainSocket::can_write_impl() const
{
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
return connection_info.connection.valid();
}
return true;
}
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(BAN::ConstByteSpan message, const sockaddr* address, socklen_t address_len)
{
if (message.size() > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS); return BAN::Error::from_errno(ENOBUFS);
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
{ {
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
if (arguments->dest_addr) if (address)
return BAN::Error::from_errno(EISCONN); return BAN::Error::from_errno(EISCONN);
auto target = connection_info.connection.lock(); auto target = connection_info.connection.lock();
if (!target) if (!target)
return BAN::Error::from_errno(ENOTCONN); return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet({ reinterpret_cast<const uint8_t*>(arguments->message), arguments->length })); TRY(target->add_packet(message));
return arguments->length; return message.size();
} }
else else
{ {
return BAN::Error::from_errno(ENOTSUP); BAN::String canonical_path;
}
}
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(sys_recvfrom_t* arguments) if (!address)
{ {
if (arguments->flags) auto& connectionless_info = m_info.get<ConnectionlessInfo>();
return BAN::Error::from_errno(ENOTSUP); if (connectionless_info.peer_address.empty())
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(canonical_path.append(connectionless_info.peer_address));
}
else
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_WRONLY
));
canonical_path = BAN::move(file.canonical_path);
}
LockGuard _(s_bound_socket_lock);
if (!s_bound_sockets.contains(canonical_path))
return BAN::Error::from_errno(EDESTADDRREQ);
auto target = s_bound_sockets[canonical_path].lock();
if (!target)
return BAN::Error::from_errno(EDESTADDRREQ);
TRY(target->add_packet(message));
return message.size();
}
}
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(BAN::ByteSpan buffer, sockaddr*, socklen_t*)
{
if (m_info.has<ConnectionInfo>()) if (m_info.has<ConnectionInfo>())
{ {
auto& connection_info = m_info.get<ConnectionInfo>(); auto& connection_info = m_info.get<ConnectionInfo>();
@ -297,21 +339,21 @@ namespace Kernel
while (m_packet_size_total == 0) while (m_packet_size_total == 0)
{ {
LockFreeGuard _(m_lock); LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore)); TRY(Thread::current().block_or_eintr_indefinite(m_packet_semaphore));
} }
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr()); uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t nread = 0; size_t nread = 0;
if (is_streaming()) if (is_streaming())
nread = BAN::Math::min(arguments->length, m_packet_size_total); nread = BAN::Math::min(buffer.size(), m_packet_size_total);
else else
{ {
nread = BAN::Math::min(arguments->length, m_packet_sizes.front()); nread = BAN::Math::min(buffer.size(), m_packet_sizes.front());
m_packet_sizes.pop(); m_packet_sizes.pop();
} }
memcpy(arguments->buffer, packet_buffer, nread); memcpy(buffer.data(), packet_buffer, nread);
memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread); memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread);
m_packet_size_total -= nread; m_packet_size_total -= nread;

View File

@ -331,7 +331,7 @@ namespace Kernel
{ {
TRY(validate_fd(fd)); TRY(validate_fd(fd));
auto& open_file = m_open_files[fd]; auto& open_file = m_open_files[fd];
if ((open_file->flags & O_NONBLOCK) && !open_file->inode->has_data()) if ((open_file->flags & O_NONBLOCK) && !open_file->inode->can_read())
return 0; return 0;
size_t nread = TRY(open_file->inode->read(open_file->offset, buffer)); size_t nread = TRY(open_file->inode->read(open_file->offset, buffer));
open_file->offset += nread; open_file->offset += nread;
@ -342,6 +342,8 @@ namespace Kernel
{ {
TRY(validate_fd(fd)); TRY(validate_fd(fd));
auto& open_file = m_open_files[fd]; auto& open_file = m_open_files[fd];
if ((open_file->flags & O_NONBLOCK) && !open_file->inode->can_write())
return 0;
if (open_file->flags & O_APPEND) if (open_file->flags & O_APPEND)
open_file->offset = open_file->inode->size(); open_file->offset = open_file->inode->size();
size_t nwrite = TRY(open_file->inode->write(open_file->offset, buffer)); size_t nwrite = TRY(open_file->inode->write(open_file->offset, buffer));

View File

@ -568,7 +568,7 @@ namespace Kernel
return BAN::Error::from_errno(ECHILD); return BAN::Error::from_errno(ECHILD);
while (!target->m_exit_status.exited) while (!target->m_exit_status.exited)
TRY(Thread::current().block_or_eintr(target->m_exit_status.semaphore)); TRY(Thread::current().block_or_eintr_indefinite(target->m_exit_status.semaphore));
int exit_status = target->m_exit_status.exit_code; int exit_status = target->m_exit_status.exit_code;
target->m_exit_status.waiting--; target->m_exit_status.waiting--;
@ -983,7 +983,8 @@ namespace Kernel
if (!inode->mode().ifsock()) if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return TRY(inode->sendto(arguments)); BAN::ConstByteSpan message { reinterpret_cast<const uint8_t*>(arguments->message), arguments->length };
return TRY(inode->sendto(message, arguments->dest_addr, arguments->dest_len));
} }
BAN::ErrorOr<long> Process::sys_recvfrom(sys_recvfrom_t* arguments) BAN::ErrorOr<long> Process::sys_recvfrom(sys_recvfrom_t* arguments)
@ -1006,7 +1007,8 @@ namespace Kernel
if (!inode->mode().ifsock()) if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK); return BAN::Error::from_errno(ENOTSOCK);
return TRY(inode->recvfrom(arguments)); BAN::ByteSpan buffer { reinterpret_cast<uint8_t*>(arguments->buffer), arguments->length };
return TRY(inode->recvfrom(buffer, arguments->address, arguments->address_len));
} }
BAN::ErrorOr<long> Process::sys_ioctl(int fildes, int request, void* arg) BAN::ErrorOr<long> Process::sys_ioctl(int fildes, int request, void* arg)
@ -1016,6 +1018,86 @@ namespace Kernel
return TRY(inode->ioctl(request, arg)); return TRY(inode->ioctl(request, arg));
} }
BAN::ErrorOr<long> Process::sys_pselect(sys_pselect_t* arguments)
{
LockGuard _(m_lock);
TRY(validate_pointer_access(arguments, sizeof(sys_pselect_t)));
if (arguments->readfds)
TRY(validate_pointer_access(arguments->readfds, sizeof(fd_set)));
if (arguments->writefds)
TRY(validate_pointer_access(arguments->writefds, sizeof(fd_set)));
if (arguments->errorfds)
TRY(validate_pointer_access(arguments->errorfds, sizeof(fd_set)));
if (arguments->timeout)
TRY(validate_pointer_access(arguments->timeout, sizeof(timespec)));
if (arguments->sigmask)
TRY(validate_pointer_access(arguments->sigmask, sizeof(sigset_t)));
if (arguments->sigmask)
return BAN::Error::from_errno(ENOTSUP);
uint64_t timedout_ms = SystemTimer::get().ms_since_boot();
if (arguments->timeout)
{
timedout_ms += arguments->timeout->tv_sec * 1000;
timedout_ms += arguments->timeout->tv_nsec / 1'000'000;
}
fd_set readfds; FD_ZERO(&readfds);
fd_set writefds; FD_ZERO(&writefds);
fd_set errorfds; FD_ZERO(&errorfds);
long set_bits = 0;
while (set_bits == 0)
{
if (arguments->timeout && SystemTimer::get().ms_since_boot() >= timedout_ms)
break;
auto update_fds =
[&](int fd, fd_set* source, fd_set* dest, bool (Inode::*func)() const)
{
if (source == nullptr)
return;
if (!FD_ISSET(fd, source))
return;
auto inode_or_error = m_open_file_descriptors.inode_of(fd);
if (inode_or_error.is_error())
return;
auto inode = inode_or_error.release_value();
auto mode = inode->mode();
if (!mode.ifreg() && !mode.ififo() && !mode.ifsock() && !inode->is_pipe() && !inode->is_tty())
return;
if ((inode_or_error.value().ptr()->*func)())
{
FD_SET(fd, dest);
set_bits++;
}
};
for (int i = 0; i < arguments->nfds; i++)
{
update_fds(i, arguments->readfds, &readfds, &Inode::can_read);
update_fds(i, arguments->writefds, &writefds, &Inode::can_write);
update_fds(i, arguments->errorfds, &errorfds, &Inode::can_read);
}
SystemTimer::get().sleep(1);
}
if (arguments->readfds)
memcpy(arguments->readfds, &readfds, sizeof(fd_set));
if (arguments->writefds)
memcpy(arguments->writefds, &writefds, sizeof(fd_set));
if (arguments->errorfds)
memcpy(arguments->errorfds, &errorfds, sizeof(fd_set));
return set_bits;
}
BAN::ErrorOr<long> Process::sys_pipe(int fildes[2]) BAN::ErrorOr<long> Process::sys_pipe(int fildes[2])
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
@ -1136,7 +1218,7 @@ namespace Kernel
return BAN::Error::from_errno(EUNKNOWN); return BAN::Error::from_errno(EUNKNOWN);
} }
BAN::ErrorOr<long> Process::sys_read_dir_entries(int fd, DirectoryEntryList* list, size_t list_size) BAN::ErrorOr<long> Process::sys_readdir(int fd, DirectoryEntryList* list, size_t list_size)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
TRY(validate_pointer_access(list, list_size)); TRY(validate_pointer_access(list, list_size));

44
kernel/kernel/Random.cpp Normal file
View File

@ -0,0 +1,44 @@
#include <kernel/Debug.h>
#include <kernel/CPUID.h>
#include <kernel/Random.h>
namespace Kernel
{
// Constants and algorithm from https://en.wikipedia.org/wiki/Permuted_congruential_generator
static uint64_t s_rand_seed = 0x4d595df4d0f33173;
static constexpr uint64_t s_rand_multiplier = 6364136223846793005;
static constexpr uint64_t s_rand_increment = 1442695040888963407;
void Random::initialize()
{
uint32_t ecx, edx;
CPUID::get_features(ecx, edx);
if (ecx & CPUID::ECX_RDRND)
asm volatile("rdrand %0" : "=a"(s_rand_seed));
else
dprintln("No RDRAND available");
}
uint32_t Random::get_u32()
{
auto rotr32 = [](uint32_t x, unsigned r) { return x >> r | x << (-r & 31); };
uint64_t x = s_rand_seed;
unsigned count = (unsigned)(x >> 59);
s_rand_seed = x * s_rand_multiplier + s_rand_increment;
x ^= x >> 18;
return rotr32(x >> 27, count) % UINT32_MAX;
}
uint64_t Random::get_u64()
{
return ((uint64_t)get_u32() << 32) | get_u32();
}
}

View File

@ -278,12 +278,9 @@ namespace Kernel
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
} }
void Scheduler::set_current_thread_sleeping(uint64_t wake_time) void Scheduler::set_current_thread_sleeping_impl(uint64_t wake_time)
{ {
VERIFY_STI(); VERIFY_CLI();
DISABLE_INTERRUPTS();
ASSERT(m_current_thread);
if (save_current_thread()) if (save_current_thread())
{ {
@ -310,42 +307,37 @@ namespace Kernel
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
} }
void Scheduler::block_current_thread(Semaphore* semaphore) void Scheduler::set_current_thread_sleeping(uint64_t wake_time)
{ {
VERIFY_STI(); VERIFY_STI();
DISABLE_INTERRUPTS(); DISABLE_INTERRUPTS();
ASSERT(m_current_thread); ASSERT(m_current_thread);
if (save_current_thread()) m_current_thread->semaphore = nullptr;
{ set_current_thread_sleeping_impl(wake_time);
ENABLE_INTERRUPTS();
return;
} }
void Scheduler::block_current_thread(Semaphore* semaphore, uint64_t wake_time)
{
VERIFY_STI();
DISABLE_INTERRUPTS();
ASSERT(m_current_thread);
m_current_thread->semaphore = semaphore; m_current_thread->semaphore = semaphore;
m_active_threads.move_element_to_other_linked_list( set_current_thread_sleeping_impl(wake_time);
m_blocking_threads,
m_blocking_threads.end(),
m_current_thread
);
m_current_thread = {};
advance_current_thread();
execute_current_thread();
ASSERT_NOT_REACHED();
} }
void Scheduler::unblock_threads(Semaphore* semaphore) void Scheduler::unblock_threads(Semaphore* semaphore)
{ {
CriticalScope critical; CriticalScope critical;
for (auto it = m_blocking_threads.begin(); it != m_blocking_threads.end();) for (auto it = m_sleeping_threads.begin(); it != m_sleeping_threads.end();)
{ {
if (it->semaphore == semaphore) if (it->semaphore == semaphore)
{ {
it = m_blocking_threads.move_element_to_other_linked_list( it = m_sleeping_threads.move_element_to_other_linked_list(
m_active_threads, m_active_threads,
m_active_threads.end(), m_active_threads.end(),
it it
@ -362,19 +354,6 @@ namespace Kernel
{ {
CriticalScope _; CriticalScope _;
for (auto it = m_blocking_threads.begin(); it != m_blocking_threads.end(); it++)
{
if (it->thread->tid() == tid)
{
m_blocking_threads.move_element_to_other_linked_list(
m_active_threads,
m_active_threads.end(),
it
);
return;
}
}
for (auto it = m_sleeping_threads.begin(); it != m_sleeping_threads.end(); it++) for (auto it = m_sleeping_threads.begin(); it != m_sleeping_threads.end(); it++)
{ {
if (it->thread->tid() == tid) if (it->thread->tid() == tid)

View File

@ -1,12 +1,23 @@
#include <kernel/Scheduler.h> #include <kernel/Scheduler.h>
#include <kernel/Semaphore.h> #include <kernel/Semaphore.h>
#include <kernel/Timer/Timer.h>
namespace Kernel namespace Kernel
{ {
void Semaphore::block() void Semaphore::block_indefinite()
{ {
Scheduler::get().block_current_thread(this); Scheduler::get().block_current_thread(this, ~(uint64_t)0);
}
void Semaphore::block_with_timeout(uint64_t timeout_ms)
{
Scheduler::get().block_current_thread(this, SystemTimer::get().ms_since_boot() + timeout_ms);
}
void Semaphore::block_with_wake_time(uint64_t wake_time)
{
Scheduler::get().block_current_thread(this, wake_time);
} }
void Semaphore::unblock() void Semaphore::unblock()

View File

@ -7,6 +7,7 @@ namespace Kernel
void SpinLock::lock() void SpinLock::lock()
{ {
pid_t tid = Scheduler::current_tid(); pid_t tid = Scheduler::current_tid();
ASSERT(tid != m_locker);
while (!m_locker.compare_exchange(-1, tid)) while (!m_locker.compare_exchange(-1, tid))
Scheduler::get().reschedule(); Scheduler::get().reschedule();
} }

View File

@ -67,14 +67,13 @@ namespace Kernel
while (SystemTimer::get().ms_since_boot() < start_time + s_nvme_command_timeout_ms) while (SystemTimer::get().ms_since_boot() < start_time + s_nvme_command_timeout_ms)
{ {
if (!m_done) if (m_done)
{ {
m_semaphore.block();
continue;
}
m_done = false; m_done = false;
return m_status; return m_status;
} }
m_semaphore.block_with_wake_time(start_time + s_nvme_command_timeout_ms);
}
return 0xFFFF; return 0xFFFF;
} }

View File

@ -1,3 +1,4 @@
#include <BAN/Bitcast.h>
#include <kernel/Debug.h> #include <kernel/Debug.h>
#include <kernel/InterruptStack.h> #include <kernel/InterruptStack.h>
#include <kernel/Process.h> #include <kernel/Process.h>
@ -19,6 +20,14 @@ namespace Kernel
extern "C" long sys_fork_trampoline(); extern "C" long sys_fork_trampoline();
using SyscallHandler = BAN::ErrorOr<long> (Process::*)(uintptr_t, uintptr_t, uintptr_t, uintptr_t, uintptr_t);
static const SyscallHandler s_syscall_handlers[] = {
#define O(enum, name) BAN::bit_cast<SyscallHandler>(&Process::sys_ ## name),
__SYSCALL_LIST(O)
#undef O
};
extern "C" long cpp_syscall_handler(int syscall, uintptr_t arg1, uintptr_t arg2, uintptr_t arg3, uintptr_t arg4, uintptr_t arg5, InterruptStack& interrupt_stack) extern "C" long cpp_syscall_handler(int syscall, uintptr_t arg1, uintptr_t arg2, uintptr_t arg3, uintptr_t arg4, uintptr_t arg5, InterruptStack& interrupt_stack)
{ {
ASSERT((interrupt_stack.cs & 0b11) == 0b11); ASSERT((interrupt_stack.cs & 0b11) == 0b11);
@ -28,219 +37,14 @@ namespace Kernel
asm volatile("sti"); asm volatile("sti");
(void)arg1;
(void)arg2;
(void)arg3;
(void)arg4;
(void)arg5;
(void)interrupt_stack;
BAN::ErrorOr<long> ret = BAN::Error::from_errno(ENOSYS); BAN::ErrorOr<long> ret = BAN::Error::from_errno(ENOSYS);
switch (syscall) if (syscall < 0 || syscall >= __SYSCALL_COUNT)
{ dwarnln("No syscall {}", syscall);
case SYS_EXIT: else if (syscall == SYS_FORK)
ret = Process::current().sys_exit((int)arg1);
break;
case SYS_READ:
ret = Process::current().sys_read((int)arg1, (void*)arg2, (size_t)arg3);
break;
case SYS_WRITE:
ret = Process::current().sys_write((int)arg1, (const void*)arg2, (size_t)arg3);
break;
case SYS_TERMID:
ret = Process::current().sys_termid((char*)arg1);
break;
case SYS_CLOSE:
ret = Process::current().sys_close((int)arg1);
break;
case SYS_OPEN:
ret = Process::current().sys_open((const char*)arg1, (int)arg2, (mode_t)arg3);
break;
case SYS_OPENAT:
ret = Process::current().sys_openat((int)arg1, (const char*)arg2, (int)arg3, (mode_t)arg4);
break;
case SYS_SEEK:
ret = Process::current().sys_seek((int)arg1, (long)arg2, (int)arg3);
break;
case SYS_TELL:
ret = Process::current().sys_tell((int)arg1);
break;
case SYS_GET_TERMIOS:
ret = Process::current().sys_gettermios((::termios*)arg1);
break;
case SYS_SET_TERMIOS:
ret = Process::current().sys_settermios((const ::termios*)arg1);
break;
case SYS_FORK:
ret = sys_fork_trampoline(); ret = sys_fork_trampoline();
break; else
case SYS_EXEC: ret = (Process::current().*s_syscall_handlers[syscall])(arg1, arg2, arg3, arg4, arg5);
ret = Process::current().sys_exec((const char*)arg1, (const char* const*)arg2, (const char* const*)arg3);
break;
case SYS_SLEEP:
ret = Process::current().sys_sleep((unsigned int)arg1);
break;
case SYS_WAIT:
ret = Process::current().sys_wait((pid_t)arg1, (int*)arg2, (int)arg3);
break;
case SYS_FSTAT:
ret = Process::current().sys_fstat((int)arg1, (struct stat*)arg2);
break;
case SYS_READ_DIR_ENTRIES:
ret = Process::current().sys_read_dir_entries((int)arg1, (API::DirectoryEntryList*)arg2, (size_t)arg3);
break;
case SYS_SET_UID:
ret = Process::current().sys_setuid((uid_t)arg1);
break;
case SYS_SET_GID:
ret = Process::current().sys_setgid((gid_t)arg1);
break;
case SYS_SET_EUID:
ret = Process::current().sys_seteuid((uid_t)arg1);
break;
case SYS_SET_EGID:
ret = Process::current().sys_setegid((gid_t)arg1);
break;
case SYS_SET_REUID:
ret = Process::current().sys_setreuid((uid_t)arg1, (uid_t)arg2);
break;
case SYS_SET_REGID:
ret = Process::current().sys_setregid((gid_t)arg1, (gid_t)arg2);
break;
case SYS_GET_UID:
ret = Process::current().sys_getuid();
break;
case SYS_GET_GID:
ret = Process::current().sys_getgid();
break;
case SYS_GET_EUID:
ret = Process::current().sys_geteuid();
break;
case SYS_GET_EGID:
ret = Process::current().sys_getegid();
break;
case SYS_GET_PWD:
ret = Process::current().sys_getpwd((char*)arg1, (size_t)arg2);
break;
case SYS_SET_PWD:
ret = Process::current().sys_setpwd((const char*)arg1);
break;
case SYS_CLOCK_GETTIME:
ret = Process::current().sys_clock_gettime((clockid_t)arg1, (timespec*)arg2);
break;
case SYS_PIPE:
ret = Process::current().sys_pipe((int*)arg1);
break;
case SYS_DUP:
ret = Process::current().sys_dup((int)arg1);
break;
case SYS_DUP2:
ret = Process::current().sys_dup2((int)arg1, (int)arg2);
break;
case SYS_KILL:
ret = Process::current().sys_kill((pid_t)arg1, (int)arg2);
break;
case SYS_SIGNAL:
ret = Process::current().sys_signal((int)arg1, (void (*)(int))arg2);
break;
case SYS_TCSETPGRP:
ret = Process::current().sys_tcsetpgrp((int)arg1, (pid_t)arg2);
break;
case SYS_GET_PID:
ret = Process::current().pid();
break;
case SYS_GET_PGID:
ret = Process::current().sys_getpgid((pid_t)arg1);
break;
case SYS_SET_PGID:
ret = Process::current().sys_setpgid((pid_t)arg1, (pid_t)arg2);
break;
case SYS_FCNTL:
ret = Process::current().sys_fcntl((int)arg1, (int)arg2, (int)arg3);
break;
case SYS_NANOSLEEP:
ret = Process::current().sys_nanosleep((const timespec*)arg1, (timespec*)arg2);
break;
case SYS_FSTATAT:
ret = Process::current().sys_fstatat((int)arg1, (const char*)arg2, (struct stat*)arg3, (int)arg4);
break;
case SYS_STAT:
ret = Process::current().sys_stat((const char*)arg1, (struct stat*)arg2, (int)arg3);
break;
case SYS_SYNC:
ret = Process::current().sys_sync((bool)arg1);
break;
case SYS_MMAP:
ret = Process::current().sys_mmap((const sys_mmap_t*)arg1);
break;
case SYS_MUNMAP:
ret = Process::current().sys_munmap((void*)arg1, (size_t)arg2);
break;
case SYS_TTY_CTRL:
ret = Process::current().sys_tty_ctrl((int)arg1, (int)arg2, (int)arg3);
break;
case SYS_POWEROFF:
ret = Process::current().sys_poweroff((int)arg1);
break;
case SYS_CHMOD:
ret = Process::current().sys_chmod((const char*)arg1, (mode_t)arg2);
break;
case SYS_CREATE:
ret = Process::current().sys_create((const char*)arg1, (mode_t)arg2);
break;
case SYS_CREATE_DIR:
ret = Process::current().sys_create_dir((const char*)arg1, (mode_t)arg2);
break;
case SYS_UNLINK:
ret = Process::current().sys_unlink((const char*)arg1);
break;
case SYS_READLINK:
ret = Process::current().sys_readlink((const char*)arg1, (char*)arg2, (size_t)arg3);
break;
case SYS_READLINKAT:
ret = Process::current().sys_readlinkat((int)arg1, (const char*)arg2, (char*)arg3, (size_t)arg4);
break;
case SYS_MSYNC:
ret = Process::current().sys_msync((void*)arg1, (size_t)arg2, (int)arg3);
break;
case SYS_PREAD:
ret = Process::current().sys_pread((int)arg1, (void*)arg2, (size_t)arg3, (off_t)arg4);
break;
case SYS_CHOWN:
ret = Process::current().sys_chown((const char*)arg1, (uid_t)arg2, (gid_t)arg3);
break;
case SYS_LOAD_KEYMAP:
ret = Process::current().sys_load_keymap((const char*)arg1);
break;
case SYS_SOCKET:
ret = Process::current().sys_socket((int)arg1, (int)arg2, (int)arg3);
break;
case SYS_BIND:
ret = Process::current().sys_bind((int)arg1, (const sockaddr*)arg2, (socklen_t)arg3);
break;
case SYS_SENDTO:
ret = Process::current().sys_sendto((const sys_sendto_t*)arg1);
break;
case SYS_RECVFROM:
ret = Process::current().sys_recvfrom((sys_recvfrom_t*)arg1);
break;
case SYS_IOCTL:
ret = Process::current().sys_ioctl((int)arg1, (int)arg2, (void*)arg3);
break;
case SYS_ACCEPT:
ret = Process::current().sys_accept((int)arg1, (sockaddr*)arg2, (socklen_t*)arg3);
break;
case SYS_CONNECT:
ret = Process::current().sys_connect((int)arg1, (const sockaddr*)arg2, (socklen_t)arg3);
break;
case SYS_LISTEN:
ret = Process::current().sys_listen((int)arg1, (int)arg2);
break;
default:
dwarnln("Unknown syscall {}", syscall);
break;
}
asm volatile("cli"); asm volatile("cli");

View File

@ -92,7 +92,7 @@ namespace Kernel
while (true) while (true)
{ {
while (!TTY::current()->m_tty_ctrl.receive_input) while (!TTY::current()->m_tty_ctrl.receive_input)
TTY::current()->m_tty_ctrl.semaphore.block(); TTY::current()->m_tty_ctrl.semaphore.block_indefinite();
Input::KeyEvent event; Input::KeyEvent event;
size_t read = MUST(inode->read(0, BAN::ByteSpan::from(event))); size_t read = MUST(inode->read(0, BAN::ByteSpan::from(event)));
@ -210,7 +210,7 @@ namespace Kernel
// ^C // ^C
if (ch == '\x03') if (ch == '\x03')
{ {
if (auto ret = Process::sys_kill(-m_foreground_pgrp, SIGINT); ret.is_error()) if (auto ret = Process::current().sys_kill(-m_foreground_pgrp, SIGINT); ret.is_error())
dwarnln("TTY: {}", ret.error()); dwarnln("TTY: {}", ret.error());
return; return;
} }
@ -323,7 +323,7 @@ namespace Kernel
uint32_t depth = m_lock.lock_depth(); uint32_t depth = m_lock.lock_depth();
for (uint32_t i = 0; i < depth; i++) for (uint32_t i = 0; i < depth; i++)
m_lock.unlock(); m_lock.unlock();
auto eintr = Thread::current().block_or_eintr(m_output.semaphore); auto eintr = Thread::current().block_or_eintr_indefinite(m_output.semaphore);
for (uint32_t i = 0; i < depth; i++) for (uint32_t i = 0; i < depth; i++)
m_lock.lock(); m_lock.lock();
if (eintr.is_error()) if (eintr.is_error())
@ -358,12 +358,6 @@ namespace Kernel
return buffer.size(); return buffer.size();
} }
bool TTY::has_data_impl() const
{
LockGuard _(m_lock);
return m_output.flush;
}
void TTY::putchar_current(uint8_t ch) void TTY::putchar_current(uint8_t ch)
{ {
ASSERT(s_tty); ASSERT(s_tty);

View File

@ -8,6 +8,7 @@
#include <kernel/Process.h> #include <kernel/Process.h>
#include <kernel/Scheduler.h> #include <kernel/Scheduler.h>
#include <kernel/Thread.h> #include <kernel/Thread.h>
#include <kernel/Timer/Timer.h>
namespace Kernel namespace Kernel
{ {
@ -30,7 +31,6 @@ namespace Kernel
void Thread::terminate() void Thread::terminate()
{ {
CriticalScope _; CriticalScope _;
ASSERT(this == &Thread::current());
m_state = Thread::State::Terminated; m_state = Thread::State::Terminated;
if (this == &Thread::current()) if (this == &Thread::current())
Scheduler::get().execute_current_thread(); Scheduler::get().execute_current_thread();
@ -343,16 +343,34 @@ namespace Kernel
return false; return false;
} }
BAN::ErrorOr<void> Thread::block_or_eintr(Semaphore& semaphore) BAN::ErrorOr<void> Thread::block_or_eintr_indefinite(Semaphore& semaphore)
{ {
if (is_interrupted_by_signal()) if (is_interrupted_by_signal())
return BAN::Error::from_errno(EINTR); return BAN::Error::from_errno(EINTR);
semaphore.block(); semaphore.block_indefinite();
if (is_interrupted_by_signal()) if (is_interrupted_by_signal())
return BAN::Error::from_errno(EINTR); return BAN::Error::from_errno(EINTR);
return {}; return {};
} }
BAN::ErrorOr<void> Thread::block_or_eintr_or_timeout(Semaphore& semaphore, uint64_t timeout_ms, bool etimedout)
{
uint64_t wake_time_ms = SystemTimer::get().ms_since_boot() + timeout_ms;
return block_or_eintr_or_waketime(semaphore, wake_time_ms, etimedout);
}
BAN::ErrorOr<void> Thread::block_or_eintr_or_waketime(Semaphore& semaphore, uint64_t wake_time_ms, bool etimedout)
{
if (is_interrupted_by_signal())
return BAN::Error::from_errno(EINTR);
semaphore.block_with_wake_time(wake_time_ms);
if (is_interrupted_by_signal())
return BAN::Error::from_errno(EINTR);
if (etimedout && SystemTimer::get().ms_since_boot() >= wake_time_ms)
return BAN::Error::from_errno(ETIMEDOUT);
return {};
}
void Thread::validate_stack() const void Thread::validate_stack() const
{ {
if (stack_base() <= m_rsp && m_rsp <= stack_base() + stack_size()) if (stack_base() <= m_rsp && m_rsp <= stack_base() + stack_size())

View File

@ -19,6 +19,7 @@
#include <kernel/PCI.h> #include <kernel/PCI.h>
#include <kernel/PIC.h> #include <kernel/PIC.h>
#include <kernel/Process.h> #include <kernel/Process.h>
#include <kernel/Random.h>
#include <kernel/Scheduler.h> #include <kernel/Scheduler.h>
#include <kernel/Syscall.h> #include <kernel/Syscall.h>
#include <kernel/Terminal/Serial.h> #include <kernel/Terminal/Serial.h>
@ -153,6 +154,9 @@ extern "C" void kernel_main(uint32_t boot_magic, uint32_t boot_info)
dprintln("Virtual TTY initialized"); dprintln("Virtual TTY initialized");
} }
Random::initialize();
dprintln("RNG initialized");
MUST(Scheduler::initialize()); MUST(Scheduler::initialize());
dprintln("Scheduler initialized"); dprintln("Scheduler initialized");

View File

@ -21,6 +21,7 @@ set(LIBC_SOURCES
stropts.cpp stropts.cpp
sys/banan-os.cpp sys/banan-os.cpp
sys/mman.cpp sys/mman.cpp
sys/select.cpp
sys/socket.cpp sys/socket.cpp
sys/stat.cpp sys/stat.cpp
sys/wait.cpp sys/wait.cpp

View File

@ -50,3 +50,26 @@ char* inet_ntoa(struct in_addr in)
); );
return buffer; return buffer;
} }
const char* inet_ntop(int af, const void* __restrict src, char* __restrict dst, socklen_t size)
{
if (af == AF_INET)
{
if (size < INET_ADDRSTRLEN)
{
errno = ENOSPC;
return nullptr;
}
uint32_t he = ntohl(reinterpret_cast<const in_addr*>(src)->s_addr);
sprintf(dst, "%u.%u.%u.%u",
(he >> 24) & 0xFF,
(he >> 16) & 0xFF,
(he >> 8) & 0xFF,
(he >> 0) & 0xFF
);
return dst;
}
errno = EAFNOSUPPORT;
return nullptr;
}

View File

@ -79,7 +79,7 @@ struct dirent* readdir(DIR* dirp)
return &dirp->current->dirent; return &dirp->current->dirent;
} }
if (syscall(SYS_READ_DIR_ENTRIES, dirp->fd, dirp->buffer, dirp->buffer_size) == -1) if (syscall(SYS_READ_DIR, dirp->fd, dirp->buffer, dirp->buffer_size) == -1)
return nullptr; return nullptr;
if (dirp->buffer->entry_count == 0) if (dirp->buffer->entry_count == 0)

View File

@ -0,0 +1,20 @@
#ifndef _BITS_TIMEVAL_H
#define _BITS_TIMEVAL_H 1
#include <sys/cdefs.h>
__BEGIN_DECLS
#define __need_time_t
#define __need_suseconds_t
#include <sys/types.h>
struct timeval
{
time_t tv_sec; /* Seconds. */
suseconds_t tc_usec; /* Microseconds. */
};
__END_DECLS
#endif

View File

@ -7,9 +7,7 @@
__BEGIN_DECLS __BEGIN_DECLS
#define __need_time_t #include <bits/types/timeval.h>
#define __need_suseconds_t
#include <sys/types.h>
#include <signal.h> #include <signal.h>
#include <time.h> #include <time.h>
@ -50,6 +48,16 @@ typedef struct {
(setp)->__bits[i] = (__fd_mask)0; \ (setp)->__bits[i] = (__fd_mask)0; \
} while (0) } while (0)
struct sys_pselect_t
{
int nfds;
fd_set* readfds;
fd_set* writefds;
fd_set* errorfds;
const struct timespec* timeout;
const sigset_t* sigmask;
};
int pselect(int nfds, fd_set* __restrict readfds, fd_set* __restrict writefds, fd_set* __restrict errorfds, const struct timespec* __restrict timeout, const sigset_t* __restrict sigmask); int pselect(int nfds, fd_set* __restrict readfds, fd_set* __restrict writefds, fd_set* __restrict errorfds, const struct timespec* __restrict timeout, const sigset_t* __restrict sigmask);
int select(int nfds, fd_set* __restrict readfds, fd_set* __restrict writefds, fd_set* __restrict errorfds, struct timeval* __restrict timeout); int select(int nfds, fd_set* __restrict readfds, fd_set* __restrict writefds, fd_set* __restrict errorfds, struct timeval* __restrict timeout);

View File

@ -5,72 +5,82 @@
__BEGIN_DECLS __BEGIN_DECLS
#define SYS_EXIT 1 #define __SYSCALL_LIST(O) \
#define SYS_READ 2 O(SYS_EXIT, exit) \
#define SYS_WRITE 3 O(SYS_READ, read) \
#define SYS_TERMID 4 O(SYS_WRITE, write) \
#define SYS_CLOSE 5 O(SYS_TERMID, termid) \
#define SYS_OPEN 6 O(SYS_CLOSE, close) \
#define SYS_OPENAT 7 O(SYS_OPEN, open) \
#define SYS_SEEK 11 O(SYS_OPENAT, openat) \
#define SYS_TELL 12 O(SYS_SEEK, seek) \
#define SYS_GET_TERMIOS 13 O(SYS_TELL, tell) \
#define SYS_SET_TERMIOS 14 O(SYS_GET_TERMIOS, gettermios) \
#define SYS_FORK 15 O(SYS_SET_TERMIOS, settermios) \
#define SYS_EXEC 16 O(SYS_FORK, fork) \
#define SYS_SLEEP 17 O(SYS_EXEC, exec) \
#define SYS_WAIT 18 O(SYS_SLEEP, sleep) \
#define SYS_FSTAT 19 O(SYS_WAIT, wait) \
#define SYS_READ_DIR_ENTRIES 21 O(SYS_FSTAT, fstat) \
#define SYS_SET_UID 22 O(SYS_READ_DIR, readdir) \
#define SYS_SET_GID 23 O(SYS_SET_UID, setuid) \
#define SYS_SET_EUID 24 O(SYS_SET_GID, setgid) \
#define SYS_SET_EGID 25 O(SYS_SET_EUID, seteuid) \
#define SYS_SET_REUID 26 O(SYS_SET_EGID, setegid) \
#define SYS_SET_REGID 27 O(SYS_SET_REUID, setreuid) \
#define SYS_GET_UID 28 O(SYS_SET_REGID, setregid) \
#define SYS_GET_GID 29 O(SYS_GET_UID, getuid) \
#define SYS_GET_EUID 30 O(SYS_GET_GID, getgid) \
#define SYS_GET_EGID 31 O(SYS_GET_EUID, geteuid) \
#define SYS_GET_PWD 32 O(SYS_GET_EGID, getegid) \
#define SYS_SET_PWD 33 O(SYS_GET_PWD, getpwd) \
#define SYS_CLOCK_GETTIME 34 O(SYS_SET_PWD, setpwd) \
#define SYS_PIPE 35 O(SYS_CLOCK_GETTIME, clock_gettime) \
#define SYS_DUP 36 O(SYS_PIPE, pipe) \
#define SYS_DUP2 37 O(SYS_DUP, dup) \
#define SYS_KILL 39 O(SYS_DUP2, dup2) \
#define SYS_SIGNAL 40 O(SYS_KILL, kill) \
#define SYS_TCSETPGRP 42 O(SYS_SIGNAL, signal) \
#define SYS_GET_PID 43 O(SYS_TCSETPGRP, tcsetpgrp) \
#define SYS_GET_PGID 44 O(SYS_GET_PID, getpid) \
#define SYS_SET_PGID 45 O(SYS_GET_PGID, getpgid) \
#define SYS_FCNTL 46 O(SYS_SET_PGID, setpgid) \
#define SYS_NANOSLEEP 47 O(SYS_FCNTL, fcntl) \
#define SYS_FSTATAT 48 O(SYS_NANOSLEEP, nanosleep) \
#define SYS_STAT 49 // stat/lstat O(SYS_FSTATAT, fstatat) \
#define SYS_SYNC 50 O(SYS_STAT, stat) \
#define SYS_MMAP 51 O(SYS_SYNC, sync) \
#define SYS_MUNMAP 52 O(SYS_MMAP, mmap) \
#define SYS_TTY_CTRL 53 O(SYS_MUNMAP, munmap) \
#define SYS_POWEROFF 54 O(SYS_TTY_CTRL, tty_ctrl) \
#define SYS_CHMOD 55 O(SYS_POWEROFF, poweroff) \
#define SYS_CREATE 56 // creat, mkfifo O(SYS_CHMOD, chmod) \
#define SYS_CREATE_DIR 57 // mkdir O(SYS_CREATE, create) \
#define SYS_UNLINK 58 O(SYS_CREATE_DIR, create_dir) \
#define SYS_READLINK 59 O(SYS_UNLINK, unlink) \
#define SYS_READLINKAT 60 O(SYS_READLINK, readlink) \
#define SYS_MSYNC 61 O(SYS_READLINKAT, readlinkat) \
#define SYS_PREAD 62 O(SYS_MSYNC, msync) \
#define SYS_CHOWN 63 O(SYS_PREAD, pread) \
#define SYS_LOAD_KEYMAP 64 O(SYS_CHOWN, chown) \
#define SYS_SOCKET 65 O(SYS_LOAD_KEYMAP, load_keymap) \
#define SYS_BIND 66 O(SYS_SOCKET, socket) \
#define SYS_SENDTO 67 O(SYS_BIND, bind) \
#define SYS_RECVFROM 68 O(SYS_SENDTO, sendto) \
#define SYS_IOCTL 69 O(SYS_RECVFROM, recvfrom) \
#define SYS_ACCEPT 70 O(SYS_IOCTL, ioctl) \
#define SYS_CONNECT 71 O(SYS_ACCEPT, accept) \
#define SYS_LISTEN 72 O(SYS_CONNECT, connect) \
O(SYS_LISTEN, listen) \
O(SYS_PSELECT, pselect) \
enum Syscall
{
#define O(enum, name) enum,
__SYSCALL_LIST(O)
#undef O
__SYSCALL_COUNT
};
__END_DECLS __END_DECLS

View File

@ -7,18 +7,10 @@
__BEGIN_DECLS __BEGIN_DECLS
#define __need_time_t
#define __need_suseconds_t
#include <sys/types.h>
// NOTE: select is declared from here // NOTE: select is declared from here
#include <sys/select.h> #include <sys/select.h>
struct timeval #include <bits/types/timeval.h>
{
time_t tv_sec; /* Seconds. */
suseconds_t tc_usec; /* Microseconds. */
};
struct itimerval struct itimerval
{ {

31
libc/sys/select.cpp Normal file
View File

@ -0,0 +1,31 @@
#include <sys/select.h>
#include <sys/syscall.h>
#include <unistd.h>
int pselect(int nfds, fd_set* __restrict readfds, fd_set* __restrict writefds, fd_set* __restrict errorfds, const struct timespec* __restrict timeout, const sigset_t* __restrict sigmask)
{
sys_pselect_t arguments {
.nfds = nfds,
.readfds = readfds,
.writefds = writefds,
.errorfds = errorfds,
.timeout = timeout,
.sigmask = sigmask
};
return syscall(SYS_PSELECT, &arguments);
}
int select(int nfds, fd_set* __restrict readfds, fd_set* __restrict writefds, fd_set* __restrict errorfds, struct timeval* __restrict timeout)
{
timespec* pts = nullptr;
timespec ts;
if (timeout)
{
ts.tv_sec = timeout->tv_sec;
ts.tv_nsec = timeout->tc_usec * 1000;
pts = &ts;
}
// TODO: "select may update timeout", should we?
return pselect(nfds, readfds, writefds, errorfds, pts, nullptr);
}

View File

@ -34,6 +34,9 @@ set(USERSPACE_PROJECTS
test-globals test-globals
test-mouse test-mouse
test-sort test-sort
test-tcp
test-udp
test-unix-socket
touch touch
u8sum u8sum
whoami whoami

View File

@ -1,7 +1,12 @@
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/un.h> #include <sys/un.h>
#include <unistd.h>
#define MAX(a, b) ((a) < (b) ? (b) : (a))
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
@ -33,15 +38,17 @@ int main(int argc, char** argv)
return 1; return 1;
} }
char buffer[128]; sockaddr_storage storage;
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0); if (recv(socket, &storage, sizeof(storage), 0) == -1)
if (nrecv == -1)
{ {
perror("recv"); perror("recv");
return 1; return 1;
} }
buffer[nrecv] = '\0';
printf("%s\n", buffer); close(socket);
char buffer[MAX(INET_ADDRSTRLEN, INET6_ADDRSTRLEN)];
printf("%s\n", inet_ntop(storage.ss_family, storage.ss_storage, buffer, sizeof(buffer)));
return 0; return 0;
} }

View File

@ -1,4 +1,5 @@
#include <BAN/ByteSpan.h> #include <BAN/ByteSpan.h>
#include <BAN/Debug.h>
#include <BAN/Endianness.h> #include <BAN/Endianness.h>
#include <BAN/HashMap.h> #include <BAN/HashMap.h>
#include <BAN/IPv4.h> #include <BAN/IPv4.h>
@ -39,6 +40,19 @@ struct DNSAnswer
}; };
static_assert(sizeof(DNSAnswer) == 12); static_assert(sizeof(DNSAnswer) == 12);
enum QTYPE : uint16_t
{
A = 0x0001,
CNAME = 0x0005,
AAAA = 0x001C,
};
struct DNSEntry
{
time_t valid_until { 0 };
BAN::IPv4Address address { 0 };
};
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id) bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
{ {
static uint8_t buffer[4096]; static uint8_t buffer[4096];
@ -61,8 +75,8 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
} }
request.data[idx++] = 0x00; request.data[idx++] = 0x00;
*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2; *(uint16_t*)&request.data[idx] = htons(QTYPE::A); idx += 2;
*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2; *(uint16_t*)&request.data[idx] = htons(0x0001); idx += 2;
sockaddr_in nameserver; sockaddr_in nameserver;
nameserver.sin_family = AF_INET; nameserver.sin_family = AF_INET;
@ -70,33 +84,33 @@ bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
nameserver.sin_addr.s_addr = inet_addr("8.8.8.8"); nameserver.sin_addr.s_addr = inet_addr("8.8.8.8");
if (sendto(socket, &request, sizeof(DNSPacket) + idx, 0, (sockaddr*)&nameserver, sizeof(nameserver)) == -1) if (sendto(socket, &request, sizeof(DNSPacket) + idx, 0, (sockaddr*)&nameserver, sizeof(nameserver)) == -1)
{ {
perror("sendto"); dprintln("sendto: {}", strerror(errno));
return false; return false;
} }
return true; return true;
} }
BAN::Optional<BAN::String> read_dns_response(int socket, uint16_t id) BAN::Optional<DNSEntry> read_dns_response(int socket, uint16_t id)
{ {
static uint8_t buffer[4096]; static uint8_t buffer[4096];
ssize_t nrecv = recvfrom(socket, buffer, sizeof(buffer), 0, nullptr, nullptr); ssize_t nrecv = recvfrom(socket, buffer, sizeof(buffer), 0, nullptr, nullptr);
if (nrecv == -1) if (nrecv == -1)
{ {
perror("recvfrom"); dprintln("recvfrom: {}", strerror(errno));
return {}; return {};
} }
DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer); DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
if (reply.identification != id) if (reply.identification != id)
{ {
fprintf(stderr, "Reply to invalid packet\n"); dprintln("Reply to invalid packet");
return {}; return {};
} }
if (reply.flags & 0x0F) if (reply.flags & 0x0F)
{ {
fprintf(stderr, "DNS error (rcode %u)\n", (unsigned)(reply.flags & 0xF)); dprintln("DNS error (rcode {})", (unsigned)(reply.flags & 0xF));
return {}; return {};
} }
@ -109,13 +123,22 @@ BAN::Optional<BAN::String> read_dns_response(int socket, uint16_t id)
} }
DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]); DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]);
if (answer.type() != QTYPE::A)
{
dprintln("Not A record");
return {};
}
if (answer.data_len() != 4) if (answer.data_len() != 4)
{ {
fprintf(stderr, "Not IPv4\n"); dprintln("corrupted package");
return {}; return {};
} }
return inet_ntoa({ .s_addr = *reinterpret_cast<uint32_t*>(answer.data) }); DNSEntry result;
result.valid_until = time(nullptr) + answer.ttl();
result.address = BAN::IPv4Address(*reinterpret_cast<uint32_t*>(answer.data));
return result;
} }
int create_service_socket() int create_service_socket()
@ -123,7 +146,7 @@ int create_service_socket()
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0); int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1) if (socket == -1)
{ {
perror("socket"); dprintln("socket: {}", strerror(errno));
return -1; return -1;
} }
@ -132,21 +155,21 @@ int create_service_socket()
strcpy(addr.sun_path, "/tmp/resolver.sock"); strcpy(addr.sun_path, "/tmp/resolver.sock");
if (bind(socket, (sockaddr*)&addr, sizeof(addr)) == -1) if (bind(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{ {
perror("bind"); dprintln("bind: {}", strerror(errno));
close(socket); close(socket);
return -1; return -1;
} }
if (chmod("/tmp/resolver.sock", 0777) == -1) if (chmod("/tmp/resolver.sock", 0777) == -1)
{ {
perror("chmod"); dprintln("chmod: {}", strerror(errno));
close(socket); close(socket);
return -1; return -1;
} }
if (listen(socket, 10) == -1) if (listen(socket, 10) == -1)
{ {
perror("listen"); dprintln("listen: {}", strerror(errno));
close(socket); close(socket);
return -1; return -1;
} }
@ -160,7 +183,7 @@ BAN::Optional<BAN::String> read_service_query(int socket)
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0); ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1) if (nrecv == -1)
{ {
perror("recv"); dprintln("recv: {}", strerror(errno));
return {}; return {};
} }
buffer[nrecv] = '\0'; buffer[nrecv] = '\0';
@ -178,39 +201,60 @@ int main(int, char**)
int dns_socket = socket(AF_INET, SOCK_DGRAM, 0); int dns_socket = socket(AF_INET, SOCK_DGRAM, 0);
if (dns_socket == -1) if (dns_socket == -1)
{ {
perror("socket"); dprintln("socket: {}", strerror(errno));
return 1; return 1;
} }
BAN::HashMap<BAN::String, DNSEntry> dns_cache;
for (;;) for (;;)
{ {
int client = accept(service_socket, nullptr, nullptr); int client = accept(service_socket, nullptr, nullptr);
if (client == -1) if (client == -1)
{ {
perror("accept"); dprintln("accept: {}", strerror(errno));
continue; continue;
} }
auto query = read_service_query(client); auto query = read_service_query(client);
if (!query.has_value()) if (!query.has_value())
continue;
uint16_t id = rand() % 0xFFFF;
if (send_dns_query(dns_socket, *query, id))
{ {
auto response = read_dns_response(dns_socket, id);
if (response.has_value())
{
if (send(client, response->data(), response->size() + 1, 0) == -1)
perror("send");
close(client); close(client);
continue; continue;
} }
BAN::Optional<DNSEntry> result;
if (dns_cache.contains(*query))
{
auto& cached = dns_cache[*query];
if (time(nullptr) <= cached.valid_until)
result = cached;
else
dns_cache.remove(*query);
} }
char message[] = "unavailable"; if (!result.has_value())
send(client, message, sizeof(message), 0); {
uint16_t id = rand() % 0xFFFF;
if (send_dns_query(dns_socket, *query, id))
{
result = read_dns_response(dns_socket, id);
if (result.has_value())
(void)dns_cache.insert(*query, *result);
}
}
if (!result.has_value())
result = DNSEntry { .valid_until = 0, .address = BAN::IPv4Address(INADDR_ANY) };
sockaddr_storage storage;
storage.ss_family = AF_INET;
memcpy(storage.ss_storage, &result->address.raw, sizeof(result->address.raw));
if (send(client, &storage, sizeof(storage), 0) == -1)
dprintln("send: {}", strerror(errno));
close(client); close(client);
} }

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.26)
project(test-tcp CXX)
set(SOURCES
main.cpp
)
add_executable(test-tcp ${SOURCES})
target_compile_options(test-tcp PUBLIC -O2 -g)
target_link_libraries(test-tcp PUBLIC libc)
add_custom_target(test-tcp-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/test-tcp ${BANAN_BIN}/
DEPENDS test-tcp
)

111
userspace/test-tcp/main.cpp Normal file
View File

@ -0,0 +1,111 @@
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
in_addr_t get_ipv4_address(const char* query)
{
if (in_addr_t ipv4 = inet_addr(query); ipv4 != (in_addr_t)(-1))
return ipv4;
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
perror("socket");
return -1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("connect");
close(socket);
return -1;
}
if (send(socket, query, strlen(query), 0) == -1)
{
perror("send");
close(socket);
return -1;
}
sockaddr_storage storage;
if (recv(socket, &storage, sizeof(storage), 0) == -1)
{
perror("recv");
close(socket);
return -1;
}
close(socket);
return *reinterpret_cast<in_addr_t*>(storage.ss_storage);
}
int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "usage: %s IPADDR\n", argv[0]);
return 1;
}
in_addr_t ipv4 = get_ipv4_address(argv[1]);
if (ipv4 == (in_addr_t)(-1))
{
fprintf(stderr, "could not parse address '%s'\n", argv[1]);
return 1;
}
int socket = ::socket(AF_INET, SOCK_STREAM, 0);
if (socket == -1)
{
perror("socket");
return 1;
}
printf("connecting to %s\n", inet_ntoa({ .s_addr = ipv4 }));
sockaddr_in server_addr;
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(80);
server_addr.sin_addr.s_addr = ipv4;
if (connect(socket, (sockaddr*)&server_addr, sizeof(server_addr)) == -1)
{
perror("connect");
return 1;
}
char request[128];
strcpy(request, "GET / HTTP/1.1\r\n");
strcat(request, "Host: "); strcat(request, argv[1]); strcat(request, "\r\n");
strcat(request, "Accept: */*\r\n");
strcat(request, "Connection: close\r\n");
strcat(request, "\r\n");
if (send(socket, request, strlen(request), 0) == -1)
{
perror("send");
return 1;
}
char buffer[1024];
for (;;)
{
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
break;
}
write(STDOUT_FILENO, buffer, nrecv);
}
close(socket);
return 0;
}

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.26)
project(test-udp CXX)
set(SOURCES
main.cpp
)
add_executable(test-udp ${SOURCES})
target_compile_options(test-udp PUBLIC -O2 -g)
target_link_libraries(test-udp PUBLIC libc)
add_custom_target(test-udp-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/test-udp ${BANAN_BIN}/
DEPENDS test-udp
)

View File

@ -0,0 +1,88 @@
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>
int usage(const char* argv0)
{
fprintf(stderr, "usage: %s [-s|-c] [-a addr] [-p port]\n", argv0);
return 1;
}
int main(int argc, char** argv)
{
bool server = false;
uint32_t addr = 0;
uint16_t port = 0;
for (int i = 1; i < argc; i++)
{
if (strcmp(argv[i], "-s") == 0)
server = true;
else if (strcmp(argv[i], "-c") == 0)
server = false;
else if (strcmp(argv[i], "-a") == 0)
addr = inet_addr(argv[++i]);
else if (strcmp(argv[i], "-p") == 0)
sscanf(argv[++i], "%hu", &port);
else
return usage(argv[0]);
}
int socket = ::socket(AF_INET, SOCK_DGRAM, 0);
if (socket == -1)
{
perror("socket");
return 1;
}
if (server)
{
sockaddr_in bind_addr;
bind_addr.sin_family = AF_INET;
bind_addr.sin_port = htons(port);
bind_addr.sin_addr.s_addr = addr;
if (bind(socket, (sockaddr*)&bind_addr, sizeof(bind_addr)) == -1)
{
perror("bind");
return 1;
}
printf("listening on %s:%hu\n", inet_ntoa(bind_addr.sin_addr), ntohs(bind_addr.sin_port));
char buffer[1024];
sockaddr_in sender;
socklen_t sender_len = sizeof(sender);
if (recvfrom(socket, buffer, sizeof(buffer), 0, (sockaddr*)&sender, &sender_len) == -1)
{
perror("recvfrom");
return 1;
}
printf("received from %s:%hu\n", inet_ntoa(sender.sin_addr), ntohs(sender.sin_port));
printf(" %s\n", buffer);
}
else
{
const char buffer[] = "Hello from banan-os!";
sockaddr_in server_addr;
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
server_addr.sin_addr.s_addr = addr;
printf("sending to %s:%hu\n", inet_ntoa(server_addr.sin_addr), ntohs(server_addr.sin_port));
if (sendto(socket, buffer, sizeof(buffer), 0, (sockaddr*)&server_addr, sizeof(server_addr)) == -1)
{
perror("sendto");
return 1;
}
}
close(socket);
return 0;
}

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.26)
project(test-unix-socket CXX)
set(SOURCES
main.cpp
)
add_executable(test-unix-socket ${SOURCES})
target_compile_options(test-unix-socket PUBLIC -O2 -g)
target_link_libraries(test-unix-socket PUBLIC libc)
add_custom_target(test-unix-socket-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/test-unix-socket ${BANAN_BIN}/
DEPENDS test-unix-socket
)

View File

@ -0,0 +1,200 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <unistd.h>
#define SOCK_PATH "/tmp/test.sock"
int server_connection()
{
int socket = ::socket(AF_UNIX, SOCK_STREAM, 0);
if (socket == -1)
{
perror("server: socket");
return 1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, SOCK_PATH);
if (bind(socket, (sockaddr*)&addr, sizeof(addr)))
{
perror("server: bind");
return 1;
}
if (listen(socket, 0) == -1)
{
perror("server: listen");
return 1;
}
int client = accept(socket, nullptr, nullptr);
if (client == -1)
{
perror("server: accept");
return 1;
}
sleep(2);
char buffer[128];
ssize_t nrecv = recv(client, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("server: recv");
return 1;
}
printf("server: read %d bytes\n", (int)nrecv);
printf("server: '%s'\n", buffer);
char message[] = "Hello from server";
if (send(client, message, sizeof(message), 0) == -1)
{
perror("server: send");
return 1;
}
close(client);
close(socket);
return 0;
}
int client_connection()
{
sleep(1);
int socket = ::socket(AF_UNIX, SOCK_STREAM, 0);
if (socket == -1)
{
perror("client: socket");
return 1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, SOCK_PATH);
if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("client: connect");
return 1;
}
char message[] = "Hello from client";
if (send(socket, message, sizeof(message), 0) == -1)
{
perror("client: send");
return 1;
}
char buffer[128];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("client: recv");
return 1;
}
printf("client: read %d bytes\n", (int)nrecv);
printf("client: '%s'\n", buffer);
close(socket);
return 0;
}
int server_connectionless()
{
int socket = ::socket(AF_UNIX, SOCK_DGRAM, 0);
if (socket == -1)
{
perror("server: socket");
return 1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, SOCK_PATH);
if (bind(socket, (sockaddr*)&addr, sizeof(addr)))
{
perror("server: bind");
return 1;
}
sleep(2);
char buffer[128];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("server: recv");
return 1;
}
close(socket);
return 0;
}
int client_connectionless()
{
sleep(1);
int socket = ::socket(AF_UNIX, SOCK_DGRAM, 0);
if (socket == -1)
{
perror("client: socket");
return 1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, SOCK_PATH);
char message[] = "Hello from client";
if (sendto(socket, message, sizeof(message), 0, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("client: send");
return 1;
}
close(socket);
return 0;
}
int test_mode(int (*client)(), int (*server)())
{
pid_t pid = fork();
if (pid == -1)
{
perror("fork");
return 1;
}
if (pid == 0)
exit(server());
if (int ret = client())
{
kill(pid, SIGKILL);
return ret;
}
int ret;
waitpid(pid, &ret, 0);
if (remove(SOCK_PATH) == -1)
perror("remove");
return ret;
}
int main()
{
if (test_mode(client_connection, server_connection))
return 1;
if (test_mode(client_connectionless, server_connectionless))
return 2;
return 0;
}