Compare commits

...

22 Commits

Author SHA1 Message Date
Bananymous 49889858fa Kernel: Allow chmod on TmpSocketInode 2024-02-08 03:16:01 +02:00
Bananymous 2424f38a62 Userspace: Implement super simple DNS resolver in userspace
You connect to this service using unix domain sockets and send the
asked domain name. It will respond with ip address or 'unavailable'

There is no DNS cache implemented so all calls ask the nameserver.
2024-02-08 03:14:00 +02:00
Bananymous 218456d127 BAN: Fix some includes 2024-02-08 03:13:21 +02:00
Bananymous e7dd03e551 Kernel: Implement basic connection-mode unix domain sockets 2024-02-08 02:28:19 +02:00
Bananymous 0c8e9fe095 Kernel: Add operator bool() for WeakPtr 2024-02-08 02:26:46 +02:00
Bananymous 5b4acec4ca BAN: Add capacity() getter for Queue 2024-02-07 22:53:56 +02:00
Bananymous e26f360d93 Kernel: allow kmalloc of size 0 2024-02-07 22:36:24 +02:00
Bananymous 2cc9534570 BAN: Add emplace for Variant
This allows variant to store values that are not copy/move
constructible.
2024-02-07 22:33:16 +02:00
Bananymous 572c4052f6 Kernel: Fix Process APIs 2024-02-07 15:57:45 +02:00
Bananymous 132286895f Kernel: Implement Socket inodes for tmpfs 2024-02-07 15:57:45 +02:00
Bananymous 454bee3f02 LibC: Fix sockaddr_un implementation 2024-02-07 15:57:45 +02:00
Bananymous 41cad88d6e Kernel/LibC: Implement dummy syscalls for accept, connect, listen 2024-02-07 15:57:45 +02:00
Bananymous 40e341b0ee BAN: Remove unstable hash map and set
These can now be implemented safely with new linked list api
2024-02-06 17:35:15 +02:00
Bananymous 5da59c9151 Kernel: Make better abstractions for networking 2024-02-06 16:45:39 +02:00
Bananymous f804e87f7d Kernel: Implement basic gateway for network interfaces 2024-02-05 18:18:56 +02:00
Bananymous dd3641f054 Kernel: Cleanup ARPTable code
Packet process is now killed if ARPTable dies.

ARP wait loop now just reschecules so timeout actually works.
2024-02-05 18:18:56 +02:00
Bananymous b2291ce162 Kernel/BAN: Fix network strucute endianness 2024-02-05 18:18:56 +02:00
Bananymous c35ed6570b LibC: Implement endiannes and ip address functions 2024-02-05 18:18:56 +02:00
Bananymous d15cbb2d6a Kernel: Fix IPv4 header checksum calculation 2024-02-05 18:18:56 +02:00
Bananymous b8cf6432ef BAN: Implement host_to_network_endian 2024-02-05 17:29:24 +02:00
Bananymous 89805fb092 dhcp-client: Use dprintln for debug printing 2024-02-05 01:24:45 +02:00
Bananymous 692cec8458 Kernel/Userspace/LibC: Implement basic dprintln for userspace 2024-02-05 01:24:09 +02:00
67 changed files with 1920 additions and 468 deletions

View File

@ -2,6 +2,8 @@
#include <BAN/Span.h>
#include <stdint.h>
namespace BAN
{

31
BAN/include/BAN/Debug.h Normal file
View File

@ -0,0 +1,31 @@
#pragma once
#if __is_kernel
#error "This is userspace only file"
#endif
#include <BAN/Formatter.h>
#include <stdio.h>
#define __debug_putchar [](int c) { putc(c, stddbg); }
#define dprintln(...) \
do { \
BAN::Formatter::print(__debug_putchar, __VA_ARGS__); \
BAN::Formatter::print(__debug_putchar,"\r\n"); \
fflush(stddbg); \
} while (false)
#define dwarnln(...) \
do { \
BAN::Formatter::print(__debug_putchar, "\e[33m"); \
dprintln(__VA_ARGS__); \
BAN::Formatter::print(__debug_putchar, "\e[m"); \
} while(false)
#define derrorln(...) \
do { \
BAN::Formatter::print(__debug_putchar, "\e[31m"); \
dprintln(__VA_ARGS__); \
BAN::Formatter::print(__debug_putchar, "\e[m"); \
} while(false)

View File

@ -90,4 +90,10 @@ namespace BAN
template<integral T>
using NetworkEndian = BigEndian<T>;
template<integral T>
constexpr T host_to_network_endian(T value)
{
return host_to_big_endian(value);
}
}

View File

@ -7,7 +7,7 @@
namespace BAN
{
template<typename Key, typename T, typename HASH = BAN::hash<Key>, bool STABLE = true>
template<typename Key, typename T, typename HASH = BAN::hash<Key>>
class HashMap
{
public:
@ -32,12 +32,12 @@ namespace BAN
public:
HashMap() = default;
HashMap(const HashMap<Key, T, HASH, STABLE>&);
HashMap(HashMap<Key, T, HASH, STABLE>&&);
HashMap(const HashMap<Key, T, HASH>&);
HashMap(HashMap<Key, T, HASH>&&);
~HashMap();
HashMap<Key, T, HASH, STABLE>& operator=(const HashMap<Key, T, HASH, STABLE>&);
HashMap<Key, T, HASH, STABLE>& operator=(HashMap<Key, T, HASH, STABLE>&&);
HashMap<Key, T, HASH>& operator=(const HashMap<Key, T, HASH>&);
HashMap<Key, T, HASH>& operator=(HashMap<Key, T, HASH>&&);
ErrorOr<void> insert(const Key&, const T&);
ErrorOr<void> insert(const Key&, T&&);
@ -74,26 +74,26 @@ namespace BAN
friend iterator;
};
template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH, STABLE>::HashMap(const HashMap<Key, T, HASH, STABLE>& other)
template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH>::HashMap(const HashMap<Key, T, HASH>& other)
{
*this = other;
}
template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH, STABLE>::HashMap(HashMap<Key, T, HASH, STABLE>&& other)
template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH>::HashMap(HashMap<Key, T, HASH>&& other)
{
*this = move(other);
}
template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH, STABLE>::~HashMap()
template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH>::~HashMap()
{
clear();
}
template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH, STABLE>& HashMap<Key, T, HASH, STABLE>::operator=(const HashMap<Key, T, HASH, STABLE>& other)
template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH>& HashMap<Key, T, HASH>::operator=(const HashMap<Key, T, HASH>& other)
{
clear();
m_buckets = other.m_buckets;
@ -101,8 +101,8 @@ namespace BAN
return *this;
}
template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH, STABLE>& HashMap<Key, T, HASH, STABLE>::operator=(HashMap<Key, T, HASH, STABLE>&& other)
template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH>& HashMap<Key, T, HASH>::operator=(HashMap<Key, T, HASH>&& other)
{
clear();
m_buckets = move(other.m_buckets);
@ -111,21 +111,21 @@ namespace BAN
return *this;
}
template<typename Key, typename T, typename HASH, bool STABLE>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::insert(const Key& key, const T& value)
template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH>::insert(const Key& key, const T& value)
{
return insert(key, move(T(value)));
}
template<typename Key, typename T, typename HASH, bool STABLE>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::insert(const Key& key, T&& value)
template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH>::insert(const Key& key, T&& value)
{
return emplace(key, move(value));
}
template<typename Key, typename T, typename HASH, bool STABLE>
template<typename Key, typename T, typename HASH>
template<typename... Args>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::emplace(const Key& key, Args&&... args)
ErrorOr<void> HashMap<Key, T, HASH>::emplace(const Key& key, Args&&... args)
{
ASSERT(!contains(key));
TRY(rebucket(m_size + 1));
@ -135,15 +135,15 @@ namespace BAN
return {};
}
template<typename Key, typename T, typename HASH, bool STABLE>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::reserve(size_type size)
template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH>::reserve(size_type size)
{
TRY(rebucket(size));
return {};
}
template<typename Key, typename T, typename HASH, bool STABLE>
void HashMap<Key, T, HASH, STABLE>::remove(const Key& key)
template<typename Key, typename T, typename HASH>
void HashMap<Key, T, HASH>::remove(const Key& key)
{
if (empty()) return;
auto& bucket = get_bucket(key);
@ -158,15 +158,15 @@ namespace BAN
}
}
template<typename Key, typename T, typename HASH, bool STABLE>
void HashMap<Key, T, HASH, STABLE>::clear()
template<typename Key, typename T, typename HASH>
void HashMap<Key, T, HASH>::clear()
{
m_buckets.clear();
m_size = 0;
}
template<typename Key, typename T, typename HASH, bool STABLE>
T& HashMap<Key, T, HASH, STABLE>::operator[](const Key& key)
template<typename Key, typename T, typename HASH>
T& HashMap<Key, T, HASH>::operator[](const Key& key)
{
ASSERT(!empty());
auto& bucket = get_bucket(key);
@ -176,8 +176,8 @@ namespace BAN
ASSERT(false);
}
template<typename Key, typename T, typename HASH, bool STABLE>
const T& HashMap<Key, T, HASH, STABLE>::operator[](const Key& key) const
template<typename Key, typename T, typename HASH>
const T& HashMap<Key, T, HASH>::operator[](const Key& key) const
{
ASSERT(!empty());
const auto& bucket = get_bucket(key);
@ -187,8 +187,8 @@ namespace BAN
ASSERT(false);
}
template<typename Key, typename T, typename HASH, bool STABLE>
bool HashMap<Key, T, HASH, STABLE>::contains(const Key& key) const
template<typename Key, typename T, typename HASH>
bool HashMap<Key, T, HASH>::contains(const Key& key) const
{
if (empty()) return false;
const auto& bucket = get_bucket(key);
@ -198,20 +198,20 @@ namespace BAN
return false;
}
template<typename Key, typename T, typename HASH, bool STABLE>
bool HashMap<Key, T, HASH, STABLE>::empty() const
template<typename Key, typename T, typename HASH>
bool HashMap<Key, T, HASH>::empty() const
{
return m_size == 0;
}
template<typename Key, typename T, typename HASH, bool STABLE>
typename HashMap<Key, T, HASH, STABLE>::size_type HashMap<Key, T, HASH, STABLE>::size() const
template<typename Key, typename T, typename HASH>
typename HashMap<Key, T, HASH>::size_type HashMap<Key, T, HASH>::size() const
{
return m_size;
}
template<typename Key, typename T, typename HASH, bool STABLE>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::rebucket(size_type bucket_count)
template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH>::rebucket(size_type bucket_count)
{
if (m_buckets.size() >= bucket_count)
return {};
@ -222,13 +222,10 @@ namespace BAN
for (auto& bucket : m_buckets)
{
for (Entry& entry : bucket)
for (auto it = bucket.begin(); it != bucket.end();)
{
size_type bucket_index = HASH()(entry.key) % new_buckets.size();
if constexpr(STABLE)
TRY(new_buckets[bucket_index].push_back(entry));
else
TRY(new_buckets[bucket_index].push_back(move(entry)));
size_type new_bucket_index = HASH()(it->key) % new_buckets.size();
it = bucket.move_element_to_other_linked_list(new_buckets[new_bucket_index], new_buckets[new_bucket_index].end(), it);
}
}
@ -236,27 +233,20 @@ namespace BAN
return {};
}
template<typename Key, typename T, typename HASH, bool STABLE>
LinkedList<typename HashMap<Key, T, HASH, STABLE>::Entry>& HashMap<Key, T, HASH, STABLE>::get_bucket(const Key& key)
template<typename Key, typename T, typename HASH>
LinkedList<typename HashMap<Key, T, HASH>::Entry>& HashMap<Key, T, HASH>::get_bucket(const Key& key)
{
ASSERT(!m_buckets.empty());
auto index = HASH()(key) % m_buckets.size();
return m_buckets[index];
}
template<typename Key, typename T, typename HASH, bool STABLE>
const LinkedList<typename HashMap<Key, T, HASH, STABLE>::Entry>& HashMap<Key, T, HASH, STABLE>::get_bucket(const Key& key) const
template<typename Key, typename T, typename HASH>
const LinkedList<typename HashMap<Key, T, HASH>::Entry>& HashMap<Key, T, HASH>::get_bucket(const Key& key) const
{
ASSERT(!m_buckets.empty());
auto index = HASH()(key) % m_buckets.size();
return m_buckets[index];
}
// Unstable hash map moves values between container during rebucketing.
// This means that if insertion to map fails, elements could be in invalid state
// and that container is no longer usable. This is better if either way you are
// going to stop using the hash map after insertion fails.
template<typename Key, typename T, typename HASH = BAN::hash<Key>>
using HashMapUnstable = HashMap<Key, T, HASH, false>;
}

View File

@ -11,7 +11,7 @@
namespace BAN
{
template<typename T, typename HASH = hash<T>, bool STABLE = true>
template<typename T, typename HASH = hash<T>>
class HashSet
{
public:
@ -55,23 +55,23 @@ namespace BAN
size_type m_size = 0;
};
template<typename T, typename HASH, bool STABLE>
HashSet<T, HASH, STABLE>::HashSet(const HashSet& other)
template<typename T, typename HASH>
HashSet<T, HASH>::HashSet(const HashSet& other)
: m_buckets(other.m_buckets)
, m_size(other.m_size)
{
}
template<typename T, typename HASH, bool STABLE>
HashSet<T, HASH, STABLE>::HashSet(HashSet&& other)
template<typename T, typename HASH>
HashSet<T, HASH>::HashSet(HashSet&& other)
: m_buckets(move(other.m_buckets))
, m_size(other.m_size)
{
other.clear();
}
template<typename T, typename HASH, bool STABLE>
HashSet<T, HASH, STABLE>& HashSet<T, HASH, STABLE>::operator=(const HashSet& other)
template<typename T, typename HASH>
HashSet<T, HASH>& HashSet<T, HASH>::operator=(const HashSet& other)
{
clear();
m_buckets = other.m_buckets;
@ -79,8 +79,8 @@ namespace BAN
return *this;
}
template<typename T, typename HASH, bool STABLE>
HashSet<T, HASH, STABLE>& HashSet<T, HASH, STABLE>::operator=(HashSet&& other)
template<typename T, typename HASH>
HashSet<T, HASH>& HashSet<T, HASH>::operator=(HashSet&& other)
{
clear();
m_buckets = move(other.m_buckets);
@ -89,14 +89,14 @@ namespace BAN
return *this;
}
template<typename T, typename HASH, bool STABLE>
ErrorOr<void> HashSet<T, HASH, STABLE>::insert(const T& key)
template<typename T, typename HASH>
ErrorOr<void> HashSet<T, HASH>::insert(const T& key)
{
return insert(move(T(key)));
}
template<typename T, typename HASH, bool STABLE>
ErrorOr<void> HashSet<T, HASH, STABLE>::insert(T&& key)
template<typename T, typename HASH>
ErrorOr<void> HashSet<T, HASH>::insert(T&& key)
{
if (!empty() && get_bucket(key).contains(key))
return {};
@ -107,8 +107,8 @@ namespace BAN
return {};
}
template<typename T, typename HASH, bool STABLE>
void HashSet<T, HASH, STABLE>::remove(const T& key)
template<typename T, typename HASH>
void HashSet<T, HASH>::remove(const T& key)
{
if (empty()) return;
auto& bucket = get_bucket(key);
@ -123,41 +123,41 @@ namespace BAN
}
}
template<typename T, typename HASH, bool STABLE>
void HashSet<T, HASH, STABLE>::clear()
template<typename T, typename HASH>
void HashSet<T, HASH>::clear()
{
m_buckets.clear();
m_size = 0;
}
template<typename T, typename HASH, bool STABLE>
ErrorOr<void> HashSet<T, HASH, STABLE>::reserve(size_type size)
template<typename T, typename HASH>
ErrorOr<void> HashSet<T, HASH>::reserve(size_type size)
{
TRY(rebucket(size));
return {};
}
template<typename T, typename HASH, bool STABLE>
bool HashSet<T, HASH, STABLE>::contains(const T& key) const
template<typename T, typename HASH>
bool HashSet<T, HASH>::contains(const T& key) const
{
if (empty()) return false;
return get_bucket(key).contains(key);
}
template<typename T, typename HASH, bool STABLE>
typename HashSet<T, HASH, STABLE>::size_type HashSet<T, HASH, STABLE>::size() const
template<typename T, typename HASH>
typename HashSet<T, HASH>::size_type HashSet<T, HASH>::size() const
{
return m_size;
}
template<typename T, typename HASH, bool STABLE>
bool HashSet<T, HASH, STABLE>::empty() const
template<typename T, typename HASH>
bool HashSet<T, HASH>::empty() const
{
return m_size == 0;
}
template<typename T, typename HASH, bool STABLE>
ErrorOr<void> HashSet<T, HASH, STABLE>::rebucket(size_type bucket_count)
template<typename T, typename HASH>
ErrorOr<void> HashSet<T, HASH>::rebucket(size_type bucket_count)
{
if (m_buckets.size() >= bucket_count)
return {};
@ -169,13 +169,10 @@ namespace BAN
for (auto& bucket : m_buckets)
{
for (T& key : bucket)
for (auto it = bucket.begin(); it != bucket.end();)
{
size_type bucket_index = HASH()(key) % new_buckets.size();
if constexpr(STABLE)
TRY(new_buckets[bucket_index].push_back(key));
else
TRY(new_buckets[bucket_index].push_back(move(key)));
size_type new_bucket_index = HASH()(*it) % new_buckets.size();
it = bucket.move_element_to_other_linked_list(new_buckets[new_bucket_index], new_buckets[new_bucket_index].end(), it);
}
}
@ -183,27 +180,20 @@ namespace BAN
return {};
}
template<typename T, typename HASH, bool STABLE>
LinkedList<T>& HashSet<T, HASH, STABLE>::get_bucket(const T& key)
template<typename T, typename HASH>
LinkedList<T>& HashSet<T, HASH>::get_bucket(const T& key)
{
ASSERT(!m_buckets.empty());
size_type index = HASH()(key) % m_buckets.size();
return m_buckets[index];
}
template<typename T, typename HASH, bool STABLE>
const LinkedList<T>& HashSet<T, HASH, STABLE>::get_bucket(const T& key) const
template<typename T, typename HASH>
const LinkedList<T>& HashSet<T, HASH>::get_bucket(const T& key) const
{
ASSERT(!m_buckets.empty());
size_type index = HASH()(key) % m_buckets.size();
return m_buckets[index];
}
// Unstable hash set moves values between container during rebucketing.
// This means that if insertion to set fails, elements could be in invalid state
// and that container is no longer usable. This is better if either way you are
// going to stop using the hash set after insertion fails.
template<typename T, typename HASH = hash<T>>
using HashSetUnstable = HashSet<T, HASH, false>;
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <BAN/Endianness.h>
#include <BAN/Formatter.h>
#include <BAN/Hash.h>
@ -10,31 +11,32 @@ namespace BAN
{
constexpr IPv4Address(uint32_t u32_address)
{
address[0] = u32_address >> 24;
address[1] = u32_address >> 16;
address[2] = u32_address >> 8;
address[3] = u32_address >> 0;
raw = u32_address;
}
constexpr uint32_t as_u32() const
constexpr IPv4Address(uint8_t oct1, uint8_t oct2, uint8_t oct3, uint8_t oct4)
{
return
((uint32_t)address[0] << 24) |
((uint32_t)address[1] << 16) |
((uint32_t)address[2] << 8) |
((uint32_t)address[3] << 0);
octets[0] = oct1;
octets[1] = oct2;
octets[2] = oct3;
octets[3] = oct4;
}
constexpr bool operator==(const IPv4Address& other) const
{
return
address[0] == other.address[0] &&
address[1] == other.address[1] &&
address[2] == other.address[2] &&
address[3] == other.address[3];
return raw == other.raw;
}
uint8_t address[4];
constexpr IPv4Address mask(const IPv4Address& other) const
{
return IPv4Address(raw & other.raw);
}
union
{
uint8_t octets[4];
uint32_t raw;
} __attribute__((packed));
};
static_assert(sizeof(IPv4Address) == 4);
@ -43,7 +45,7 @@ namespace BAN
{
constexpr hash_t operator()(IPv4Address ipv4) const
{
return hash<uint32_t>()(ipv4.as_u32());
return hash<uint32_t>()(ipv4.raw);
}
};
@ -62,11 +64,11 @@ namespace BAN::Formatter
.upper = false,
};
print_argument(putc, ipv4.address[0], format);
print_argument(putc, ipv4.octets[0], format);
for (size_t i = 1; i < 4; i++)
{
putc('.');
print_argument(putc, ipv4.address[i], format);
print_argument(putc, ipv4.octets[i], format);
}
}

View File

@ -3,6 +3,8 @@
#include <BAN/Assert.h>
#include <BAN/Traits.h>
#include <stddef.h>
namespace BAN
{

View File

@ -45,6 +45,7 @@ namespace BAN
void clear();
bool empty() const;
size_type capacity() const;
size_type size() const;
const T& front() const;
@ -186,6 +187,12 @@ namespace BAN
return m_size == 0;
}
template<typename T>
typename Queue<T>::size_type Queue<T>::capacity() const
{
return m_capacity;
}
template<typename T>
typename Queue<T>::size_type Queue<T>::size() const
{

View File

@ -216,6 +216,14 @@ namespace BAN
return m_index == detail::index<T, Ts...>();
}
template<typename T, typename... Args>
void emplace(Args&&... args) requires (can_have<T>())
{
clear();
m_index = detail::index<T, Ts...>();
new (m_storage) T(BAN::forward<Args>(args)...);
}
template<typename T>
void set(T&& value) requires (can_have<T>() && !is_lvalue_reference_v<T>)
{

View File

@ -91,6 +91,8 @@ namespace BAN
bool valid() const { return m_link && m_link->valid(); }
operator bool() const { return valid(); }
private:
WeakPtr(const RefPtr<WeakLink<T>>& link)
: m_link(link)

View File

@ -16,6 +16,7 @@ set(KERNEL_SOURCES
kernel/CPUID.cpp
kernel/Credentials.cpp
kernel/Debug.cpp
kernel/Device/DebugDevice.cpp
kernel/Device/Device.cpp
kernel/Device/FramebufferDevice.cpp
kernel/Device/NullDevice.cpp
@ -52,11 +53,12 @@ set(KERNEL_SOURCES
kernel/Networking/ARPTable.cpp
kernel/Networking/E1000/E1000.cpp
kernel/Networking/E1000/E1000E.cpp
kernel/Networking/IPv4.cpp
kernel/Networking/IPv4Layer.cpp
kernel/Networking/NetworkInterface.cpp
kernel/Networking/NetworkManager.cpp
kernel/Networking/NetworkSocket.cpp
kernel/Networking/UDPSocket.cpp
kernel/Networking/UNIX/Socket.cpp
kernel/OpenFileDescriptorSet.cpp
kernel/Panic.cpp
kernel/PCI.cpp

View File

@ -5,7 +5,7 @@
#define dprintln(...) \
do { \
Debug::DebugLock::lock(); \
Debug::print_prefix(__FILE__, __LINE__); \
Debug::print_prefix(__FILE__, __LINE__); \
BAN::Formatter::print(Debug::putchar, __VA_ARGS__); \
BAN::Formatter::print(Debug::putchar, "\r\n"); \
Debug::DebugLock::unlock(); \

View File

@ -0,0 +1,28 @@
#include <kernel/Device/Device.h>
namespace Kernel
{
class DebugDevice : public CharacterDevice
{
public:
static BAN::ErrorOr<BAN::RefPtr<DebugDevice>> create(mode_t, uid_t, gid_t);
virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return "debug"sv; }
protected:
DebugDevice(mode_t mode, uid_t uid, gid_t gid, dev_t rdev)
: CharacterDevice(mode, uid, gid)
, m_rdev(rdev)
{ }
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return 0; }
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override;
private:
const dev_t m_rdev;
};
}

View File

@ -100,9 +100,12 @@ namespace Kernel
BAN::ErrorOr<BAN::String> link_target();
// Socket API
BAN::ErrorOr<long> accept(sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<ssize_t> sendto(const sys_sendto_t*);
BAN::ErrorOr<ssize_t> recvfrom(sys_recvfrom_t*);
BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> listen(int backlog);
BAN::ErrorOr<size_t> sendto(const sys_sendto_t*);
BAN::ErrorOr<size_t> recvfrom(sys_recvfrom_t*);
// General API
BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer);
@ -128,9 +131,12 @@ namespace Kernel
virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); }
// Socket API
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<ssize_t> sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<ssize_t> recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); }
// General API
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); }

View File

@ -0,0 +1,20 @@
#pragma once
namespace Kernel
{
enum class SocketDomain
{
INET,
INET6,
UNIX,
};
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
}

View File

@ -80,6 +80,25 @@ namespace Kernel
friend class TmpInode;
};
class TmpSocketInode : public TmpInode
{
public:
static BAN::ErrorOr<BAN::RefPtr<TmpSocketInode>> create_new(TmpFileSystem&, mode_t, uid_t, gid_t);
~TmpSocketInode();
protected:
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<void> truncate_impl(size_t) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<void> chmod_impl(mode_t) override;
virtual bool has_data_impl() const override { return true; }
private:
TmpSocketInode(TmpFileSystem&, ino_t, const TmpInodeInfo&);
friend class TmpInode;
};
class TmpSymlinkInode : public TmpInode
{
public:

View File

@ -31,6 +31,7 @@ namespace Kernel
public:
static BAN::ErrorOr<BAN::UniqPtr<ARPTable>> create();
~ARPTable();
BAN::ErrorOr<BAN::MACAddress> get_mac_from_ipv4(NetworkInterface&, BAN::IPv4Address);

View File

@ -42,7 +42,7 @@ namespace Kernel
uint32_t read32(uint16_t reg);
void write32(uint16_t reg, uint32_t value);
virtual BAN::ErrorOr<void> send_raw_bytes(BAN::ConstByteSpan) override;
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override;
private:
BAN::ErrorOr<void> read_mac_address();

View File

@ -0,0 +1,24 @@
#pragma once
#include <BAN/Endianness.h>
#include <stdint.h>
namespace Kernel
{
struct ICMPHeader
{
uint8_t type;
uint8_t code;
BAN::NetworkEndian<uint16_t> checksum;
BAN::NetworkEndian<uint32_t> rest;
};
static_assert(sizeof(ICMPHeader) == 8);
enum ICMPType : uint8_t
{
EchoReply = 0x00,
EchoRequest = 0x08,
};
}

View File

@ -1,38 +0,0 @@
#pragma once
#include <BAN/ByteSpan.h>
#include <BAN/Endianness.h>
#include <BAN/IPv4.h>
#include <BAN/Vector.h>
namespace Kernel
{
struct IPv4Header
{
uint8_t version_IHL;
uint8_t DSCP_ECN;
BAN::NetworkEndian<uint16_t> total_length { 0 };
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags_frament { 0 };
uint8_t time_to_live;
uint8_t protocol;
BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address;
BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
return 0xFFFF
- (((uint16_t)version_IHL << 8) | DSCP_ECN)
- total_length
- identification
- flags_frament
- (((uint16_t)time_to_live << 8) | protocol);
}
};
static_assert(sizeof(IPv4Header) == 20);
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol);
}

View File

@ -0,0 +1,105 @@
#pragma once
#include <BAN/Array.h>
#include <BAN/ByteSpan.h>
#include <BAN/CircularQueue.h>
#include <BAN/Endianness.h>
#include <BAN/IPv4.h>
#include <BAN/NoCopyMove.h>
#include <BAN/UniqPtr.h>
#include <kernel/Networking/ARPTable.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
#include <kernel/Networking/NetworkSocket.h>
#include <kernel/Process.h>
#include <kernel/SpinLock.h>
namespace Kernel
{
struct IPv4Header
{
uint8_t version_IHL;
uint8_t DSCP_ECN;
BAN::NetworkEndian<uint16_t> total_length { 0 };
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags_frament { 0 };
uint8_t time_to_live;
uint8_t protocol;
BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address;
BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
total_sum -= checksum;
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return ~(uint16_t)total_sum;
}
constexpr bool is_valid_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return total_sum == 0xFFFF;
}
};
static_assert(sizeof(IPv4Header) == 20);
class IPv4Layer : public NetworkLayer
{
BAN_NON_COPYABLE(IPv4Layer);
BAN_NON_MOVABLE(IPv4Layer);
public:
static BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> create();
~IPv4Layer();
ARPTable& arp_table() { return *m_arp_table; }
void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan);
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) override;
private:
IPv4Layer();
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const;
void packet_handle_task();
BAN::ErrorOr<void> handle_ipv4_packet(NetworkInterface&, BAN::ByteSpan);
private:
struct PendingIPv4Packet
{
NetworkInterface& interface;
};
private:
SpinLock m_lock;
BAN::UniqPtr<ARPTable> m_arp_table;
Process* m_process { nullptr };
static constexpr size_t pending_packet_buffer_size = 128 * PAGE_SIZE;
BAN::UniqPtr<VirtualRange> m_pending_packet_buffer;
BAN::CircularQueue<PendingIPv4Packet, 128> m_pending_packets;
Semaphore m_pending_semaphore;
size_t m_pending_total_size { 0 };
BAN::HashMap<int, BAN::WeakPtr<NetworkSocket>> m_bound_sockets;
friend class BAN::UniqPtr<IPv4Layer>;
};
}

View File

@ -1,10 +1,10 @@
#pragma once
#include <BAN/Errors.h>
#include <BAN/ByteSpan.h>
#include <BAN/Errors.h>
#include <BAN/IPv4.h>
#include <BAN/MAC.h>
#include <kernel/Device/Device.h>
#include <kernel/Networking/IPv4.h>
namespace Kernel
{
@ -46,16 +46,16 @@ namespace Kernel
BAN::IPv4Address get_netmask() const { return m_netmask; }
void set_netmask(BAN::IPv4Address new_netmask) { m_netmask = new_netmask; }
BAN::IPv4Address get_gateway() const { return m_gateway; }
void set_gateway(BAN::IPv4Address new_gateway) { m_gateway = new_gateway; }
virtual bool link_up() = 0;
virtual int link_speed() = 0;
size_t interface_header_size() const;
void add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination);
virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return m_name; }
virtual BAN::ErrorOr<void> send_raw_bytes(BAN::ConstByteSpan) = 0;
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) = 0;
private:
const Type m_type;
@ -65,6 +65,7 @@ namespace Kernel
BAN::IPv4Address m_ipv4_address { 0 };
BAN::IPv4Address m_netmask { 0 };
BAN::IPv4Address m_gateway { 0 };
};
}

View File

@ -0,0 +1,25 @@
#pragma once
#include <kernel/Networking/NetworkInterface.h>
namespace Kernel
{
class NetworkSocket;
enum class SocketType;
class NetworkLayer
{
public:
virtual ~NetworkLayer() {}
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) = 0;
protected:
NetworkLayer() = default;
};
}

View File

@ -2,7 +2,7 @@
#include <BAN/Vector.h>
#include <kernel/FS/TmpFS/FileSystem.h>
#include <kernel/Networking/ARPTable.h>
#include <kernel/Networking/IPv4Layer.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkSocket.h>
#include <kernel/PCI.h>
@ -17,26 +17,15 @@ namespace Kernel
BAN_NON_COPYABLE(NetworkManager);
BAN_NON_MOVABLE(NetworkManager);
public:
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
public:
static BAN::ErrorOr<void> initialize();
static NetworkManager& get();
ARPTable& arp_table() { return *m_arp_table; }
BAN::ErrorOr<void> add_interface(PCI::Device& pci_device);
void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>);
BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>);
BAN::Vector<BAN::RefPtr<NetworkInterface>> interfaces() { return m_interfaces; }
BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> create_socket(SocketType, mode_t, uid_t, gid_t);
BAN::ErrorOr<BAN::RefPtr<TmpInode>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t);
void on_receive(NetworkInterface&, BAN::ConstByteSpan);
@ -44,9 +33,8 @@ namespace Kernel
NetworkManager();
private:
BAN::UniqPtr<ARPTable> m_arp_table;
BAN::Vector<BAN::RefPtr<NetworkInterface>> m_interfaces;
BAN::HashMap<int, BAN::WeakPtr<NetworkSocket>> m_bound_sockets;
BAN::UniqPtr<IPv4Layer> m_ipv4_layer;
BAN::Vector<BAN::RefPtr<NetworkInterface>> m_interfaces;
};
}

View File

@ -1,8 +1,10 @@
#pragma once
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
#include <netinet/in.h>
@ -11,6 +13,7 @@ namespace Kernel
enum NetworkProtocol : uint8_t
{
ICMP = 0x01,
UDP = 0x11,
};
@ -26,26 +29,29 @@ namespace Kernel
void bind_interface_and_port(NetworkInterface*, uint16_t port);
~NetworkSocket();
NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; }
virtual size_t protocol_header_size() const = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) = 0;
virtual NetworkProtocol protocol() const = 0;
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0;
protected:
NetworkSocket(mode_t mode, uid_t uid, gid_t gid);
NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0;
virtual void on_close_impl() override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<ssize_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<ssize_t> recvfrom_impl(sys_recvfrom_t*) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override;
protected:
NetworkLayer& m_network_layer;
NetworkInterface* m_interface = nullptr;
uint16_t m_port = PORT_NONE;
};

View File

@ -22,10 +22,10 @@ namespace Kernel
class UDPSocket final : public NetworkSocket
{
public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(mode_t, uid_t, gid_t);
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) override;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) override;
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
protected:
@ -33,7 +33,7 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override;
private:
UDPSocket(mode_t, uid_t, gid_t);
UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
struct PacketInfo
{

View File

@ -0,0 +1,67 @@
#pragma once
#include <BAN/Queue.h>
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
namespace Kernel
{
class UnixDomainSocket final : public TmpInode, public BAN::Weakable<UnixDomainSocket>
{
BAN_NON_COPYABLE(UnixDomainSocket);
BAN_NON_MOVABLE(UnixDomainSocket);
public:
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);
BAN::ErrorOr<void> add_packet(BAN::ConstByteSpan);
bool is_bound() const { return !m_bound_path.empty(); }
bool is_bound_to_unused() const { return m_bound_path == "X"sv; }
bool is_streaming() const;
private:
struct ConnectionInfo
{
bool listening { false };
BAN::Atomic<bool> connection_done { false };
BAN::WeakPtr<UnixDomainSocket> connection;
BAN::Queue<BAN::RefPtr<UnixDomainSocket>> pending_connections;
Semaphore pending_semaphore;
SpinLock pending_lock;
};
struct ConnectionlessInfo
{
};
private:
const SocketType m_socket_type;
BAN::String m_bound_path;
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;
BAN::CircularQueue<size_t, 128> m_packet_sizes;
size_t m_packet_size_total { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
Semaphore m_packet_semaphore;
friend class BAN::RefPtr<UnixDomainSocket>;
};
}

View File

@ -21,6 +21,7 @@ namespace Kernel
BAN::ErrorOr<void> clone_from(const OpenFileDescriptorSet&);
BAN::ErrorOr<int> open(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags);
BAN::ErrorOr<int> socket(int domain, int type, int protocol);

View File

@ -62,6 +62,8 @@ namespace Kernel
bool is_session_leader() const { return pid() == sid(); }
const char* name() const { return m_cmdline.empty() ? "" : m_cmdline.front().data(); }
const Credentials& credentials() const { return m_credentials; }
BAN::ErrorOr<long> sys_exit(int status);
@ -93,8 +95,10 @@ namespace Kernel
BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); }
BAN::ErrorOr<long> sys_getpgid(pid_t);
BAN::ErrorOr<long> open_inode(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode);
BAN::ErrorOr<long> open_file(BAN::StringView path, int, mode_t = 0);
BAN::ErrorOr<long> open_file(BAN::StringView path, int oflag, mode_t = 0);
BAN::ErrorOr<long> sys_open(const char* path, int, mode_t);
BAN::ErrorOr<long> sys_openat(int, const char* path, int, mode_t);
BAN::ErrorOr<long> sys_close(int fd);
@ -113,7 +117,10 @@ namespace Kernel
BAN::ErrorOr<long> sys_chown(const char*, uid_t, gid_t);
BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol);
BAN::ErrorOr<long> sys_accept(int socket, sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<long> sys_bind(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_connect(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_listen(int socket, int backlog);
BAN::ErrorOr<long> sys_sendto(const sys_sendto_t*);
BAN::ErrorOr<long> sys_recvfrom(sys_recvfrom_t*);
@ -175,6 +182,8 @@ namespace Kernel
// Return false if access was page violation (segfault)
BAN::ErrorOr<bool> allocate_page_for_demand_paging(vaddr_t addr);
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
private:
Process(const Credentials&, pid_t pid, pid_t parent, pid_t sid, pid_t pgrp);
static Process* create_process(const Credentials&, pid_t parent, pid_t sid = 0, pid_t pgrp = 0);
@ -184,8 +193,6 @@ namespace Kernel
BAN::ErrorOr<int> block_until_exit(pid_t pid);
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
BAN::ErrorOr<void> validate_string_access(const char*);
BAN::ErrorOr<void> validate_pointer_access(const void*, size_t);

View File

@ -0,0 +1,32 @@
#include <kernel/Device/DebugDevice.h>
#include <kernel/FS/DevFS/FileSystem.h>
#include <kernel/Process.h>
#include <kernel/Timer/Timer.h>
namespace Kernel
{
BAN::ErrorOr<BAN::RefPtr<DebugDevice>> DebugDevice::create(mode_t mode, uid_t uid, gid_t gid)
{
auto* result = new DebugDevice(mode, uid, gid, DevFileSystem::get().get_next_dev());
if (result == nullptr)
return BAN::Error::from_errno(ENOMEM);
return BAN::RefPtr<DebugDevice>::adopt(result);
}
BAN::ErrorOr<size_t> DebugDevice::write_impl(off_t, BAN::ConstByteSpan buffer)
{
auto ms_since_boot = SystemTimer::get().ms_since_boot();
Debug::DebugLock::lock();
BAN::Formatter::print(Debug::putchar, "[{5}.{3}] {}: ",
ms_since_boot / 1000,
ms_since_boot % 1000,
Kernel::Process::current().name()
);
for (size_t i = 0; i < buffer.size(); i++)
Debug::putchar(buffer[i]);
Debug::DebugLock::unlock();
return buffer.size();
}
}

View File

@ -1,4 +1,5 @@
#include <BAN/ScopeGuard.h>
#include <kernel/Device/DebugDevice.h>
#include <kernel/Device/FramebufferDevice.h>
#include <kernel/Device/NullDevice.h>
#include <kernel/Device/ZeroDevice.h>
@ -22,6 +23,7 @@ namespace Kernel
ASSERT(s_instance);
MUST(s_instance->TmpFileSystem::initialize(0755, 0, 0));
s_instance->add_device(MUST(DebugDevice::create(0666, 0, 0)));
s_instance->add_device(MUST(NullDevice::create(0666, 0, 0)));
s_instance->add_device(MUST(ZeroDevice::create(0666, 0, 0)));
}

View File

@ -116,6 +116,14 @@ namespace Kernel
return link_target_impl();
}
BAN::ErrorOr<long> Inode::accept(sockaddr* address, socklen_t* address_len)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return accept_impl(address, address_len);
}
BAN::ErrorOr<void> Inode::bind(const sockaddr* address, socklen_t address_len)
{
LockGuard _(m_lock);
@ -124,7 +132,23 @@ namespace Kernel
return bind_impl(address, address_len);
}
BAN::ErrorOr<ssize_t> Inode::sendto(const sys_sendto_t* arguments)
BAN::ErrorOr<void> Inode::connect(const sockaddr* address, socklen_t address_len)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return connect_impl(address, address_len);
}
BAN::ErrorOr<void> Inode::listen(int backlog)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return listen_impl(backlog);
}
BAN::ErrorOr<size_t> Inode::sendto(const sys_sendto_t* arguments)
{
LockGuard _(m_lock);
if (!mode().ifsock())
@ -132,7 +156,7 @@ namespace Kernel
return sendto_impl(arguments);
};
BAN::ErrorOr<ssize_t> Inode::recvfrom(sys_recvfrom_t* arguments)
BAN::ErrorOr<size_t> Inode::recvfrom(sys_recvfrom_t* arguments)
{
LockGuard _(m_lock);
if (!mode().ifsock())

View File

@ -215,6 +215,36 @@ namespace Kernel
return {};
}
/* SOCKET INODE */
BAN::ErrorOr<BAN::RefPtr<TmpSocketInode>> TmpSocketInode::create_new(TmpFileSystem& fs, mode_t mode, uid_t uid, gid_t gid)
{
auto info = create_inode_info(Mode::IFSOCK | mode, uid, gid);
ino_t ino = TRY(fs.allocate_inode(info));
auto* inode_ptr = new TmpSocketInode(fs, ino, info);
if (inode_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
return BAN::RefPtr<TmpSocketInode>::adopt(inode_ptr);
}
TmpSocketInode::TmpSocketInode(TmpFileSystem& fs, ino_t ino, const TmpInodeInfo& info)
: TmpInode(fs, ino, info)
{
ASSERT(mode().ifsock());
}
TmpSocketInode::~TmpSocketInode()
{
}
BAN::ErrorOr<void> TmpSocketInode::chmod_impl(mode_t new_mode)
{
m_inode_info.mode = new_mode;
return {};
}
/* SYMLINK INODE */
BAN::ErrorOr<BAN::RefPtr<TmpSymlinkInode>> TmpSymlinkInode::create_new(TmpFileSystem& fs, mode_t mode, uid_t uid, gid_t gid, BAN::StringView target)
@ -446,7 +476,19 @@ namespace Kernel
BAN::ErrorOr<void> TmpDirectoryInode::create_file_impl(BAN::StringView name, mode_t mode, uid_t uid, gid_t gid)
{
auto new_inode = TRY(TmpFileInode::create_new(m_fs, mode, uid, gid));
BAN::RefPtr<TmpInode> new_inode;
switch (mode & Mode::TYPE_MASK)
{
case Mode::IFREG:
new_inode = TRY(TmpFileInode::create_new(m_fs, mode, uid, gid));
break;
case Mode::IFSOCK:
new_inode = TRY(TmpSocketInode::create_new(m_fs, mode, uid, gid));
break;
default:
dprintln("Creating with mode {o} is not supported", mode);
return BAN::Error::from_errno(ENOTSUP);
}
TRY(link_inode(*new_inode, name));
return {};
}

View File

@ -294,6 +294,9 @@ void* kmalloc(size_t size, size_t align, bool force_identity_map)
// currently kmalloc is always identity mapped
(void)force_identity_map;
if (size == 0)
size = 1;
const kmalloc_info& info = s_kmalloc_info;
ASSERT(is_power_of_two(align));

View File

@ -1,5 +1,6 @@
#include <kernel/LockGuard.h>
#include <kernel/Networking/ARPTable.h>
#include <kernel/Scheduler.h>
#include <kernel/Timer/Timer.h>
namespace Kernel
@ -32,26 +33,31 @@ namespace Kernel
{
}
ARPTable::~ARPTable()
{
if (m_process)
m_process->exit(0, SIGKILL);
m_process = nullptr;
}
BAN::ErrorOr<BAN::MACAddress> ARPTable::get_mac_from_ipv4(NetworkInterface& interface, BAN::IPv4Address ipv4_address)
{
if (ipv4_address == s_broadcast_ipv4)
return s_broadcast_mac;
if (interface.get_ipv4_address() == BAN::IPv4Address { 0 })
return BAN::Error::from_errno(EINVAL);
if (interface.get_ipv4_address().mask(interface.get_netmask()) != ipv4_address.mask(interface.get_netmask()))
ipv4_address = interface.get_gateway();
{
LockGuard _(m_lock);
if (ipv4_address == s_broadcast_ipv4)
return s_broadcast_mac;
if (m_arp_table.contains(ipv4_address))
return m_arp_table[ipv4_address];
}
BAN::Vector<uint8_t> full_packet_buffer;
TRY(full_packet_buffer.resize(sizeof(ARPPacket) + sizeof(EthernetHeader)));
auto full_packet = BAN::ByteSpan { full_packet_buffer.span() };
auto& ethernet_header = full_packet.as<EthernetHeader>();
ethernet_header.dst_mac = s_broadcast_mac;
ethernet_header.src_mac = interface.get_mac_address();
ethernet_header.ether_type = EtherType::ARP;
auto& arp_request = full_packet.slice(sizeof(EthernetHeader)).as<ARPPacket>();
ARPPacket arp_request;
arp_request.htype = 0x0001;
arp_request.ptype = EtherType::IPv4;
arp_request.hlen = 0x06;
@ -62,9 +68,9 @@ namespace Kernel
arp_request.tha = {{ 0, 0, 0, 0, 0, 0 }};
arp_request.tpa = ipv4_address;
TRY(interface.send_raw_bytes(full_packet));
TRY(interface.send_bytes(s_broadcast_mac, EtherType::ARP, BAN::ConstByteSpan::from(arp_request)));
uint64_t timeout = SystemTimer::get().ms_since_boot() + 5'000;
uint64_t timeout = SystemTimer::get().ms_since_boot() + 1'000;
while (SystemTimer::get().ms_since_boot() < timeout)
{
{
@ -72,10 +78,10 @@ namespace Kernel
if (m_arp_table.contains(ipv4_address))
return m_arp_table[ipv4_address];
}
TRY(Thread::current().block_or_eintr(m_pending_semaphore));
Scheduler::get().reschedule();
}
return BAN::Error::from_errno(EINVAL);
return BAN::Error::from_errno(ETIMEDOUT);
}
BAN::ErrorOr<void> ARPTable::handle_arp_packet(NetworkInterface& interface, const ARPPacket& packet)
@ -92,27 +98,17 @@ namespace Kernel
{
if (packet.tpa == interface.get_ipv4_address())
{
BAN::Vector<uint8_t> full_packet_buffer;
TRY(full_packet_buffer.resize(sizeof(ARPPacket) + sizeof(EthernetHeader)));
auto full_packet = BAN::ByteSpan { full_packet_buffer.span() };
auto& ethernet_header = full_packet.as<EthernetHeader>();
ethernet_header.dst_mac = packet.sha;
ethernet_header.src_mac = interface.get_mac_address();
ethernet_header.ether_type = EtherType::ARP;
auto& arp_request = full_packet.slice(sizeof(EthernetHeader)).as<ARPPacket>();
arp_request.htype = 0x0001;
arp_request.ptype = EtherType::IPv4;
arp_request.hlen = 0x06;
arp_request.plen = 0x04;
arp_request.oper = ARPOperation::Reply;
arp_request.sha = interface.get_mac_address();
arp_request.spa = interface.get_ipv4_address();
arp_request.tha = packet.sha;
arp_request.tpa = packet.spa;
TRY(interface.send_raw_bytes(full_packet));
ARPPacket arp_reply;
arp_reply.htype = 0x0001;
arp_reply.ptype = EtherType::IPv4;
arp_reply.hlen = 0x06;
arp_reply.plen = 0x04;
arp_reply.oper = ARPOperation::Reply;
arp_reply.sha = interface.get_mac_address();
arp_reply.spa = interface.get_ipv4_address();
arp_reply.tha = packet.sha;
arp_reply.tpa = packet.spa;
TRY(interface.send_bytes(packet.sha, EtherType::ARP, BAN::ConstByteSpan::from(arp_reply)));
}
break;
}

View File

@ -256,19 +256,26 @@ namespace Kernel
return {};
}
BAN::ErrorOr<void> E1000::send_raw_bytes(BAN::ConstByteSpan buffer)
BAN::ErrorOr<void> E1000::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer)
{
ASSERT_LTE(buffer.size(), E1000_TX_BUFFER_SIZE);
ASSERT_LTE(buffer.size() + sizeof(EthernetHeader), E1000_TX_BUFFER_SIZE);
CriticalScope _;
size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT;
auto* tx_buffer = reinterpret_cast<void*>(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * tx_current);
memcpy(tx_buffer, buffer.data(), buffer.size());
auto* tx_buffer = reinterpret_cast<uint8_t*>(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * tx_current);
auto& ethernet_header = *reinterpret_cast<EthernetHeader*>(tx_buffer);
ethernet_header.dst_mac = destination;
ethernet_header.src_mac = get_mac_address();
ethernet_header.ether_type = protocol;
memcpy(tx_buffer + sizeof(EthernetHeader), buffer.data(), buffer.size());
auto& descriptor = reinterpret_cast<volatile e1000_tx_desc*>(m_tx_descriptor_region->vaddr())[tx_current];
descriptor.length = buffer.size();
descriptor.length = sizeof(EthernetHeader) + buffer.size();
descriptor.status = 0;
descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS;

View File

@ -1,22 +0,0 @@
#include <BAN/Endianness.h>
#include <kernel/Networking/IPv4.h>
namespace Kernel
{
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol)
{
auto& header = packet.as<IPv4Header>();
header.version_IHL = 0x45;
header.DSCP_ECN = 0x10;
header.total_length = packet.size();
header.identification = 1;
header.flags_frament = 0x00;
header.time_to_live = 0x40;
header.protocol = protocol;
header.checksum = header.calculate_checksum();
header.src_address = src_ipv4;
header.dst_address = dst_ipv4;
}
}

View File

@ -0,0 +1,284 @@
#include <kernel/Memory/Heap.h>
#include <kernel/Memory/PageTable.h>
#include <kernel/Networking/ICMP.h>
#include <kernel/Networking/IPv4Layer.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UDPSocket.h>
#include <netinet/in.h>
#define DEBUG_IPV4 0
namespace Kernel
{
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
{
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
ipv4_manager->m_process = Process::create_kernel(
[](void* ipv4_manager_ptr)
{
auto& ipv4_manager = *reinterpret_cast<IPv4Layer*>(ipv4_manager_ptr);
ipv4_manager.packet_handle_task();
}, ipv4_manager.ptr()
);
ASSERT(ipv4_manager->m_process);
ipv4_manager->m_pending_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
pending_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
ipv4_manager->m_arp_table = TRY(ARPTable::create());
return ipv4_manager;
}
IPv4Layer::IPv4Layer()
{ }
IPv4Layer::~IPv4Layer()
{
if (m_process)
m_process->exit(0, SIGKILL);
m_process = nullptr;
}
void IPv4Layer::add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const
{
auto& header = packet.as<IPv4Header>();
header.version_IHL = 0x45;
header.DSCP_ECN = 0x00;
header.total_length = packet.size();
header.identification = 1;
header.flags_frament = 0x00;
header.time_to_live = 0x40;
header.protocol = protocol;
header.src_address = src_ipv4;
header.dst_address = dst_ipv4;
header.checksum = header.calculate_checksum();
}
void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
LockGuard _(m_lock);
if (m_bound_sockets.contains(port))
{
ASSERT(m_bound_sockets[port].valid());
ASSERT(m_bound_sockets[port].lock() == socket);
m_bound_sockets.remove(port);
}
NetworkManager::get().TmpFileSystem::remove_from_cache(socket);
}
BAN::ErrorOr<void> IPv4Layer::bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
if (NetworkManager::get().interfaces().empty())
return BAN::Error::from_errno(EADDRNOTAVAIL);
LockGuard _(m_lock);
if (port == NetworkSocket::PORT_NONE)
{
for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++)
{
if (!m_bound_sockets.contains(temp))
{
port = temp;
break;
}
}
if (port == NetworkSocket::PORT_NONE)
{
dwarnln("No ports available");
return BAN::Error::from_errno(EAGAIN);
}
}
if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket));
// FIXME: actually determine proper interface
auto interface = NetworkManager::get().interfaces().front();
socket->bind_interface_and_port(interface.ptr(), port);
return {};
}
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, const sys_sendto_t* arguments)
{
if (arguments->dest_addr->sa_family != AF_INET)
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(arguments->dest_addr);
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4));
BAN::Vector<uint8_t> packet_buffer;
TRY(packet_buffer.resize(arguments->length + sizeof(IPv4Header) + socket.protocol_header_size()));
auto packet = BAN::ByteSpan { packet_buffer.span() };
memcpy(
packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(),
arguments->message,
arguments->length
);
socket.add_protocol_header(
packet.slice(sizeof(IPv4Header)),
dst_port
);
add_ipv4_header(
packet,
socket.interface().get_ipv4_address(),
dst_ipv4,
socket.protocol()
);
TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet));
return arguments->length;
}
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet)
{
uint32_t checksum = 0;
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
while (checksum >> 16)
checksum = (checksum >> 16) | (checksum & 0xFFFF);
return ~(uint16_t)checksum;
}
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet)
{
auto& ipv4_header = packet.as<const IPv4Header>();
auto ipv4_data = packet.slice(sizeof(IPv4Header));
ASSERT(ipv4_header.is_valid_checksum());
auto src_ipv4 = ipv4_header.src_address;
switch (ipv4_header.protocol)
{
case NetworkProtocol::ICMP:
{
auto& icmp_header = ipv4_data.as<const ICMPHeader>();
switch (icmp_header.type)
{
case ICMPType::EchoRequest:
{
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4));
auto& reply_icmp_header = ipv4_data.as<ICMPHeader>();
reply_icmp_header.type = ICMPType::EchoReply;
reply_icmp_header.checksum = 0;
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data);
add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP);
TRY(interface.send_bytes(dst_mac, EtherType::IPv4, packet));
break;
}
default:
dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type);
break;
}
break;
}
case NetworkProtocol::UDP:
{
auto& udp_header = ipv4_data.as<const UDPHeader>();
uint16_t src_port = udp_header.src_port;
uint16_t dst_port = udp_header.dst_port;
LockGuard _(m_lock);
if (!m_bound_sockets.contains(dst_port) || !m_bound_sockets[dst_port].valid())
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
auto udp_data = ipv4_data.slice(sizeof(UDPHeader));
m_bound_sockets[dst_port].lock()->add_packet(udp_data, src_ipv4, src_port);
break;
}
default:
dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol);
break;
}
return {};
}
void IPv4Layer::packet_handle_task()
{
for (;;)
{
BAN::Optional<PendingIPv4Packet> pending;
{
CriticalScope _;
if (!m_pending_packets.empty())
{
pending = m_pending_packets.front();
m_pending_packets.pop();
}
}
if (!pending.has_value())
{
m_pending_semaphore.block();
continue;
}
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
const size_t ipv4_packet_size = reinterpret_cast<const IPv4Header*>(buffer_start)->total_length;
if (auto ret = handle_ipv4_packet(pending->interface, BAN::ByteSpan(buffer_start, ipv4_packet_size)); ret.is_error())
dwarnln("{}", ret.error());
CriticalScope _;
m_pending_total_size -= ipv4_packet_size;
if (m_pending_total_size)
memmove(buffer_start, buffer_start + ipv4_packet_size, m_pending_total_size);
}
}
void IPv4Layer::add_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer)
{
if (m_pending_packets.full())
{
dwarnln("IPv4 packet queue full");
return;
}
if (m_pending_total_size + buffer.size() > m_pending_packet_buffer->size())
{
dwarnln("IPv4 packet queue full");
return;
}
auto& ipv4_header = buffer.as<const IPv4Header>();
if (!ipv4_header.is_valid_checksum())
{
dwarnln("Invalid IPv4 packet");
return;
}
if (ipv4_header.total_length > buffer.size())
{
dwarnln("Too short IPv4 packet");
return;
}
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
memcpy(buffer_start + m_pending_total_size, buffer.data(), ipv4_header.total_length);
m_pending_total_size += ipv4_header.total_length;
m_pending_packets.push({ .interface = interface });
m_pending_semaphore.unblock();
}
}

View File

@ -32,19 +32,4 @@ namespace Kernel
m_name[3] = minor(m_rdev) + '0';
}
size_t NetworkInterface::interface_header_size() const
{
ASSERT(m_type == Type::Ethernet);
return sizeof(EthernetHeader);
}
void NetworkInterface::add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination)
{
ASSERT(m_type == Type::Ethernet);
auto& header = packet.as<EthernetHeader>();
header.dst_mac = destination;
header.src_mac = get_mac_address();
header.ether_type = 0x0800;
}
}

View File

@ -3,9 +3,12 @@
#include <kernel/FS/DevFS/FileSystem.h>
#include <kernel/Networking/E1000/E1000.h>
#include <kernel/Networking/E1000/E1000E.h>
#include <kernel/Networking/IPv4.h>
#include <kernel/Networking/ICMP.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UDPSocket.h>
#include <kernel/Networking/UNIX/Socket.h>
#define DEBUG_ETHERTYPE 0
namespace Kernel
{
@ -19,8 +22,8 @@ namespace Kernel
if (manager_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
auto manager = BAN::UniqPtr<NetworkManager>::adopt(manager_ptr);
manager->m_arp_table = TRY(ARPTable::create());
TRY(manager->TmpFileSystem::initialize(0777, 0, 0));
manager->m_ipv4_layer = TRY(IPv4Layer::create());
s_instance = BAN::move(manager);
return {};
}
@ -68,45 +71,50 @@ namespace Kernel
return {};
}
BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> NetworkManager::create_socket(SocketType type, mode_t mode, uid_t uid, gid_t gid)
BAN::ErrorOr<BAN::RefPtr<TmpInode>> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid)
{
switch (domain)
{
case SocketDomain::INET:
{
if (type != SocketType::DGRAM)
return BAN::Error::from_errno(EPROTOTYPE);
break;
}
case SocketDomain::UNIX:
{
break;
}
default:
return BAN::Error::from_errno(EAFNOSUPPORT);
}
ASSERT((mode & Inode::Mode::TYPE_MASK) == 0);
mode |= Inode::Mode::IFSOCK;
if (type != SocketType::DGRAM)
return BAN::Error::from_errno(EPROTOTYPE);
auto inode_info = create_inode_info(mode, uid, gid);
ino_t ino = TRY(allocate_inode(inode_info));
auto udp_socket = TRY(UDPSocket::create(mode | Inode::Mode::IFSOCK, uid, gid));
return BAN::RefPtr<NetworkSocket>(udp_socket);
}
void NetworkManager::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
if (m_bound_sockets.contains(port))
BAN::RefPtr<TmpInode> socket;
switch (domain)
{
ASSERT(m_bound_sockets[port].valid());
ASSERT(m_bound_sockets[port].lock() == socket);
m_bound_sockets.remove(port);
}
NetworkManager::get().remove_from_cache(socket);
}
BAN::ErrorOr<void> NetworkManager::bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
if (m_interfaces.empty())
return BAN::Error::from_errno(EADDRNOTAVAIL);
if (port != NetworkSocket::PORT_NONE)
{
if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket));
case SocketDomain::INET:
{
if (type == SocketType::DGRAM)
socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info));
break;
}
case SocketDomain::UNIX:
{
socket = TRY(UnixDomainSocket::create(type, ino, inode_info));
break;
}
default:
ASSERT_NOT_REACHED();
}
// FIXME: actually determine proper interface
auto interface = m_interfaces.front();
socket->bind_interface_and_port(interface.ptr(), port);
return {};
ASSERT(socket);
return socket;
}
void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet)
@ -117,41 +125,16 @@ namespace Kernel
{
case EtherType::ARP:
{
m_arp_table->add_arp_packet(interface, packet.slice(sizeof(EthernetHeader)));
m_ipv4_layer->arp_table().add_arp_packet(interface, packet.slice(sizeof(EthernetHeader)));
break;
}
case EtherType::IPv4:
{
auto ipv4 = packet.slice(sizeof(EthernetHeader));
auto& ipv4_header = ipv4.as<const IPv4Header>();
auto src_ipv4 = ipv4_header.src_address;
switch (ipv4_header.protocol)
{
case NetworkProtocol::UDP:
{
auto udp = ipv4.slice(sizeof(IPv4Header));
auto& udp_header = udp.as<const UDPHeader>();
uint16_t src_port = udp_header.src_port;
uint16_t dst_port = udp_header.dst_port;
if (!m_bound_sockets.contains(dst_port))
{
dprintln("no one is listening on port {}", dst_port);
return;
}
auto raw = udp.slice(8);
m_bound_sockets[dst_port].lock()->add_packet(raw, src_ipv4, src_port);
break;
}
default:
dprintln("Unknown network protocol 0x{2H}", ipv4_header.protocol);
break;
}
m_ipv4_layer->add_ipv4_packet(interface, packet.slice(sizeof(EthernetHeader)));
break;
}
default:
dprintln("Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type);
dprintln_if(DEBUG_ETHERTYPE, "Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type);
break;
}
}

View File

@ -1,4 +1,3 @@
#include <kernel/Networking/IPv4.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/NetworkSocket.h>
@ -7,13 +6,9 @@
namespace Kernel
{
NetworkSocket::NetworkSocket(mode_t mode, uid_t uid, gid_t gid)
// FIXME: what the fuck is this
: TmpInode(
NetworkManager::get(),
MUST(NetworkManager::get().allocate_inode(create_inode_info(mode, uid, gid))),
create_inode_info(mode, uid, gid)
)
NetworkSocket::NetworkSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
, m_network_layer(network_layer)
{ }
NetworkSocket::~NetworkSocket()
@ -23,7 +18,7 @@ namespace Kernel
void NetworkSocket::on_close_impl()
{
if (m_interface)
NetworkManager::get().unbind_socket(m_port, this);
m_network_layer.unbind_socket(m_port, this);
}
void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port)
@ -36,16 +31,15 @@ namespace Kernel
BAN::ErrorOr<void> NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len != sizeof(sockaddr_in))
if (m_interface || address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
auto* addr_in = reinterpret_cast<const sockaddr_in*>(address);
return NetworkManager::get().bind_socket(addr_in->sin_port, this);
uint16_t dst_port = BAN::host_to_network_endian(addr_in->sin_port);
return m_network_layer.bind_socket(dst_port, this);
}
BAN::ErrorOr<ssize_t> NetworkSocket::sendto_impl(const sys_sendto_t* arguments)
BAN::ErrorOr<size_t> NetworkSocket::sendto_impl(const sys_sendto_t* arguments)
{
if (arguments->dest_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
if (arguments->flags)
{
dprintln("flags not supported");
@ -53,45 +47,12 @@ namespace Kernel
}
if (!m_interface)
TRY(NetworkManager::get().bind_socket(PORT_NONE, this));
TRY(m_network_layer.bind_socket(PORT_NONE, this));
auto* destination = reinterpret_cast<const sockaddr_in*>(arguments->dest_addr);
auto message = BAN::ConstByteSpan((const uint8_t*)arguments->message, arguments->length);
uint16_t dst_port = destination->sin_port;
if (dst_port == PORT_NONE)
return BAN::Error::from_errno(EINVAL);
auto dst_addr = BAN::IPv4Address(destination->sin_addr.s_addr);
auto dst_mac = TRY(NetworkManager::get().arp_table().get_mac_from_ipv4(*m_interface, dst_addr));
const size_t interface_header_offset = 0;
const size_t interface_header_size = m_interface->interface_header_size();
const size_t ipv4_header_offset = interface_header_offset + interface_header_size;
const size_t ipv4_header_size = sizeof(IPv4Header);
const size_t protocol_header_offset = ipv4_header_offset + ipv4_header_size;
const size_t protocol_header_size = this->protocol_header_size();
const size_t payload_offset = protocol_header_offset + protocol_header_size;
const size_t payload_size = message.size();
BAN::Vector<uint8_t> full_packet;
TRY(full_packet.resize(payload_offset + payload_size));
BAN::ByteSpan packet_bytespan { full_packet.span() };
memcpy(full_packet.data() + payload_offset, message.data(), payload_size);
add_protocol_header(packet_bytespan.slice(protocol_header_offset), m_port, dst_port);
add_ipv4_header(packet_bytespan.slice(ipv4_header_offset), m_interface->get_ipv4_address(), dst_addr, protocol());
m_interface->add_interface_header(packet_bytespan.slice(interface_header_offset), dst_mac);
TRY(m_interface->send_raw_bytes(packet_bytespan));
return arguments->length;
return TRY(m_network_layer.sendto(*this, arguments));
}
BAN::ErrorOr<ssize_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments)
BAN::ErrorOr<size_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{
sockaddr_in* sender_addr = nullptr;
if (arguments->address)
@ -140,36 +101,52 @@ namespace Kernel
{
case SIOCGIFADDR:
{
auto ipv4_address = m_interface->get_ipv4_address();
ifreq->ifr_ifru.ifru_addr.sa_family = AF_INET;
memcpy(ifreq->ifr_ifru.ifru_addr.sa_data, &ipv4_address, sizeof(ipv4_address));
auto& ifru_addr = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_addr);
ifru_addr.sin_family = AF_INET;
ifru_addr.sin_addr.s_addr = m_interface->get_ipv4_address().raw;
return 0;
}
case SIOCSIFADDR:
{
if (ifreq->ifr_ifru.ifru_addr.sa_family != AF_INET)
auto& ifru_addr = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_addr);
if (ifru_addr.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL);
BAN::IPv4Address ipv4_address { *reinterpret_cast<uint32_t*>(ifreq->ifr_ifru.ifru_addr.sa_data) };
m_interface->set_ipv4_address(ipv4_address);
m_interface->set_ipv4_address(BAN::IPv4Address { ifru_addr.sin_addr.s_addr });
dprintln("IPv4 address set to {}", m_interface->get_ipv4_address());
return 0;
}
case SIOCGIFNETMASK:
{
auto netmask_address = m_interface->get_netmask();
ifreq->ifr_ifru.ifru_netmask.sa_family = AF_INET;
memcpy(ifreq->ifr_ifru.ifru_netmask.sa_data, &netmask_address, sizeof(netmask_address));
auto& ifru_netmask = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_netmask);
ifru_netmask.sin_family = AF_INET;
ifru_netmask.sin_addr.s_addr = m_interface->get_netmask().raw;
return 0;
}
case SIOCSIFNETMASK:
{
if (ifreq->ifr_ifru.ifru_netmask.sa_family != AF_INET)
auto& ifru_netmask = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_netmask);
if (ifru_netmask.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL);
BAN::IPv4Address netmask { *reinterpret_cast<uint32_t*>(ifreq->ifr_ifru.ifru_netmask.sa_data) };
m_interface->set_netmask(netmask);
m_interface->set_netmask(BAN::IPv4Address { ifru_netmask.sin_addr.s_addr });
dprintln("Netmask set to {}", m_interface->get_netmask());
return 0;
}
case SIOCGIFGWADDR:
{
auto& ifru_gwaddr = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_gwaddr);
ifru_gwaddr.sin_family = AF_INET;
ifru_gwaddr.sin_addr.s_addr = m_interface->get_gateway().raw;
return 0;
}
case SIOCSIFGWADDR:
{
auto& ifru_gwaddr = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_gwaddr);
if (ifru_gwaddr.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL);
m_interface->set_gateway(BAN::IPv4Address { ifru_gwaddr.sin_addr.s_addr });
dprintln("Gateway set to {}", m_interface->get_gateway());
return 0;
}
case SIOCGIFHWADDR:
{
auto mac_address = m_interface->get_mac_address();

View File

@ -5,9 +5,9 @@
namespace Kernel
{
BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(mode_t mode, uid_t uid, gid_t gid)
BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
{
auto socket = TRY(BAN::RefPtr<UDPSocket>::create(mode, uid, gid));
auto socket = TRY(BAN::RefPtr<UDPSocket>::create(network_layer, ino, inode_info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
@ -19,14 +19,14 @@ namespace Kernel
return socket;
}
UDPSocket::UDPSocket(mode_t mode, uid_t uid, gid_t gid)
: NetworkSocket(mode, uid, gid)
UDPSocket::UDPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: NetworkSocket(network_layer, ino, inode_info)
{ }
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port)
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port)
{
auto& header = packet.as<UDPHeader>();
header.src_port = src_port;
header.src_port = m_port;
header.dst_port = dst_port;
header.length = packet.size();
header.checksum = 0;
@ -91,8 +91,8 @@ namespace Kernel
if (sender_addr)
{
sender_addr->sin_family = AF_INET;
sender_addr->sin_port = packet_info.sender_port;
sender_addr->sin_addr.s_addr = packet_info.sender_addr.as_u32();
sender_addr->sin_port = BAN::NetworkEndian(packet_info.sender_port);
sender_addr->sin_addr.s_addr = packet_info.sender_addr.raw;
}
return nread;

View File

@ -0,0 +1,323 @@
#include <BAN/HashMap.h>
#include <kernel/FS/VirtualFileSystem.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UNIX/Socket.h>
#include <kernel/Scheduler.h>
#include <fcntl.h>
#include <sys/un.h>
namespace Kernel
{
static BAN::HashMap<BAN::String, BAN::RefPtr<UnixDomainSocket>> s_bound_sockets;
static SpinLock s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
{
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, ino, inode_info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
s_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
return socket;
}
UnixDomainSocket::UnixDomainSocket(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
, m_socket_type(socket_type)
{
switch (socket_type)
{
case SocketType::STREAM:
case SocketType::SEQPACKET:
m_info.emplace<ConnectionInfo>();
break;
case SocketType::DGRAM:
m_info.emplace<ConnectionlessInfo>();
break;
default:
ASSERT_NOT_REACHED();
}
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len)
{
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.listening)
return BAN::Error::from_errno(EINVAL);
while (connection_info.pending_connections.empty())
TRY(Thread::current().block_or_eintr(connection_info.pending_semaphore));
BAN::RefPtr<UnixDomainSocket> pending;
{
LockGuard _(connection_info.pending_lock);
pending = connection_info.pending_connections.front();
connection_info.pending_connections.pop();
connection_info.pending_semaphore.unblock();
}
BAN::RefPtr<UnixDomainSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(SocketDomain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
}
TRY(return_inode->m_bound_path.append(m_bound_path));
return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection_done = true;
if (address && address_len && !is_bound_to_unused())
{
size_t copy_len = BAN::Math::min<size_t>(*address_len, sizeof(sockaddr) + m_bound_path.size() + 1);
auto& sockaddr_un = *reinterpret_cast<struct sockaddr_un*>(address);
sockaddr_un.sun_family = AF_UNIX;
strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len);
}
return TRY(Process::current().open_inode(return_inode, O_RDWR));
}
BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (!is_bound())
TRY(m_bound_path.push_back('X'));
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
BAN::RefPtr<UnixDomainSocket> target;
{
LockGuard _(s_bound_socket_lock);
if (!s_bound_sockets.contains(file.canonical_path))
return BAN::Error::from_errno(ECONNREFUSED);
target = s_bound_sockets[file.canonical_path];
}
if (m_socket_type != target->m_socket_type)
return BAN::Error::from_errno(EPROTOTYPE);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(ECONNREFUSED);
if (connection_info.listening)
return BAN::Error::from_errno(EOPNOTSUPP);
connection_info.connection_done = false;
for (;;)
{
auto& target_info = target->m_info.get<ConnectionInfo>();
{
LockGuard _(target_info.pending_lock);
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
{
MUST(target_info.pending_connections.push(this));
target_info.pending_semaphore.unblock();
break;
}
}
TRY(Thread::current().block_or_eintr(target_info.pending_semaphore));
}
while (!connection_info.connection_done)
Scheduler::get().reschedule();
return {};
}
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog)
{
backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN);
if (!is_bound())
return BAN::Error::from_errno(EDESTADDRREQ);
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(EINVAL);
TRY(connection_info.pending_connections.reserve(backlog));
connection_info.listening = true;
return {};
}
BAN::ErrorOr<void> UnixDomainSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (is_bound())
return BAN::Error::from_errno(EINVAL);
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
if (auto ret = Process::current().create_file_or_dir(absolute_path, 0755 | S_IFSOCK); ret.is_error())
{
if (ret.error().get_error_code() == EEXIST)
return BAN::Error::from_errno(EADDRINUSE);
return ret.release_error();
}
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
LockGuard _(s_bound_socket_lock);
ASSERT(!s_bound_sockets.contains(file.canonical_path));
TRY(s_bound_sockets.emplace(file.canonical_path, this));
m_bound_path = BAN::move(file.canonical_path);
return {};
}
bool UnixDomainSocket::is_streaming() const
{
switch (m_socket_type)
{
case SocketType::STREAM:
return true;
case SocketType::SEQPACKET:
case SocketType::DGRAM:
return false;
default:
ASSERT_NOT_REACHED();
}
}
// This to feels too hacky to expose out of here
struct LockFreeGuard
{
LockFreeGuard(RecursivePrioritySpinLock& lock)
: m_lock(lock)
, m_depth(lock.lock_depth())
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.unlock();
}
~LockFreeGuard()
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.lock();
}
private:
RecursivePrioritySpinLock& m_lock;
const uint32_t m_depth;
};
BAN::ErrorOr<void> UnixDomainSocket::add_packet(BAN::ConstByteSpan packet)
{
LockGuard _(m_lock);
while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size)
{
LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
memcpy(packet_buffer, packet.data(), packet.size());
m_packet_size_total += packet.size();
if (!is_streaming())
m_packet_sizes.push(packet.size());
m_packet_semaphore.unblock();
return {};
}
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(const sys_sendto_t* arguments)
{
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (arguments->length > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (arguments->dest_addr)
return BAN::Error::from_errno(EISCONN);
auto target = connection_info.connection.lock();
if (!target)
return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet({ reinterpret_cast<const uint8_t*>(arguments->message), arguments->length }));
return arguments->length;
}
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.connection)
return BAN::Error::from_errno(ENOTCONN);
}
while (m_packet_size_total == 0)
{
LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t nread = 0;
if (is_streaming())
nread = BAN::Math::min(arguments->length, m_packet_size_total);
else
{
nread = BAN::Math::min(arguments->length, m_packet_sizes.front());
m_packet_sizes.pop();
}
memcpy(arguments->buffer, packet_buffer, nread);
memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread);
m_packet_size_total -= nread;
m_packet_semaphore.unblock();
return nread;
}
}

View File

@ -55,6 +55,21 @@ namespace Kernel
return {};
}
BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::RefPtr<Inode> inode, int flags)
{
ASSERT(inode);
ASSERT(!inode->mode().ifdir());
if (flags & ~(O_RDONLY | O_WRONLY))
return BAN::Error::from_errno(ENOTSUP);
int fd = TRY(get_free_fd());
// FIXME: path?
m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(inode, ""sv, 0, flags));
return fd;
}
BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::StringView absolute_path, int flags)
{
if (flags & ~(O_RDONLY | O_WRONLY | O_NOFOLLOW | O_SEARCH | O_APPEND | O_TRUNC | O_CLOEXEC | O_TTY_INIT | O_DIRECTORY | O_NONBLOCK))
@ -80,13 +95,25 @@ namespace Kernel
BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol)
{
using SocketType = NetworkManager::SocketType;
if (domain != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (protocol != 0)
return BAN::Error::from_errno(EPROTONOSUPPORT);
SocketDomain sock_domain;
switch (domain)
{
case AF_INET:
sock_domain = SocketDomain::INET;
break;
case AF_INET6:
sock_domain = SocketDomain::INET6;
break;
case AF_UNIX:
sock_domain = SocketDomain::UNIX;
break;
default:
return BAN::Error::from_errno(EPROTOTYPE);
}
SocketType sock_type;
switch (type)
{
@ -103,7 +130,7 @@ namespace Kernel
return BAN::Error::from_errno(EPROTOTYPE);
}
auto socket = TRY(NetworkManager::get().create_socket(sock_type, 0777, m_credentials.euid(), m_credentials.egid()));
auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid()));
int fd = TRY(get_free_fd());
m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(socket, "no-path"sv, 0, O_RDWR));

View File

@ -649,8 +649,9 @@ namespace Kernel
case Inode::Mode::IFREG: break;
case Inode::Mode::IFDIR: break;
case Inode::Mode::IFIFO: break;
case Inode::Mode::IFSOCK: break;
default:
return BAN::Error::from_errno(EINVAL);
return BAN::Error::from_errno(ENOTSUP);
}
LockGuard _(m_lock);
@ -707,8 +708,17 @@ namespace Kernel
return false;
}
BAN::ErrorOr<long> Process::open_inode(BAN::RefPtr<Inode> inode, int flags)
{
ASSERT(inode);
LockGuard _(m_lock);
return TRY(m_open_file_descriptors.open(inode, flags));
}
BAN::ErrorOr<long> Process::open_file(BAN::StringView path, int flags, mode_t mode)
{
LockGuard _(m_lock);
BAN::String absolute_path = TRY(absolute_path_of(path));
if (flags & O_CREAT)
@ -716,6 +726,8 @@ namespace Kernel
if (flags & O_DIRECTORY)
return BAN::Error::from_errno(ENOTSUP);
auto file_or_error = VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, O_WRONLY);
if (!file_or_error.is_error() && (flags & O_EXCL))
return BAN::Error::from_errno(EEXIST);
if (file_or_error.is_error())
{
if (file_or_error.error().get_error_code() == ENOENT)
@ -901,6 +913,26 @@ namespace Kernel
return TRY(m_open_file_descriptors.socket(domain, type, protocol));
}
BAN::ErrorOr<long> Process::sys_accept(int socket, sockaddr* address, socklen_t* address_len)
{
if (address && !address_len)
return BAN::Error::from_errno(EINVAL);
if (!address && address_len)
return BAN::Error::from_errno(EINVAL);
LockGuard _(m_lock);
if (address)
{
TRY(validate_pointer_access(address_len, sizeof(*address_len)));
TRY(validate_pointer_access(address, *address_len));
}
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return TRY(inode->accept(address, address_len));
}
BAN::ErrorOr<long> Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len)
{
@ -915,6 +947,31 @@ namespace Kernel
return 0;
}
BAN::ErrorOr<long> Process::sys_connect(int socket, const sockaddr* address, socklen_t address_len)
{
LockGuard _(m_lock);
TRY(validate_pointer_access(address, address_len));
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
TRY(inode->connect(address, address_len));
return 0;
}
BAN::ErrorOr<long> Process::sys_listen(int socket, int backlog)
{
LockGuard _(m_lock);
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
TRY(inode->listen(backlog));
return 0;
}
BAN::ErrorOr<long> Process::sys_sendto(const sys_sendto_t* arguments)
{
LockGuard _(m_lock);
@ -1697,7 +1754,7 @@ namespace Kernel
BAN::ErrorOr<BAN::String> Process::absolute_path_of(BAN::StringView path) const
{
ASSERT(m_lock.is_locked());
LockGuard _(m_lock);
if (path.empty() || path == "."sv)
return m_working_directory;

View File

@ -228,6 +228,15 @@ namespace Kernel
case SYS_IOCTL:
ret = Process::current().sys_ioctl((int)arg1, (int)arg2, (void*)arg3);
break;
case SYS_ACCEPT:
ret = Process::current().sys_accept((int)arg1, (sockaddr*)arg2, (socklen_t*)arg3);
break;
case SYS_CONNECT:
ret = Process::current().sys_connect((int)arg1, (const sockaddr*)arg2, (socklen_t)arg3);
break;
case SYS_LISTEN:
ret = Process::current().sys_listen((int)arg1, (int)arg2);
break;
default:
dwarnln("Unknown syscall {}", syscall);
break;

View File

@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.26)
project(libc CXX ASM)
set(LIBC_SOURCES
arpa/inet.cpp
assert.cpp
ctype.cpp
dirent.cpp

52
libc/arpa/inet.cpp Normal file
View File

@ -0,0 +1,52 @@
#include <BAN/Endianness.h>
#include <arpa/inet.h>
#include <errno.h>
#include <stdio.h>
uint32_t htonl(uint32_t hostlong)
{
return BAN::host_to_network_endian(hostlong);
}
uint16_t htons(uint16_t hostshort)
{
return BAN::host_to_network_endian(hostshort);
}
uint32_t ntohl(uint32_t netlong)
{
return BAN::host_to_network_endian(netlong);
}
uint16_t ntohs(uint16_t netshort)
{
return BAN::host_to_network_endian(netshort);
}
in_addr_t inet_addr(const char* cp)
{
uint32_t a = 0, b = 0, c = 0, d = 0;
int ret = sscanf(cp, "%u.%u.%u.%u", &a, &b, &c, &d);
if (ret < 1 || ret > 4)
return (in_addr_t)(-1);
uint32_t result = 0;
result |= (ret == 1) ? a : a << 24;
result |= (ret == 2) ? b : b << 16;
result |= (ret == 3) ? c : c << 8;
result |= (ret == 4) ? d : d << 0;
return htonl(result);
}
char* inet_ntoa(struct in_addr in)
{
static char buffer[16];
uint32_t he = ntohl(in.s_addr);
sprintf(buffer, "%u.%u.%u.%u",
(he >> 24) & 0xFF,
(he >> 16) & 0xFF,
(he >> 8) & 0xFF,
(he >> 0) & 0xFF
);
return buffer;
}

View File

@ -22,6 +22,7 @@ struct ifreq
union {
struct sockaddr ifru_addr;
struct sockaddr ifru_netmask;
struct sockaddr ifru_gwaddr;
struct sockaddr ifru_hwaddr;
unsigned char __min_storage[sizeof(sockaddr) + 6];
} ifr_ifru;
@ -31,7 +32,9 @@ struct ifreq
#define SIOCSIFADDR 2 /* Set interface address */
#define SIOCGIFNETMASK 3 /* Get network mask */
#define SIOCSIFNETMASK 4 /* Set network mask */
#define SIOCGIFHWADDR 5 /* Get hardware address */
#define SIOCGIFGWADDR 5 /* Get gateway address */
#define SIOCSIFGWADDR 6 /* Set gateway address */
#define SIOCGIFHWADDR 7 /* Get hardware address */
void if_freenameindex(struct if_nameindex* ptr);
char* if_indextoname(unsigned ifindex, char* ifname);

View File

@ -53,6 +53,8 @@ extern FILE* __stdout;
#define stdout __stdout
extern FILE* __stderr;
#define stderr __stderr
extern FILE* __stddbg;
#define stddbg __stddbg
void clearerr(FILE* stream);
char* ctermid(char* s);

View File

@ -16,6 +16,12 @@ __BEGIN_DECLS
#include <bits/types/sa_family_t.h>
typedef long socklen_t;
#if !defined(FILENAME_MAX)
#define FILENAME_MAX 256
#elif FILENAME_MAX != 256
#error "invalid FILENAME_MAX"
#endif
struct sockaddr
{
sa_family_t sa_family; /* Address family. */
@ -24,8 +30,8 @@ struct sockaddr
struct sockaddr_storage
{
// FIXME
sa_family_t ss_family;
char ss_storage[FILENAME_MAX];
};
struct msghdr

View File

@ -68,6 +68,9 @@ __BEGIN_DECLS
#define SYS_SENDTO 67
#define SYS_RECVFROM 68
#define SYS_IOCTL 69
#define SYS_ACCEPT 70
#define SYS_CONNECT 71
#define SYS_LISTEN 72
__END_DECLS

View File

@ -11,8 +11,8 @@ __BEGIN_DECLS
struct sockaddr_un
{
sa_family_t sun_family; /* Address family. */
char sun_path[]; /* Socket pathname. */
sa_family_t sun_family; /* Address family. */
char sun_path[FILENAME_MAX]; /* Socket pathname. */
};
__END_DECLS

View File

@ -120,6 +120,7 @@ __BEGIN_DECLS
#define STDIN_FILENO 0
#define STDOUT_FILENO 1
#define STDERR_FILENO 2
#define STDDBG_FILENO 3
#define _POSIX_VDISABLE 0

View File

@ -22,11 +22,13 @@ static FILE s_files[FOPEN_MAX] {
{ .fd = STDIN_FILENO },
{ .fd = STDOUT_FILENO },
{ .fd = STDERR_FILENO },
{ .fd = STDDBG_FILENO },
};
FILE* stdin = &s_files[0];
FILE* stdout = &s_files[1];
FILE* stderr = &s_files[2];
FILE* stddbg = &s_files[3];
void clearerr(FILE* file)
{

View File

@ -2,11 +2,31 @@
#include <sys/syscall.h>
#include <unistd.h>
int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
{
return syscall(SYS_ACCEPT, socket, address, address_len);
}
int bind(int socket, const struct sockaddr* address, socklen_t address_len)
{
return syscall(SYS_BIND, socket, address, address_len);
}
int connect(int socket, const struct sockaddr* address, socklen_t address_len)
{
return syscall(SYS_CONNECT, socket, address, address_len);
}
int listen(int socket, int backlog)
{
return syscall(SYS_LISTEN, socket, backlog);
}
ssize_t recv(int socket, void* __restrict buffer, size_t length, int flags)
{
return recvfrom(socket, buffer, length, flags, nullptr, nullptr);
}
ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
{
sys_recvfrom_t arguments {
@ -20,6 +40,10 @@ ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags,
return syscall(SYS_RECVFROM, &arguments);
}
ssize_t send(int socket, const void* message, size_t length, int flags)
{
return sendto(socket, message, length, flags, nullptr, 0);
}
ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len)
{

View File

@ -11,14 +11,16 @@ set(USERSPACE_PROJECTS
dhcp-client
echo
id
init
image
init
loadkeys
ls
meminfo
mkdir
mmap-shared-test
nslookup
poweroff
resolver
rm
Shell
sleep

View File

@ -72,9 +72,9 @@ static constexpr Position s_dir_offset[] {
i64 solve_general(FILE* fp, auto parse_dir, auto parse_count)
{
BAN::HashSetUnstable<Position, PositionHash> path;
BAN::HashSetUnstable<Position, PositionHash> lpath;
BAN::HashSetUnstable<Position, PositionHash> rpath;
BAN::HashSet<Position, PositionHash> path;
BAN::HashSet<Position, PositionHash> lpath;
BAN::HashSet<Position, PositionHash> rpath;
Position current_pos { 0, 0 };
MUST(path.insert(current_pos));
@ -157,8 +157,8 @@ i64 solve_general(FILE* fp, auto parse_dir, auto parse_count)
ASSERT(lmin_x != rmin_x);
auto& expand = (lmin_x < rmin_x) ? rpath : lpath;
BAN::HashSetUnstable<Position, PositionHash> visited;
BAN::HashSetUnstable<Position, PositionHash> inner_area;
BAN::HashSet<Position, PositionHash> visited;
BAN::HashSet<Position, PositionHash> inner_area;
while (!expand.empty())
{

View File

@ -33,7 +33,7 @@ struct Rule
BAN::String target;
};
using Workflows = BAN::HashMapUnstable<BAN::String, BAN::Vector<Rule>>;
using Workflows = BAN::HashMap<BAN::String, BAN::Vector<Rule>>;
struct Item
{

View File

@ -72,9 +72,9 @@ struct ConjunctionModule : public Module
}
};
BAN::HashMapUnstable<BAN::String, BAN::UniqPtr<Module>> parse_modules(FILE* fp)
BAN::HashMap<BAN::String, BAN::UniqPtr<Module>> parse_modules(FILE* fp)
{
BAN::HashMapUnstable<BAN::String, BAN::UniqPtr<Module>> modules;
BAN::HashMap<BAN::String, BAN::UniqPtr<Module>> modules;
char buffer[128];
while (fgets(buffer, sizeof(buffer), fp))

View File

@ -88,13 +88,13 @@ i64 puzzle1(FILE* fp)
{
auto garden = parse_garden(fp);
BAN::HashSetUnstable<Position> visited, reachable, pending;
BAN::HashSet<Position> visited, reachable, pending;
MUST(pending.insert(garden.start));
for (i32 i = 0; i <= 64; i++)
{
auto temp = BAN::move(pending);
pending = BAN::HashSetUnstable<Position>();
pending = BAN::HashSet<Position>();
while (!temp.empty())
{

View File

@ -1,8 +1,10 @@
#include <BAN/Debug.h>
#include <BAN/Endianness.h>
#include <BAN/IPv4.h>
#include <BAN/MAC.h>
#include <BAN/Vector.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <net/if.h>
#include <netinet/in.h>
@ -12,7 +14,7 @@
#include <stropts.h>
#include <sys/socket.h>
#define DEBUG_DHCP 0
#define DEBUG_DHCP 1
struct DHCPPacket
{
@ -23,10 +25,10 @@ struct DHCPPacket
BAN::NetworkEndian<uint32_t> xid { 0x3903F326 };
BAN::NetworkEndian<uint16_t> secs { 0x0000 };
BAN::NetworkEndian<uint16_t> flags { 0x0000 };
BAN::NetworkEndian<uint32_t> ciaddr { 0 };
BAN::NetworkEndian<uint32_t> yiaddr { 0 };
BAN::NetworkEndian<uint32_t> siaddr { 0 };
BAN::NetworkEndian<uint32_t> giaddr { 0 };
BAN::IPv4Address ciaddr { 0 };
BAN::IPv4Address yiaddr { 0 };
BAN::IPv4Address siaddr { 0 };
BAN::IPv4Address giaddr { 0 };
BAN::MACAddress chaddr;
uint8_t padding[10] {};
uint8_t legacy[192] {};
@ -71,12 +73,13 @@ BAN::MACAddress get_mac_address(int socket)
return mac_address;
}
void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address subnet)
void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address netmask, BAN::IPv4Address gateway)
{
{
ifreq ifreq;
ifreq.ifr_ifru.ifru_addr.sa_family = AF_INET;
*(uint32_t*)ifreq.ifr_ifru.ifru_addr.sa_data = address.as_u32();
auto& ifru_addr = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_addr);
ifru_addr.sin_family = AF_INET;
ifru_addr.sin_addr.s_addr = address.raw;
if (ioctl(socket, SIOCSIFADDR, &ifreq) == -1)
{
perror("ioctl");
@ -86,22 +89,36 @@ void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address sub
{
ifreq ifreq;
ifreq.ifr_ifru.ifru_netmask.sa_family = AF_INET;
*(uint32_t*)ifreq.ifr_ifru.ifru_netmask.sa_data = subnet.as_u32();
auto& ifru_netmask = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_netmask);
ifru_netmask.sin_family = AF_INET;
ifru_netmask.sin_addr.s_addr = netmask.raw;
if (ioctl(socket, SIOCSIFNETMASK, &ifreq) == -1)
{
perror("ioctl");
exit(1);
}
}
if (gateway.raw)
{
ifreq ifreq;
auto& ifru_gwaddr = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_gwaddr);
ifru_gwaddr.sin_family = AF_INET;
ifru_gwaddr.sin_addr.s_addr = gateway.raw;
if (ioctl(socket, SIOCSIFGWADDR, &ifreq) == -1)
{
perror("ioctl");
exit(1);
}
}
}
void send_dhcp_packet(int socket, const DHCPPacket& dhcp_packet, BAN::IPv4Address server_ipv4)
{
sockaddr_in server_addr;
server_addr.sin_family = AF_INET;
server_addr.sin_port = 67;
server_addr.sin_addr.s_addr = server_ipv4.as_u32();;
server_addr.sin_port = htons(67);
server_addr.sin_addr.s_addr = server_ipv4.raw;
if (sendto(socket, &dhcp_packet, sizeof(DHCPPacket), 0, (sockaddr*)&server_addr, sizeof(server_addr)) == -1)
{
@ -137,7 +154,7 @@ void send_dhcp_request(int socket, BAN::MACAddress mac_address, BAN::IPv4Address
{
DHCPPacket dhcp_packet;
dhcp_packet.op = 0x01;
dhcp_packet.siaddr = server_ipv4.as_u32();
dhcp_packet.siaddr = server_ipv4.raw;
dhcp_packet.chaddr = mac_address;
size_t idx = 0;
@ -148,10 +165,10 @@ void send_dhcp_request(int socket, BAN::MACAddress mac_address, BAN::IPv4Address
dhcp_packet.options[idx++] = RequestedIPv4Address;
dhcp_packet.options[idx++] = 0x04;
dhcp_packet.options[idx++] = offered_ipv4.address[0];
dhcp_packet.options[idx++] = offered_ipv4.address[1];
dhcp_packet.options[idx++] = offered_ipv4.address[2];
dhcp_packet.options[idx++] = offered_ipv4.address[3];
dhcp_packet.options[idx++] = offered_ipv4.octets[0];
dhcp_packet.options[idx++] = offered_ipv4.octets[1];
dhcp_packet.options[idx++] = offered_ipv4.octets[2];
dhcp_packet.options[idx++] = offered_ipv4.octets[3];
dhcp_packet.options[idx++] = 0xFF;
@ -188,7 +205,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
fprintf(stderr, "Subnet mask with invalid length %hhu\n", length);
break;
}
uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options);
uint32_t raw = *reinterpret_cast<const uint32_t*>(options);
packet_info.subnet = BAN::IPv4Address(raw);
break;
}
@ -201,7 +218,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
}
for (int i = 0; i < length; i += 4)
{
uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options + i);
uint32_t raw = *reinterpret_cast<const uint32_t*>(options + i);
MUST(packet_info.routers.emplace_back(raw));
}
break;
@ -215,7 +232,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
}
for (int i = 0; i < length; i += 4)
{
uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options + i);
uint32_t raw = *reinterpret_cast<const uint32_t*>(options + i);
MUST(packet_info.dns.emplace_back(raw));
}
break;
@ -244,7 +261,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
fprintf(stderr, "Server identifier with invalid length %hhu\n", length);
break;
}
uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options);
uint32_t raw = *reinterpret_cast<const uint32_t*>(options);
packet_info.server = BAN::IPv4Address(raw);
break;
}
@ -293,8 +310,8 @@ int main()
sockaddr_in client_addr;
client_addr.sin_family = AF_INET;
client_addr.sin_port = 68;
client_addr.sin_addr.s_addr = 0x00000000;
client_addr.sin_port = htons(68);
client_addr.sin_addr.s_addr = INADDR_ANY;
if (bind(socket, (sockaddr*)&client_addr, sizeof(client_addr)) == -1)
{
@ -304,12 +321,12 @@ int main()
auto mac_address = get_mac_address(socket);
#if DEBUG_DHCP
BAN::Formatter::println(putchar, "MAC: {}", mac_address);
dprintln("MAC: {}", mac_address);
#endif
send_dhcp_discover(socket, mac_address);
#if DEBUG_DHCP
printf("DHCPDISCOVER sent\n");
dprintln("DHCPDISCOVER sent");
#endif
auto dhcp_offer = read_dhcp_packet(socket);
@ -322,15 +339,15 @@ int main()
}
#if DEBUG_DHCP
BAN::Formatter::println(putchar, "DHCPOFFER");
BAN::Formatter::println(putchar, " IP {}", dhcp_offer->address);
BAN::Formatter::println(putchar, " SUBNET {}", dhcp_offer->subnet);
BAN::Formatter::println(putchar, " SERVER {}", dhcp_offer->server);
dprintln("DHCPOFFER");
dprintln(" IP {}", dhcp_offer->address);
dprintln(" SUBNET {}", dhcp_offer->subnet);
dprintln(" SERVER {}", dhcp_offer->server);
#endif
send_dhcp_request(socket, mac_address, dhcp_offer->address, dhcp_offer->server);
#if DEBUG_DHCP
printf("DHCPREQUEST sent\n");
dprintln("DHCPREQUEST sent");
#endif
auto dhcp_ack = read_dhcp_packet(socket);
@ -343,10 +360,10 @@ int main()
}
#if DEBUG_DHCP
BAN::Formatter::println(putchar, "DHCPACK");
BAN::Formatter::println(putchar, " IP {}", dhcp_ack->address);
BAN::Formatter::println(putchar, " SUBNET {}", dhcp_ack->subnet);
BAN::Formatter::println(putchar, " SERVER {}", dhcp_ack->server);
dprintln("DHCPACK");
dprintln(" IP {}", dhcp_ack->address);
dprintln(" SUBNET {}", dhcp_ack->subnet);
dprintln(" SERVER {}", dhcp_ack->server);
#endif
if (dhcp_offer->address != dhcp_ack->address)
@ -355,7 +372,11 @@ int main()
return 1;
}
update_ipv4_info(socket, dhcp_ack->address, dhcp_ack->subnet);
BAN::IPv4Address gateway { 0 };
if (!dhcp_ack->routers.empty())
gateway = dhcp_ack->routers.front();
update_ipv4_info(socket, dhcp_ack->address, dhcp_ack->subnet, gateway);
close(socket);

View File

@ -17,6 +17,7 @@ void initialize_stdio()
if (open(tty, O_RDONLY | O_TTY_INIT) != 0) _exit(1);
if (open(tty, O_WRONLY) != 1) _exit(1);
if (open(tty, O_WRONLY) != 2) _exit(1);
if (open("/dev/debug", O_WRONLY) != 3) _exit(1);
}
int main()
@ -35,6 +36,12 @@ int main()
exit(1);
}
if (fork() == 0)
{
execl("/bin/resolver", "resolver", NULL);
exit(1);
}
bool first = true;
termios termios;

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.26)
project(nslookup CXX)
set(SOURCES
main.cpp
)
add_executable(nslookup ${SOURCES})
target_compile_options(nslookup PUBLIC -O2 -g)
target_link_libraries(nslookup PUBLIC libc)
add_custom_target(nslookup-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/nslookup ${BANAN_BIN}/
DEPENDS nslookup
)

View File

@ -0,0 +1,47 @@
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "usage: %s DOMAIN\n", argv[0]);
return 1;
}
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
perror("socket");
return 1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("connect");
return 1;
}
if (send(socket, argv[1], strlen(argv[1]), 0) == -1)
{
perror("send");
return 1;
}
char buffer[128];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
return 1;
}
buffer[nrecv] = '\0';
printf("%s\n", buffer);
return 0;
}

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.26)
project(resolver CXX)
set(SOURCES
main.cpp
)
add_executable(resolver ${SOURCES})
target_compile_options(resolver PUBLIC -O2 -g)
target_link_libraries(resolver PUBLIC libc ban)
add_custom_target(resolver-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/resolver ${BANAN_BIN}/
DEPENDS resolver
)

218
userspace/resolver/main.cpp Normal file
View File

@ -0,0 +1,218 @@
#include <BAN/ByteSpan.h>
#include <BAN/Endianness.h>
#include <BAN/HashMap.h>
#include <BAN/IPv4.h>
#include <BAN/String.h>
#include <BAN/StringView.h>
#include <BAN/Vector.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <unistd.h>
struct DNSPacket
{
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags { 0 };
BAN::NetworkEndian<uint16_t> question_count { 0 };
BAN::NetworkEndian<uint16_t> answer_count { 0 };
BAN::NetworkEndian<uint16_t> authority_RR_count { 0 };
BAN::NetworkEndian<uint16_t> additional_RR_count { 0 };
uint8_t data[];
};
static_assert(sizeof(DNSPacket) == 12);
struct DNSAnswer
{
uint8_t __storage[12];
BAN::NetworkEndian<uint16_t>& name() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x00); };
BAN::NetworkEndian<uint16_t>& type() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x02); };
BAN::NetworkEndian<uint16_t>& class_() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x04); };
BAN::NetworkEndian<uint32_t>& ttl() { return *reinterpret_cast<BAN::NetworkEndian<uint32_t>*>(__storage + 0x06); };
BAN::NetworkEndian<uint16_t>& data_len() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x0A); };
uint8_t data[];
};
static_assert(sizeof(DNSAnswer) == 12);
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
{
static uint8_t buffer[4096];
memset(buffer, 0, sizeof(buffer));
DNSPacket& request = *reinterpret_cast<DNSPacket*>(buffer);
request.identification = id;
request.flags = 0x0100;
request.question_count = 1;
size_t idx = 0;
auto labels = MUST(BAN::StringView(domain).split('.'));
for (auto label : labels)
{
ASSERT(label.size() <= 0xFF);
request.data[idx++] = label.size();
for (char c : label)
request.data[idx++] = c;
}
request.data[idx++] = 0x00;
*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2;
*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2;
sockaddr_in nameserver;
nameserver.sin_family = AF_INET;
nameserver.sin_port = htons(53);
nameserver.sin_addr.s_addr = inet_addr("8.8.8.8");
if (sendto(socket, &request, sizeof(DNSPacket) + idx, 0, (sockaddr*)&nameserver, sizeof(nameserver)) == -1)
{
perror("sendto");
return false;
}
return true;
}
BAN::Optional<BAN::String> read_dns_response(int socket, uint16_t id)
{
static uint8_t buffer[4096];
ssize_t nrecv = recvfrom(socket, buffer, sizeof(buffer), 0, nullptr, nullptr);
if (nrecv == -1)
{
perror("recvfrom");
return {};
}
DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
if (reply.identification != id)
{
fprintf(stderr, "Reply to invalid packet\n");
return {};
}
if (reply.flags & 0x0F)
{
fprintf(stderr, "DNS error (rcode %u)\n", (unsigned)(reply.flags & 0xF));
return {};
}
size_t idx = 0;
for (size_t i = 0; i < reply.question_count; i++)
{
while (reply.data[idx])
idx += reply.data[idx] + 1;
idx += 5;
}
DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]);
if (answer.data_len() != 4)
{
fprintf(stderr, "Not IPv4\n");
return {};
}
return inet_ntoa({ .s_addr = *reinterpret_cast<uint32_t*>(answer.data) });
}
int create_service_socket()
{
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
perror("socket");
return -1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (bind(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("bind");
close(socket);
return -1;
}
if (chmod("/tmp/resolver.sock", 0777) == -1)
{
perror("chmod");
close(socket);
return -1;
}
if (listen(socket, 10) == -1)
{
perror("listen");
close(socket);
return -1;
}
return socket;
}
BAN::Optional<BAN::String> read_service_query(int socket)
{
static char buffer[4096];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
return {};
}
buffer[nrecv] = '\0';
return BAN::String(buffer);
}
int main(int, char**)
{
srand(time(nullptr));
int service_socket = create_service_socket();
if (service_socket == -1)
return 1;
int dns_socket = socket(AF_INET, SOCK_DGRAM, 0);
if (dns_socket == -1)
{
perror("socket");
return 1;
}
for (;;)
{
int client = accept(service_socket, nullptr, nullptr);
if (client == -1)
{
perror("accept");
continue;
}
auto query = read_service_query(client);
if (!query.has_value())
continue;
uint16_t id = rand() % 0xFFFF;
if (send_dns_query(dns_socket, *query, id))
{
auto response = read_dns_response(dns_socket, id);
if (response.has_value())
{
if (send(client, response->data(), response->size() + 1, 0) == -1)
perror("send");
close(client);
continue;
}
}
char message[] = "unavailable";
send(client, message, sizeof(message), 0);
close(client);
}
return 0;
}