Compare commits

...

22 Commits

Author SHA1 Message Date
Bananymous 49889858fa Kernel: Allow chmod on TmpSocketInode 2024-02-08 03:16:01 +02:00
Bananymous 2424f38a62 Userspace: Implement super simple DNS resolver in userspace
You connect to this service using unix domain sockets and send the
asked domain name. It will respond with ip address or 'unavailable'

There is no DNS cache implemented so all calls ask the nameserver.
2024-02-08 03:14:00 +02:00
Bananymous 218456d127 BAN: Fix some includes 2024-02-08 03:13:21 +02:00
Bananymous e7dd03e551 Kernel: Implement basic connection-mode unix domain sockets 2024-02-08 02:28:19 +02:00
Bananymous 0c8e9fe095 Kernel: Add operator bool() for WeakPtr 2024-02-08 02:26:46 +02:00
Bananymous 5b4acec4ca BAN: Add capacity() getter for Queue 2024-02-07 22:53:56 +02:00
Bananymous e26f360d93 Kernel: allow kmalloc of size 0 2024-02-07 22:36:24 +02:00
Bananymous 2cc9534570 BAN: Add emplace for Variant
This allows variant to store values that are not copy/move
constructible.
2024-02-07 22:33:16 +02:00
Bananymous 572c4052f6 Kernel: Fix Process APIs 2024-02-07 15:57:45 +02:00
Bananymous 132286895f Kernel: Implement Socket inodes for tmpfs 2024-02-07 15:57:45 +02:00
Bananymous 454bee3f02 LibC: Fix sockaddr_un implementation 2024-02-07 15:57:45 +02:00
Bananymous 41cad88d6e Kernel/LibC: Implement dummy syscalls for accept, connect, listen 2024-02-07 15:57:45 +02:00
Bananymous 40e341b0ee BAN: Remove unstable hash map and set
These can now be implemented safely with new linked list api
2024-02-06 17:35:15 +02:00
Bananymous 5da59c9151 Kernel: Make better abstractions for networking 2024-02-06 16:45:39 +02:00
Bananymous f804e87f7d Kernel: Implement basic gateway for network interfaces 2024-02-05 18:18:56 +02:00
Bananymous dd3641f054 Kernel: Cleanup ARPTable code
Packet process is now killed if ARPTable dies.

ARP wait loop now just reschecules so timeout actually works.
2024-02-05 18:18:56 +02:00
Bananymous b2291ce162 Kernel/BAN: Fix network strucute endianness 2024-02-05 18:18:56 +02:00
Bananymous c35ed6570b LibC: Implement endiannes and ip address functions 2024-02-05 18:18:56 +02:00
Bananymous d15cbb2d6a Kernel: Fix IPv4 header checksum calculation 2024-02-05 18:18:56 +02:00
Bananymous b8cf6432ef BAN: Implement host_to_network_endian 2024-02-05 17:29:24 +02:00
Bananymous 89805fb092 dhcp-client: Use dprintln for debug printing 2024-02-05 01:24:45 +02:00
Bananymous 692cec8458 Kernel/Userspace/LibC: Implement basic dprintln for userspace 2024-02-05 01:24:09 +02:00
67 changed files with 1920 additions and 468 deletions

View File

@ -2,6 +2,8 @@
#include <BAN/Span.h> #include <BAN/Span.h>
#include <stdint.h>
namespace BAN namespace BAN
{ {

31
BAN/include/BAN/Debug.h Normal file
View File

@ -0,0 +1,31 @@
#pragma once
#if __is_kernel
#error "This is userspace only file"
#endif
#include <BAN/Formatter.h>
#include <stdio.h>
#define __debug_putchar [](int c) { putc(c, stddbg); }
#define dprintln(...) \
do { \
BAN::Formatter::print(__debug_putchar, __VA_ARGS__); \
BAN::Formatter::print(__debug_putchar,"\r\n"); \
fflush(stddbg); \
} while (false)
#define dwarnln(...) \
do { \
BAN::Formatter::print(__debug_putchar, "\e[33m"); \
dprintln(__VA_ARGS__); \
BAN::Formatter::print(__debug_putchar, "\e[m"); \
} while(false)
#define derrorln(...) \
do { \
BAN::Formatter::print(__debug_putchar, "\e[31m"); \
dprintln(__VA_ARGS__); \
BAN::Formatter::print(__debug_putchar, "\e[m"); \
} while(false)

View File

@ -90,4 +90,10 @@ namespace BAN
template<integral T> template<integral T>
using NetworkEndian = BigEndian<T>; using NetworkEndian = BigEndian<T>;
template<integral T>
constexpr T host_to_network_endian(T value)
{
return host_to_big_endian(value);
}
} }

View File

@ -7,7 +7,7 @@
namespace BAN namespace BAN
{ {
template<typename Key, typename T, typename HASH = BAN::hash<Key>, bool STABLE = true> template<typename Key, typename T, typename HASH = BAN::hash<Key>>
class HashMap class HashMap
{ {
public: public:
@ -32,12 +32,12 @@ namespace BAN
public: public:
HashMap() = default; HashMap() = default;
HashMap(const HashMap<Key, T, HASH, STABLE>&); HashMap(const HashMap<Key, T, HASH>&);
HashMap(HashMap<Key, T, HASH, STABLE>&&); HashMap(HashMap<Key, T, HASH>&&);
~HashMap(); ~HashMap();
HashMap<Key, T, HASH, STABLE>& operator=(const HashMap<Key, T, HASH, STABLE>&); HashMap<Key, T, HASH>& operator=(const HashMap<Key, T, HASH>&);
HashMap<Key, T, HASH, STABLE>& operator=(HashMap<Key, T, HASH, STABLE>&&); HashMap<Key, T, HASH>& operator=(HashMap<Key, T, HASH>&&);
ErrorOr<void> insert(const Key&, const T&); ErrorOr<void> insert(const Key&, const T&);
ErrorOr<void> insert(const Key&, T&&); ErrorOr<void> insert(const Key&, T&&);
@ -74,26 +74,26 @@ namespace BAN
friend iterator; friend iterator;
}; };
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH, STABLE>::HashMap(const HashMap<Key, T, HASH, STABLE>& other) HashMap<Key, T, HASH>::HashMap(const HashMap<Key, T, HASH>& other)
{ {
*this = other; *this = other;
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH, STABLE>::HashMap(HashMap<Key, T, HASH, STABLE>&& other) HashMap<Key, T, HASH>::HashMap(HashMap<Key, T, HASH>&& other)
{ {
*this = move(other); *this = move(other);
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH, STABLE>::~HashMap() HashMap<Key, T, HASH>::~HashMap()
{ {
clear(); clear();
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH, STABLE>& HashMap<Key, T, HASH, STABLE>::operator=(const HashMap<Key, T, HASH, STABLE>& other) HashMap<Key, T, HASH>& HashMap<Key, T, HASH>::operator=(const HashMap<Key, T, HASH>& other)
{ {
clear(); clear();
m_buckets = other.m_buckets; m_buckets = other.m_buckets;
@ -101,8 +101,8 @@ namespace BAN
return *this; return *this;
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
HashMap<Key, T, HASH, STABLE>& HashMap<Key, T, HASH, STABLE>::operator=(HashMap<Key, T, HASH, STABLE>&& other) HashMap<Key, T, HASH>& HashMap<Key, T, HASH>::operator=(HashMap<Key, T, HASH>&& other)
{ {
clear(); clear();
m_buckets = move(other.m_buckets); m_buckets = move(other.m_buckets);
@ -111,21 +111,21 @@ namespace BAN
return *this; return *this;
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::insert(const Key& key, const T& value) ErrorOr<void> HashMap<Key, T, HASH>::insert(const Key& key, const T& value)
{ {
return insert(key, move(T(value))); return insert(key, move(T(value)));
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::insert(const Key& key, T&& value) ErrorOr<void> HashMap<Key, T, HASH>::insert(const Key& key, T&& value)
{ {
return emplace(key, move(value)); return emplace(key, move(value));
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
template<typename... Args> template<typename... Args>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::emplace(const Key& key, Args&&... args) ErrorOr<void> HashMap<Key, T, HASH>::emplace(const Key& key, Args&&... args)
{ {
ASSERT(!contains(key)); ASSERT(!contains(key));
TRY(rebucket(m_size + 1)); TRY(rebucket(m_size + 1));
@ -135,15 +135,15 @@ namespace BAN
return {}; return {};
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::reserve(size_type size) ErrorOr<void> HashMap<Key, T, HASH>::reserve(size_type size)
{ {
TRY(rebucket(size)); TRY(rebucket(size));
return {}; return {};
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
void HashMap<Key, T, HASH, STABLE>::remove(const Key& key) void HashMap<Key, T, HASH>::remove(const Key& key)
{ {
if (empty()) return; if (empty()) return;
auto& bucket = get_bucket(key); auto& bucket = get_bucket(key);
@ -158,15 +158,15 @@ namespace BAN
} }
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
void HashMap<Key, T, HASH, STABLE>::clear() void HashMap<Key, T, HASH>::clear()
{ {
m_buckets.clear(); m_buckets.clear();
m_size = 0; m_size = 0;
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
T& HashMap<Key, T, HASH, STABLE>::operator[](const Key& key) T& HashMap<Key, T, HASH>::operator[](const Key& key)
{ {
ASSERT(!empty()); ASSERT(!empty());
auto& bucket = get_bucket(key); auto& bucket = get_bucket(key);
@ -176,8 +176,8 @@ namespace BAN
ASSERT(false); ASSERT(false);
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
const T& HashMap<Key, T, HASH, STABLE>::operator[](const Key& key) const const T& HashMap<Key, T, HASH>::operator[](const Key& key) const
{ {
ASSERT(!empty()); ASSERT(!empty());
const auto& bucket = get_bucket(key); const auto& bucket = get_bucket(key);
@ -187,8 +187,8 @@ namespace BAN
ASSERT(false); ASSERT(false);
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
bool HashMap<Key, T, HASH, STABLE>::contains(const Key& key) const bool HashMap<Key, T, HASH>::contains(const Key& key) const
{ {
if (empty()) return false; if (empty()) return false;
const auto& bucket = get_bucket(key); const auto& bucket = get_bucket(key);
@ -198,20 +198,20 @@ namespace BAN
return false; return false;
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
bool HashMap<Key, T, HASH, STABLE>::empty() const bool HashMap<Key, T, HASH>::empty() const
{ {
return m_size == 0; return m_size == 0;
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
typename HashMap<Key, T, HASH, STABLE>::size_type HashMap<Key, T, HASH, STABLE>::size() const typename HashMap<Key, T, HASH>::size_type HashMap<Key, T, HASH>::size() const
{ {
return m_size; return m_size;
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH, STABLE>::rebucket(size_type bucket_count) ErrorOr<void> HashMap<Key, T, HASH>::rebucket(size_type bucket_count)
{ {
if (m_buckets.size() >= bucket_count) if (m_buckets.size() >= bucket_count)
return {}; return {};
@ -222,13 +222,10 @@ namespace BAN
for (auto& bucket : m_buckets) for (auto& bucket : m_buckets)
{ {
for (Entry& entry : bucket) for (auto it = bucket.begin(); it != bucket.end();)
{ {
size_type bucket_index = HASH()(entry.key) % new_buckets.size(); size_type new_bucket_index = HASH()(it->key) % new_buckets.size();
if constexpr(STABLE) it = bucket.move_element_to_other_linked_list(new_buckets[new_bucket_index], new_buckets[new_bucket_index].end(), it);
TRY(new_buckets[bucket_index].push_back(entry));
else
TRY(new_buckets[bucket_index].push_back(move(entry)));
} }
} }
@ -236,27 +233,20 @@ namespace BAN
return {}; return {};
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
LinkedList<typename HashMap<Key, T, HASH, STABLE>::Entry>& HashMap<Key, T, HASH, STABLE>::get_bucket(const Key& key) LinkedList<typename HashMap<Key, T, HASH>::Entry>& HashMap<Key, T, HASH>::get_bucket(const Key& key)
{ {
ASSERT(!m_buckets.empty()); ASSERT(!m_buckets.empty());
auto index = HASH()(key) % m_buckets.size(); auto index = HASH()(key) % m_buckets.size();
return m_buckets[index]; return m_buckets[index];
} }
template<typename Key, typename T, typename HASH, bool STABLE> template<typename Key, typename T, typename HASH>
const LinkedList<typename HashMap<Key, T, HASH, STABLE>::Entry>& HashMap<Key, T, HASH, STABLE>::get_bucket(const Key& key) const const LinkedList<typename HashMap<Key, T, HASH>::Entry>& HashMap<Key, T, HASH>::get_bucket(const Key& key) const
{ {
ASSERT(!m_buckets.empty()); ASSERT(!m_buckets.empty());
auto index = HASH()(key) % m_buckets.size(); auto index = HASH()(key) % m_buckets.size();
return m_buckets[index]; return m_buckets[index];
} }
// Unstable hash map moves values between container during rebucketing.
// This means that if insertion to map fails, elements could be in invalid state
// and that container is no longer usable. This is better if either way you are
// going to stop using the hash map after insertion fails.
template<typename Key, typename T, typename HASH = BAN::hash<Key>>
using HashMapUnstable = HashMap<Key, T, HASH, false>;
} }

View File

@ -11,7 +11,7 @@
namespace BAN namespace BAN
{ {
template<typename T, typename HASH = hash<T>, bool STABLE = true> template<typename T, typename HASH = hash<T>>
class HashSet class HashSet
{ {
public: public:
@ -55,23 +55,23 @@ namespace BAN
size_type m_size = 0; size_type m_size = 0;
}; };
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
HashSet<T, HASH, STABLE>::HashSet(const HashSet& other) HashSet<T, HASH>::HashSet(const HashSet& other)
: m_buckets(other.m_buckets) : m_buckets(other.m_buckets)
, m_size(other.m_size) , m_size(other.m_size)
{ {
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
HashSet<T, HASH, STABLE>::HashSet(HashSet&& other) HashSet<T, HASH>::HashSet(HashSet&& other)
: m_buckets(move(other.m_buckets)) : m_buckets(move(other.m_buckets))
, m_size(other.m_size) , m_size(other.m_size)
{ {
other.clear(); other.clear();
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
HashSet<T, HASH, STABLE>& HashSet<T, HASH, STABLE>::operator=(const HashSet& other) HashSet<T, HASH>& HashSet<T, HASH>::operator=(const HashSet& other)
{ {
clear(); clear();
m_buckets = other.m_buckets; m_buckets = other.m_buckets;
@ -79,8 +79,8 @@ namespace BAN
return *this; return *this;
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
HashSet<T, HASH, STABLE>& HashSet<T, HASH, STABLE>::operator=(HashSet&& other) HashSet<T, HASH>& HashSet<T, HASH>::operator=(HashSet&& other)
{ {
clear(); clear();
m_buckets = move(other.m_buckets); m_buckets = move(other.m_buckets);
@ -89,14 +89,14 @@ namespace BAN
return *this; return *this;
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
ErrorOr<void> HashSet<T, HASH, STABLE>::insert(const T& key) ErrorOr<void> HashSet<T, HASH>::insert(const T& key)
{ {
return insert(move(T(key))); return insert(move(T(key)));
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
ErrorOr<void> HashSet<T, HASH, STABLE>::insert(T&& key) ErrorOr<void> HashSet<T, HASH>::insert(T&& key)
{ {
if (!empty() && get_bucket(key).contains(key)) if (!empty() && get_bucket(key).contains(key))
return {}; return {};
@ -107,8 +107,8 @@ namespace BAN
return {}; return {};
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
void HashSet<T, HASH, STABLE>::remove(const T& key) void HashSet<T, HASH>::remove(const T& key)
{ {
if (empty()) return; if (empty()) return;
auto& bucket = get_bucket(key); auto& bucket = get_bucket(key);
@ -123,41 +123,41 @@ namespace BAN
} }
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
void HashSet<T, HASH, STABLE>::clear() void HashSet<T, HASH>::clear()
{ {
m_buckets.clear(); m_buckets.clear();
m_size = 0; m_size = 0;
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
ErrorOr<void> HashSet<T, HASH, STABLE>::reserve(size_type size) ErrorOr<void> HashSet<T, HASH>::reserve(size_type size)
{ {
TRY(rebucket(size)); TRY(rebucket(size));
return {}; return {};
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
bool HashSet<T, HASH, STABLE>::contains(const T& key) const bool HashSet<T, HASH>::contains(const T& key) const
{ {
if (empty()) return false; if (empty()) return false;
return get_bucket(key).contains(key); return get_bucket(key).contains(key);
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
typename HashSet<T, HASH, STABLE>::size_type HashSet<T, HASH, STABLE>::size() const typename HashSet<T, HASH>::size_type HashSet<T, HASH>::size() const
{ {
return m_size; return m_size;
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
bool HashSet<T, HASH, STABLE>::empty() const bool HashSet<T, HASH>::empty() const
{ {
return m_size == 0; return m_size == 0;
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
ErrorOr<void> HashSet<T, HASH, STABLE>::rebucket(size_type bucket_count) ErrorOr<void> HashSet<T, HASH>::rebucket(size_type bucket_count)
{ {
if (m_buckets.size() >= bucket_count) if (m_buckets.size() >= bucket_count)
return {}; return {};
@ -169,13 +169,10 @@ namespace BAN
for (auto& bucket : m_buckets) for (auto& bucket : m_buckets)
{ {
for (T& key : bucket) for (auto it = bucket.begin(); it != bucket.end();)
{ {
size_type bucket_index = HASH()(key) % new_buckets.size(); size_type new_bucket_index = HASH()(*it) % new_buckets.size();
if constexpr(STABLE) it = bucket.move_element_to_other_linked_list(new_buckets[new_bucket_index], new_buckets[new_bucket_index].end(), it);
TRY(new_buckets[bucket_index].push_back(key));
else
TRY(new_buckets[bucket_index].push_back(move(key)));
} }
} }
@ -183,27 +180,20 @@ namespace BAN
return {}; return {};
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
LinkedList<T>& HashSet<T, HASH, STABLE>::get_bucket(const T& key) LinkedList<T>& HashSet<T, HASH>::get_bucket(const T& key)
{ {
ASSERT(!m_buckets.empty()); ASSERT(!m_buckets.empty());
size_type index = HASH()(key) % m_buckets.size(); size_type index = HASH()(key) % m_buckets.size();
return m_buckets[index]; return m_buckets[index];
} }
template<typename T, typename HASH, bool STABLE> template<typename T, typename HASH>
const LinkedList<T>& HashSet<T, HASH, STABLE>::get_bucket(const T& key) const const LinkedList<T>& HashSet<T, HASH>::get_bucket(const T& key) const
{ {
ASSERT(!m_buckets.empty()); ASSERT(!m_buckets.empty());
size_type index = HASH()(key) % m_buckets.size(); size_type index = HASH()(key) % m_buckets.size();
return m_buckets[index]; return m_buckets[index];
} }
// Unstable hash set moves values between container during rebucketing.
// This means that if insertion to set fails, elements could be in invalid state
// and that container is no longer usable. This is better if either way you are
// going to stop using the hash set after insertion fails.
template<typename T, typename HASH = hash<T>>
using HashSetUnstable = HashSet<T, HASH, false>;
} }

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <BAN/Endianness.h>
#include <BAN/Formatter.h> #include <BAN/Formatter.h>
#include <BAN/Hash.h> #include <BAN/Hash.h>
@ -10,31 +11,32 @@ namespace BAN
{ {
constexpr IPv4Address(uint32_t u32_address) constexpr IPv4Address(uint32_t u32_address)
{ {
address[0] = u32_address >> 24; raw = u32_address;
address[1] = u32_address >> 16;
address[2] = u32_address >> 8;
address[3] = u32_address >> 0;
} }
constexpr uint32_t as_u32() const constexpr IPv4Address(uint8_t oct1, uint8_t oct2, uint8_t oct3, uint8_t oct4)
{ {
return octets[0] = oct1;
((uint32_t)address[0] << 24) | octets[1] = oct2;
((uint32_t)address[1] << 16) | octets[2] = oct3;
((uint32_t)address[2] << 8) | octets[3] = oct4;
((uint32_t)address[3] << 0);
} }
constexpr bool operator==(const IPv4Address& other) const constexpr bool operator==(const IPv4Address& other) const
{ {
return return raw == other.raw;
address[0] == other.address[0] &&
address[1] == other.address[1] &&
address[2] == other.address[2] &&
address[3] == other.address[3];
} }
uint8_t address[4]; constexpr IPv4Address mask(const IPv4Address& other) const
{
return IPv4Address(raw & other.raw);
}
union
{
uint8_t octets[4];
uint32_t raw;
} __attribute__((packed));
}; };
static_assert(sizeof(IPv4Address) == 4); static_assert(sizeof(IPv4Address) == 4);
@ -43,7 +45,7 @@ namespace BAN
{ {
constexpr hash_t operator()(IPv4Address ipv4) const constexpr hash_t operator()(IPv4Address ipv4) const
{ {
return hash<uint32_t>()(ipv4.as_u32()); return hash<uint32_t>()(ipv4.raw);
} }
}; };
@ -62,11 +64,11 @@ namespace BAN::Formatter
.upper = false, .upper = false,
}; };
print_argument(putc, ipv4.address[0], format); print_argument(putc, ipv4.octets[0], format);
for (size_t i = 1; i < 4; i++) for (size_t i = 1; i < 4; i++)
{ {
putc('.'); putc('.');
print_argument(putc, ipv4.address[i], format); print_argument(putc, ipv4.octets[i], format);
} }
} }

View File

@ -3,6 +3,8 @@
#include <BAN/Assert.h> #include <BAN/Assert.h>
#include <BAN/Traits.h> #include <BAN/Traits.h>
#include <stddef.h>
namespace BAN namespace BAN
{ {

View File

@ -45,6 +45,7 @@ namespace BAN
void clear(); void clear();
bool empty() const; bool empty() const;
size_type capacity() const;
size_type size() const; size_type size() const;
const T& front() const; const T& front() const;
@ -186,6 +187,12 @@ namespace BAN
return m_size == 0; return m_size == 0;
} }
template<typename T>
typename Queue<T>::size_type Queue<T>::capacity() const
{
return m_capacity;
}
template<typename T> template<typename T>
typename Queue<T>::size_type Queue<T>::size() const typename Queue<T>::size_type Queue<T>::size() const
{ {

View File

@ -216,6 +216,14 @@ namespace BAN
return m_index == detail::index<T, Ts...>(); return m_index == detail::index<T, Ts...>();
} }
template<typename T, typename... Args>
void emplace(Args&&... args) requires (can_have<T>())
{
clear();
m_index = detail::index<T, Ts...>();
new (m_storage) T(BAN::forward<Args>(args)...);
}
template<typename T> template<typename T>
void set(T&& value) requires (can_have<T>() && !is_lvalue_reference_v<T>) void set(T&& value) requires (can_have<T>() && !is_lvalue_reference_v<T>)
{ {

View File

@ -91,6 +91,8 @@ namespace BAN
bool valid() const { return m_link && m_link->valid(); } bool valid() const { return m_link && m_link->valid(); }
operator bool() const { return valid(); }
private: private:
WeakPtr(const RefPtr<WeakLink<T>>& link) WeakPtr(const RefPtr<WeakLink<T>>& link)
: m_link(link) : m_link(link)

View File

@ -16,6 +16,7 @@ set(KERNEL_SOURCES
kernel/CPUID.cpp kernel/CPUID.cpp
kernel/Credentials.cpp kernel/Credentials.cpp
kernel/Debug.cpp kernel/Debug.cpp
kernel/Device/DebugDevice.cpp
kernel/Device/Device.cpp kernel/Device/Device.cpp
kernel/Device/FramebufferDevice.cpp kernel/Device/FramebufferDevice.cpp
kernel/Device/NullDevice.cpp kernel/Device/NullDevice.cpp
@ -52,11 +53,12 @@ set(KERNEL_SOURCES
kernel/Networking/ARPTable.cpp kernel/Networking/ARPTable.cpp
kernel/Networking/E1000/E1000.cpp kernel/Networking/E1000/E1000.cpp
kernel/Networking/E1000/E1000E.cpp kernel/Networking/E1000/E1000E.cpp
kernel/Networking/IPv4.cpp kernel/Networking/IPv4Layer.cpp
kernel/Networking/NetworkInterface.cpp kernel/Networking/NetworkInterface.cpp
kernel/Networking/NetworkManager.cpp kernel/Networking/NetworkManager.cpp
kernel/Networking/NetworkSocket.cpp kernel/Networking/NetworkSocket.cpp
kernel/Networking/UDPSocket.cpp kernel/Networking/UDPSocket.cpp
kernel/Networking/UNIX/Socket.cpp
kernel/OpenFileDescriptorSet.cpp kernel/OpenFileDescriptorSet.cpp
kernel/Panic.cpp kernel/Panic.cpp
kernel/PCI.cpp kernel/PCI.cpp

View File

@ -5,7 +5,7 @@
#define dprintln(...) \ #define dprintln(...) \
do { \ do { \
Debug::DebugLock::lock(); \ Debug::DebugLock::lock(); \
Debug::print_prefix(__FILE__, __LINE__); \ Debug::print_prefix(__FILE__, __LINE__); \
BAN::Formatter::print(Debug::putchar, __VA_ARGS__); \ BAN::Formatter::print(Debug::putchar, __VA_ARGS__); \
BAN::Formatter::print(Debug::putchar, "\r\n"); \ BAN::Formatter::print(Debug::putchar, "\r\n"); \
Debug::DebugLock::unlock(); \ Debug::DebugLock::unlock(); \

View File

@ -0,0 +1,28 @@
#include <kernel/Device/Device.h>
namespace Kernel
{
class DebugDevice : public CharacterDevice
{
public:
static BAN::ErrorOr<BAN::RefPtr<DebugDevice>> create(mode_t, uid_t, gid_t);
virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return "debug"sv; }
protected:
DebugDevice(mode_t mode, uid_t uid, gid_t gid, dev_t rdev)
: CharacterDevice(mode, uid, gid)
, m_rdev(rdev)
{ }
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return 0; }
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override;
private:
const dev_t m_rdev;
};
}

View File

@ -100,9 +100,12 @@ namespace Kernel
BAN::ErrorOr<BAN::String> link_target(); BAN::ErrorOr<BAN::String> link_target();
// Socket API // Socket API
BAN::ErrorOr<long> accept(sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<ssize_t> sendto(const sys_sendto_t*); BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<ssize_t> recvfrom(sys_recvfrom_t*); BAN::ErrorOr<void> listen(int backlog);
BAN::ErrorOr<size_t> sendto(const sys_sendto_t*);
BAN::ErrorOr<size_t> recvfrom(sys_recvfrom_t*);
// General API // General API
BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer); BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer);
@ -128,9 +131,12 @@ namespace Kernel
virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); }
// Socket API // Socket API
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<ssize_t> sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<ssize_t> recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); }
// General API // General API
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); }

View File

@ -0,0 +1,20 @@
#pragma once
namespace Kernel
{
enum class SocketDomain
{
INET,
INET6,
UNIX,
};
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
}

View File

@ -80,6 +80,25 @@ namespace Kernel
friend class TmpInode; friend class TmpInode;
}; };
class TmpSocketInode : public TmpInode
{
public:
static BAN::ErrorOr<BAN::RefPtr<TmpSocketInode>> create_new(TmpFileSystem&, mode_t, uid_t, gid_t);
~TmpSocketInode();
protected:
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<void> truncate_impl(size_t) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<void> chmod_impl(mode_t) override;
virtual bool has_data_impl() const override { return true; }
private:
TmpSocketInode(TmpFileSystem&, ino_t, const TmpInodeInfo&);
friend class TmpInode;
};
class TmpSymlinkInode : public TmpInode class TmpSymlinkInode : public TmpInode
{ {
public: public:

View File

@ -31,6 +31,7 @@ namespace Kernel
public: public:
static BAN::ErrorOr<BAN::UniqPtr<ARPTable>> create(); static BAN::ErrorOr<BAN::UniqPtr<ARPTable>> create();
~ARPTable();
BAN::ErrorOr<BAN::MACAddress> get_mac_from_ipv4(NetworkInterface&, BAN::IPv4Address); BAN::ErrorOr<BAN::MACAddress> get_mac_from_ipv4(NetworkInterface&, BAN::IPv4Address);

View File

@ -42,7 +42,7 @@ namespace Kernel
uint32_t read32(uint16_t reg); uint32_t read32(uint16_t reg);
void write32(uint16_t reg, uint32_t value); void write32(uint16_t reg, uint32_t value);
virtual BAN::ErrorOr<void> send_raw_bytes(BAN::ConstByteSpan) override; virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override;
private: private:
BAN::ErrorOr<void> read_mac_address(); BAN::ErrorOr<void> read_mac_address();

View File

@ -0,0 +1,24 @@
#pragma once
#include <BAN/Endianness.h>
#include <stdint.h>
namespace Kernel
{
struct ICMPHeader
{
uint8_t type;
uint8_t code;
BAN::NetworkEndian<uint16_t> checksum;
BAN::NetworkEndian<uint32_t> rest;
};
static_assert(sizeof(ICMPHeader) == 8);
enum ICMPType : uint8_t
{
EchoReply = 0x00,
EchoRequest = 0x08,
};
}

View File

@ -1,38 +0,0 @@
#pragma once
#include <BAN/ByteSpan.h>
#include <BAN/Endianness.h>
#include <BAN/IPv4.h>
#include <BAN/Vector.h>
namespace Kernel
{
struct IPv4Header
{
uint8_t version_IHL;
uint8_t DSCP_ECN;
BAN::NetworkEndian<uint16_t> total_length { 0 };
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags_frament { 0 };
uint8_t time_to_live;
uint8_t protocol;
BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address;
BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
return 0xFFFF
- (((uint16_t)version_IHL << 8) | DSCP_ECN)
- total_length
- identification
- flags_frament
- (((uint16_t)time_to_live << 8) | protocol);
}
};
static_assert(sizeof(IPv4Header) == 20);
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol);
}

View File

@ -0,0 +1,105 @@
#pragma once
#include <BAN/Array.h>
#include <BAN/ByteSpan.h>
#include <BAN/CircularQueue.h>
#include <BAN/Endianness.h>
#include <BAN/IPv4.h>
#include <BAN/NoCopyMove.h>
#include <BAN/UniqPtr.h>
#include <kernel/Networking/ARPTable.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
#include <kernel/Networking/NetworkSocket.h>
#include <kernel/Process.h>
#include <kernel/SpinLock.h>
namespace Kernel
{
struct IPv4Header
{
uint8_t version_IHL;
uint8_t DSCP_ECN;
BAN::NetworkEndian<uint16_t> total_length { 0 };
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags_frament { 0 };
uint8_t time_to_live;
uint8_t protocol;
BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address;
BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
total_sum -= checksum;
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return ~(uint16_t)total_sum;
}
constexpr bool is_valid_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return total_sum == 0xFFFF;
}
};
static_assert(sizeof(IPv4Header) == 20);
class IPv4Layer : public NetworkLayer
{
BAN_NON_COPYABLE(IPv4Layer);
BAN_NON_MOVABLE(IPv4Layer);
public:
static BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> create();
~IPv4Layer();
ARPTable& arp_table() { return *m_arp_table; }
void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan);
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) override;
private:
IPv4Layer();
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const;
void packet_handle_task();
BAN::ErrorOr<void> handle_ipv4_packet(NetworkInterface&, BAN::ByteSpan);
private:
struct PendingIPv4Packet
{
NetworkInterface& interface;
};
private:
SpinLock m_lock;
BAN::UniqPtr<ARPTable> m_arp_table;
Process* m_process { nullptr };
static constexpr size_t pending_packet_buffer_size = 128 * PAGE_SIZE;
BAN::UniqPtr<VirtualRange> m_pending_packet_buffer;
BAN::CircularQueue<PendingIPv4Packet, 128> m_pending_packets;
Semaphore m_pending_semaphore;
size_t m_pending_total_size { 0 };
BAN::HashMap<int, BAN::WeakPtr<NetworkSocket>> m_bound_sockets;
friend class BAN::UniqPtr<IPv4Layer>;
};
}

View File

@ -1,10 +1,10 @@
#pragma once #pragma once
#include <BAN/Errors.h>
#include <BAN/ByteSpan.h> #include <BAN/ByteSpan.h>
#include <BAN/Errors.h>
#include <BAN/IPv4.h>
#include <BAN/MAC.h> #include <BAN/MAC.h>
#include <kernel/Device/Device.h> #include <kernel/Device/Device.h>
#include <kernel/Networking/IPv4.h>
namespace Kernel namespace Kernel
{ {
@ -46,16 +46,16 @@ namespace Kernel
BAN::IPv4Address get_netmask() const { return m_netmask; } BAN::IPv4Address get_netmask() const { return m_netmask; }
void set_netmask(BAN::IPv4Address new_netmask) { m_netmask = new_netmask; } void set_netmask(BAN::IPv4Address new_netmask) { m_netmask = new_netmask; }
BAN::IPv4Address get_gateway() const { return m_gateway; }
void set_gateway(BAN::IPv4Address new_gateway) { m_gateway = new_gateway; }
virtual bool link_up() = 0; virtual bool link_up() = 0;
virtual int link_speed() = 0; virtual int link_speed() = 0;
size_t interface_header_size() const;
void add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination);
virtual dev_t rdev() const override { return m_rdev; } virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return m_name; } virtual BAN::StringView name() const override { return m_name; }
virtual BAN::ErrorOr<void> send_raw_bytes(BAN::ConstByteSpan) = 0; virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) = 0;
private: private:
const Type m_type; const Type m_type;
@ -65,6 +65,7 @@ namespace Kernel
BAN::IPv4Address m_ipv4_address { 0 }; BAN::IPv4Address m_ipv4_address { 0 };
BAN::IPv4Address m_netmask { 0 }; BAN::IPv4Address m_netmask { 0 };
BAN::IPv4Address m_gateway { 0 };
}; };
} }

View File

@ -0,0 +1,25 @@
#pragma once
#include <kernel/Networking/NetworkInterface.h>
namespace Kernel
{
class NetworkSocket;
enum class SocketType;
class NetworkLayer
{
public:
virtual ~NetworkLayer() {}
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) = 0;
protected:
NetworkLayer() = default;
};
}

View File

@ -2,7 +2,7 @@
#include <BAN/Vector.h> #include <BAN/Vector.h>
#include <kernel/FS/TmpFS/FileSystem.h> #include <kernel/FS/TmpFS/FileSystem.h>
#include <kernel/Networking/ARPTable.h> #include <kernel/Networking/IPv4Layer.h>
#include <kernel/Networking/NetworkInterface.h> #include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkSocket.h> #include <kernel/Networking/NetworkSocket.h>
#include <kernel/PCI.h> #include <kernel/PCI.h>
@ -17,26 +17,15 @@ namespace Kernel
BAN_NON_COPYABLE(NetworkManager); BAN_NON_COPYABLE(NetworkManager);
BAN_NON_MOVABLE(NetworkManager); BAN_NON_MOVABLE(NetworkManager);
public:
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
public: public:
static BAN::ErrorOr<void> initialize(); static BAN::ErrorOr<void> initialize();
static NetworkManager& get(); static NetworkManager& get();
ARPTable& arp_table() { return *m_arp_table; }
BAN::ErrorOr<void> add_interface(PCI::Device& pci_device); BAN::ErrorOr<void> add_interface(PCI::Device& pci_device);
void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>); BAN::Vector<BAN::RefPtr<NetworkInterface>> interfaces() { return m_interfaces; }
BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>);
BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> create_socket(SocketType, mode_t, uid_t, gid_t); BAN::ErrorOr<BAN::RefPtr<TmpInode>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t);
void on_receive(NetworkInterface&, BAN::ConstByteSpan); void on_receive(NetworkInterface&, BAN::ConstByteSpan);
@ -44,9 +33,8 @@ namespace Kernel
NetworkManager(); NetworkManager();
private: private:
BAN::UniqPtr<ARPTable> m_arp_table; BAN::UniqPtr<IPv4Layer> m_ipv4_layer;
BAN::Vector<BAN::RefPtr<NetworkInterface>> m_interfaces; BAN::Vector<BAN::RefPtr<NetworkInterface>> m_interfaces;
BAN::HashMap<int, BAN::WeakPtr<NetworkSocket>> m_bound_sockets;
}; };
} }

View File

@ -1,8 +1,10 @@
#pragma once #pragma once
#include <BAN/WeakPtr.h> #include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h> #include <kernel/FS/TmpFS/Inode.h>
#include <kernel/Networking/NetworkInterface.h> #include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
#include <netinet/in.h> #include <netinet/in.h>
@ -11,6 +13,7 @@ namespace Kernel
enum NetworkProtocol : uint8_t enum NetworkProtocol : uint8_t
{ {
ICMP = 0x01,
UDP = 0x11, UDP = 0x11,
}; };
@ -26,26 +29,29 @@ namespace Kernel
void bind_interface_and_port(NetworkInterface*, uint16_t port); void bind_interface_and_port(NetworkInterface*, uint16_t port);
~NetworkSocket(); ~NetworkSocket();
NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; }
virtual size_t protocol_header_size() const = 0; virtual size_t protocol_header_size() const = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) = 0; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) = 0;
virtual NetworkProtocol protocol() const = 0; virtual NetworkProtocol protocol() const = 0;
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0;
protected: protected:
NetworkSocket(mode_t mode, uid_t uid, gid_t gid); NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0; virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0;
virtual void on_close_impl() override; virtual void on_close_impl() override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override; virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<ssize_t> sendto_impl(const sys_sendto_t*) override; virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<ssize_t> recvfrom_impl(sys_recvfrom_t*) override; virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override; virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override;
protected: protected:
NetworkLayer& m_network_layer;
NetworkInterface* m_interface = nullptr; NetworkInterface* m_interface = nullptr;
uint16_t m_port = PORT_NONE; uint16_t m_port = PORT_NONE;
}; };

View File

@ -22,10 +22,10 @@ namespace Kernel
class UDPSocket final : public NetworkSocket class UDPSocket final : public NetworkSocket
{ {
public: public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(mode_t, uid_t, gid_t); static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&);
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); } virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) override; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) override;
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
protected: protected:
@ -33,7 +33,7 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override; virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override;
private: private:
UDPSocket(mode_t, uid_t, gid_t); UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&);
struct PacketInfo struct PacketInfo
{ {

View File

@ -0,0 +1,67 @@
#pragma once
#include <BAN/Queue.h>
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
namespace Kernel
{
class UnixDomainSocket final : public TmpInode, public BAN::Weakable<UnixDomainSocket>
{
BAN_NON_COPYABLE(UnixDomainSocket);
BAN_NON_MOVABLE(UnixDomainSocket);
public:
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);
BAN::ErrorOr<void> add_packet(BAN::ConstByteSpan);
bool is_bound() const { return !m_bound_path.empty(); }
bool is_bound_to_unused() const { return m_bound_path == "X"sv; }
bool is_streaming() const;
private:
struct ConnectionInfo
{
bool listening { false };
BAN::Atomic<bool> connection_done { false };
BAN::WeakPtr<UnixDomainSocket> connection;
BAN::Queue<BAN::RefPtr<UnixDomainSocket>> pending_connections;
Semaphore pending_semaphore;
SpinLock pending_lock;
};
struct ConnectionlessInfo
{
};
private:
const SocketType m_socket_type;
BAN::String m_bound_path;
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;
BAN::CircularQueue<size_t, 128> m_packet_sizes;
size_t m_packet_size_total { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
Semaphore m_packet_semaphore;
friend class BAN::RefPtr<UnixDomainSocket>;
};
}

View File

@ -21,6 +21,7 @@ namespace Kernel
BAN::ErrorOr<void> clone_from(const OpenFileDescriptorSet&); BAN::ErrorOr<void> clone_from(const OpenFileDescriptorSet&);
BAN::ErrorOr<int> open(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags); BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags);
BAN::ErrorOr<int> socket(int domain, int type, int protocol); BAN::ErrorOr<int> socket(int domain, int type, int protocol);

View File

@ -62,6 +62,8 @@ namespace Kernel
bool is_session_leader() const { return pid() == sid(); } bool is_session_leader() const { return pid() == sid(); }
const char* name() const { return m_cmdline.empty() ? "" : m_cmdline.front().data(); }
const Credentials& credentials() const { return m_credentials; } const Credentials& credentials() const { return m_credentials; }
BAN::ErrorOr<long> sys_exit(int status); BAN::ErrorOr<long> sys_exit(int status);
@ -93,8 +95,10 @@ namespace Kernel
BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); } BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); }
BAN::ErrorOr<long> sys_getpgid(pid_t); BAN::ErrorOr<long> sys_getpgid(pid_t);
BAN::ErrorOr<long> open_inode(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode); BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode);
BAN::ErrorOr<long> open_file(BAN::StringView path, int, mode_t = 0); BAN::ErrorOr<long> open_file(BAN::StringView path, int oflag, mode_t = 0);
BAN::ErrorOr<long> sys_open(const char* path, int, mode_t); BAN::ErrorOr<long> sys_open(const char* path, int, mode_t);
BAN::ErrorOr<long> sys_openat(int, const char* path, int, mode_t); BAN::ErrorOr<long> sys_openat(int, const char* path, int, mode_t);
BAN::ErrorOr<long> sys_close(int fd); BAN::ErrorOr<long> sys_close(int fd);
@ -113,7 +117,10 @@ namespace Kernel
BAN::ErrorOr<long> sys_chown(const char*, uid_t, gid_t); BAN::ErrorOr<long> sys_chown(const char*, uid_t, gid_t);
BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol); BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol);
BAN::ErrorOr<long> sys_accept(int socket, sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<long> sys_bind(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr<long> sys_bind(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_connect(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_listen(int socket, int backlog);
BAN::ErrorOr<long> sys_sendto(const sys_sendto_t*); BAN::ErrorOr<long> sys_sendto(const sys_sendto_t*);
BAN::ErrorOr<long> sys_recvfrom(sys_recvfrom_t*); BAN::ErrorOr<long> sys_recvfrom(sys_recvfrom_t*);
@ -175,6 +182,8 @@ namespace Kernel
// Return false if access was page violation (segfault) // Return false if access was page violation (segfault)
BAN::ErrorOr<bool> allocate_page_for_demand_paging(vaddr_t addr); BAN::ErrorOr<bool> allocate_page_for_demand_paging(vaddr_t addr);
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
private: private:
Process(const Credentials&, pid_t pid, pid_t parent, pid_t sid, pid_t pgrp); Process(const Credentials&, pid_t pid, pid_t parent, pid_t sid, pid_t pgrp);
static Process* create_process(const Credentials&, pid_t parent, pid_t sid = 0, pid_t pgrp = 0); static Process* create_process(const Credentials&, pid_t parent, pid_t sid = 0, pid_t pgrp = 0);
@ -184,8 +193,6 @@ namespace Kernel
BAN::ErrorOr<int> block_until_exit(pid_t pid); BAN::ErrorOr<int> block_until_exit(pid_t pid);
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
BAN::ErrorOr<void> validate_string_access(const char*); BAN::ErrorOr<void> validate_string_access(const char*);
BAN::ErrorOr<void> validate_pointer_access(const void*, size_t); BAN::ErrorOr<void> validate_pointer_access(const void*, size_t);

View File

@ -0,0 +1,32 @@
#include <kernel/Device/DebugDevice.h>
#include <kernel/FS/DevFS/FileSystem.h>
#include <kernel/Process.h>
#include <kernel/Timer/Timer.h>
namespace Kernel
{
BAN::ErrorOr<BAN::RefPtr<DebugDevice>> DebugDevice::create(mode_t mode, uid_t uid, gid_t gid)
{
auto* result = new DebugDevice(mode, uid, gid, DevFileSystem::get().get_next_dev());
if (result == nullptr)
return BAN::Error::from_errno(ENOMEM);
return BAN::RefPtr<DebugDevice>::adopt(result);
}
BAN::ErrorOr<size_t> DebugDevice::write_impl(off_t, BAN::ConstByteSpan buffer)
{
auto ms_since_boot = SystemTimer::get().ms_since_boot();
Debug::DebugLock::lock();
BAN::Formatter::print(Debug::putchar, "[{5}.{3}] {}: ",
ms_since_boot / 1000,
ms_since_boot % 1000,
Kernel::Process::current().name()
);
for (size_t i = 0; i < buffer.size(); i++)
Debug::putchar(buffer[i]);
Debug::DebugLock::unlock();
return buffer.size();
}
}

View File

@ -1,4 +1,5 @@
#include <BAN/ScopeGuard.h> #include <BAN/ScopeGuard.h>
#include <kernel/Device/DebugDevice.h>
#include <kernel/Device/FramebufferDevice.h> #include <kernel/Device/FramebufferDevice.h>
#include <kernel/Device/NullDevice.h> #include <kernel/Device/NullDevice.h>
#include <kernel/Device/ZeroDevice.h> #include <kernel/Device/ZeroDevice.h>
@ -22,6 +23,7 @@ namespace Kernel
ASSERT(s_instance); ASSERT(s_instance);
MUST(s_instance->TmpFileSystem::initialize(0755, 0, 0)); MUST(s_instance->TmpFileSystem::initialize(0755, 0, 0));
s_instance->add_device(MUST(DebugDevice::create(0666, 0, 0)));
s_instance->add_device(MUST(NullDevice::create(0666, 0, 0))); s_instance->add_device(MUST(NullDevice::create(0666, 0, 0)));
s_instance->add_device(MUST(ZeroDevice::create(0666, 0, 0))); s_instance->add_device(MUST(ZeroDevice::create(0666, 0, 0)));
} }

View File

@ -116,6 +116,14 @@ namespace Kernel
return link_target_impl(); return link_target_impl();
} }
BAN::ErrorOr<long> Inode::accept(sockaddr* address, socklen_t* address_len)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return accept_impl(address, address_len);
}
BAN::ErrorOr<void> Inode::bind(const sockaddr* address, socklen_t address_len) BAN::ErrorOr<void> Inode::bind(const sockaddr* address, socklen_t address_len)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
@ -124,7 +132,23 @@ namespace Kernel
return bind_impl(address, address_len); return bind_impl(address, address_len);
} }
BAN::ErrorOr<ssize_t> Inode::sendto(const sys_sendto_t* arguments) BAN::ErrorOr<void> Inode::connect(const sockaddr* address, socklen_t address_len)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return connect_impl(address, address_len);
}
BAN::ErrorOr<void> Inode::listen(int backlog)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return listen_impl(backlog);
}
BAN::ErrorOr<size_t> Inode::sendto(const sys_sendto_t* arguments)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (!mode().ifsock()) if (!mode().ifsock())
@ -132,7 +156,7 @@ namespace Kernel
return sendto_impl(arguments); return sendto_impl(arguments);
}; };
BAN::ErrorOr<ssize_t> Inode::recvfrom(sys_recvfrom_t* arguments) BAN::ErrorOr<size_t> Inode::recvfrom(sys_recvfrom_t* arguments)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (!mode().ifsock()) if (!mode().ifsock())

View File

@ -215,6 +215,36 @@ namespace Kernel
return {}; return {};
} }
/* SOCKET INODE */
BAN::ErrorOr<BAN::RefPtr<TmpSocketInode>> TmpSocketInode::create_new(TmpFileSystem& fs, mode_t mode, uid_t uid, gid_t gid)
{
auto info = create_inode_info(Mode::IFSOCK | mode, uid, gid);
ino_t ino = TRY(fs.allocate_inode(info));
auto* inode_ptr = new TmpSocketInode(fs, ino, info);
if (inode_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
return BAN::RefPtr<TmpSocketInode>::adopt(inode_ptr);
}
TmpSocketInode::TmpSocketInode(TmpFileSystem& fs, ino_t ino, const TmpInodeInfo& info)
: TmpInode(fs, ino, info)
{
ASSERT(mode().ifsock());
}
TmpSocketInode::~TmpSocketInode()
{
}
BAN::ErrorOr<void> TmpSocketInode::chmod_impl(mode_t new_mode)
{
m_inode_info.mode = new_mode;
return {};
}
/* SYMLINK INODE */ /* SYMLINK INODE */
BAN::ErrorOr<BAN::RefPtr<TmpSymlinkInode>> TmpSymlinkInode::create_new(TmpFileSystem& fs, mode_t mode, uid_t uid, gid_t gid, BAN::StringView target) BAN::ErrorOr<BAN::RefPtr<TmpSymlinkInode>> TmpSymlinkInode::create_new(TmpFileSystem& fs, mode_t mode, uid_t uid, gid_t gid, BAN::StringView target)
@ -446,7 +476,19 @@ namespace Kernel
BAN::ErrorOr<void> TmpDirectoryInode::create_file_impl(BAN::StringView name, mode_t mode, uid_t uid, gid_t gid) BAN::ErrorOr<void> TmpDirectoryInode::create_file_impl(BAN::StringView name, mode_t mode, uid_t uid, gid_t gid)
{ {
auto new_inode = TRY(TmpFileInode::create_new(m_fs, mode, uid, gid)); BAN::RefPtr<TmpInode> new_inode;
switch (mode & Mode::TYPE_MASK)
{
case Mode::IFREG:
new_inode = TRY(TmpFileInode::create_new(m_fs, mode, uid, gid));
break;
case Mode::IFSOCK:
new_inode = TRY(TmpSocketInode::create_new(m_fs, mode, uid, gid));
break;
default:
dprintln("Creating with mode {o} is not supported", mode);
return BAN::Error::from_errno(ENOTSUP);
}
TRY(link_inode(*new_inode, name)); TRY(link_inode(*new_inode, name));
return {}; return {};
} }

View File

@ -294,6 +294,9 @@ void* kmalloc(size_t size, size_t align, bool force_identity_map)
// currently kmalloc is always identity mapped // currently kmalloc is always identity mapped
(void)force_identity_map; (void)force_identity_map;
if (size == 0)
size = 1;
const kmalloc_info& info = s_kmalloc_info; const kmalloc_info& info = s_kmalloc_info;
ASSERT(is_power_of_two(align)); ASSERT(is_power_of_two(align));

View File

@ -1,5 +1,6 @@
#include <kernel/LockGuard.h> #include <kernel/LockGuard.h>
#include <kernel/Networking/ARPTable.h> #include <kernel/Networking/ARPTable.h>
#include <kernel/Scheduler.h>
#include <kernel/Timer/Timer.h> #include <kernel/Timer/Timer.h>
namespace Kernel namespace Kernel
@ -32,26 +33,31 @@ namespace Kernel
{ {
} }
ARPTable::~ARPTable()
{
if (m_process)
m_process->exit(0, SIGKILL);
m_process = nullptr;
}
BAN::ErrorOr<BAN::MACAddress> ARPTable::get_mac_from_ipv4(NetworkInterface& interface, BAN::IPv4Address ipv4_address) BAN::ErrorOr<BAN::MACAddress> ARPTable::get_mac_from_ipv4(NetworkInterface& interface, BAN::IPv4Address ipv4_address)
{ {
if (ipv4_address == s_broadcast_ipv4)
return s_broadcast_mac;
if (interface.get_ipv4_address() == BAN::IPv4Address { 0 })
return BAN::Error::from_errno(EINVAL);
if (interface.get_ipv4_address().mask(interface.get_netmask()) != ipv4_address.mask(interface.get_netmask()))
ipv4_address = interface.get_gateway();
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (ipv4_address == s_broadcast_ipv4)
return s_broadcast_mac;
if (m_arp_table.contains(ipv4_address)) if (m_arp_table.contains(ipv4_address))
return m_arp_table[ipv4_address]; return m_arp_table[ipv4_address];
} }
BAN::Vector<uint8_t> full_packet_buffer; ARPPacket arp_request;
TRY(full_packet_buffer.resize(sizeof(ARPPacket) + sizeof(EthernetHeader)));
auto full_packet = BAN::ByteSpan { full_packet_buffer.span() };
auto& ethernet_header = full_packet.as<EthernetHeader>();
ethernet_header.dst_mac = s_broadcast_mac;
ethernet_header.src_mac = interface.get_mac_address();
ethernet_header.ether_type = EtherType::ARP;
auto& arp_request = full_packet.slice(sizeof(EthernetHeader)).as<ARPPacket>();
arp_request.htype = 0x0001; arp_request.htype = 0x0001;
arp_request.ptype = EtherType::IPv4; arp_request.ptype = EtherType::IPv4;
arp_request.hlen = 0x06; arp_request.hlen = 0x06;
@ -62,9 +68,9 @@ namespace Kernel
arp_request.tha = {{ 0, 0, 0, 0, 0, 0 }}; arp_request.tha = {{ 0, 0, 0, 0, 0, 0 }};
arp_request.tpa = ipv4_address; arp_request.tpa = ipv4_address;
TRY(interface.send_raw_bytes(full_packet)); TRY(interface.send_bytes(s_broadcast_mac, EtherType::ARP, BAN::ConstByteSpan::from(arp_request)));
uint64_t timeout = SystemTimer::get().ms_since_boot() + 5'000; uint64_t timeout = SystemTimer::get().ms_since_boot() + 1'000;
while (SystemTimer::get().ms_since_boot() < timeout) while (SystemTimer::get().ms_since_boot() < timeout)
{ {
{ {
@ -72,10 +78,10 @@ namespace Kernel
if (m_arp_table.contains(ipv4_address)) if (m_arp_table.contains(ipv4_address))
return m_arp_table[ipv4_address]; return m_arp_table[ipv4_address];
} }
TRY(Thread::current().block_or_eintr(m_pending_semaphore)); Scheduler::get().reschedule();
} }
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(ETIMEDOUT);
} }
BAN::ErrorOr<void> ARPTable::handle_arp_packet(NetworkInterface& interface, const ARPPacket& packet) BAN::ErrorOr<void> ARPTable::handle_arp_packet(NetworkInterface& interface, const ARPPacket& packet)
@ -92,27 +98,17 @@ namespace Kernel
{ {
if (packet.tpa == interface.get_ipv4_address()) if (packet.tpa == interface.get_ipv4_address())
{ {
BAN::Vector<uint8_t> full_packet_buffer; ARPPacket arp_reply;
TRY(full_packet_buffer.resize(sizeof(ARPPacket) + sizeof(EthernetHeader))); arp_reply.htype = 0x0001;
auto full_packet = BAN::ByteSpan { full_packet_buffer.span() }; arp_reply.ptype = EtherType::IPv4;
arp_reply.hlen = 0x06;
auto& ethernet_header = full_packet.as<EthernetHeader>(); arp_reply.plen = 0x04;
ethernet_header.dst_mac = packet.sha; arp_reply.oper = ARPOperation::Reply;
ethernet_header.src_mac = interface.get_mac_address(); arp_reply.sha = interface.get_mac_address();
ethernet_header.ether_type = EtherType::ARP; arp_reply.spa = interface.get_ipv4_address();
arp_reply.tha = packet.sha;
auto& arp_request = full_packet.slice(sizeof(EthernetHeader)).as<ARPPacket>(); arp_reply.tpa = packet.spa;
arp_request.htype = 0x0001; TRY(interface.send_bytes(packet.sha, EtherType::ARP, BAN::ConstByteSpan::from(arp_reply)));
arp_request.ptype = EtherType::IPv4;
arp_request.hlen = 0x06;
arp_request.plen = 0x04;
arp_request.oper = ARPOperation::Reply;
arp_request.sha = interface.get_mac_address();
arp_request.spa = interface.get_ipv4_address();
arp_request.tha = packet.sha;
arp_request.tpa = packet.spa;
TRY(interface.send_raw_bytes(full_packet));
} }
break; break;
} }

View File

@ -256,19 +256,26 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<void> E1000::send_raw_bytes(BAN::ConstByteSpan buffer)
BAN::ErrorOr<void> E1000::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer)
{ {
ASSERT_LTE(buffer.size(), E1000_TX_BUFFER_SIZE); ASSERT_LTE(buffer.size() + sizeof(EthernetHeader), E1000_TX_BUFFER_SIZE);
CriticalScope _; CriticalScope _;
size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT; size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT;
auto* tx_buffer = reinterpret_cast<void*>(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * tx_current); auto* tx_buffer = reinterpret_cast<uint8_t*>(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * tx_current);
memcpy(tx_buffer, buffer.data(), buffer.size());
auto& ethernet_header = *reinterpret_cast<EthernetHeader*>(tx_buffer);
ethernet_header.dst_mac = destination;
ethernet_header.src_mac = get_mac_address();
ethernet_header.ether_type = protocol;
memcpy(tx_buffer + sizeof(EthernetHeader), buffer.data(), buffer.size());
auto& descriptor = reinterpret_cast<volatile e1000_tx_desc*>(m_tx_descriptor_region->vaddr())[tx_current]; auto& descriptor = reinterpret_cast<volatile e1000_tx_desc*>(m_tx_descriptor_region->vaddr())[tx_current];
descriptor.length = buffer.size(); descriptor.length = sizeof(EthernetHeader) + buffer.size();
descriptor.status = 0; descriptor.status = 0;
descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS; descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS;

View File

@ -1,22 +0,0 @@
#include <BAN/Endianness.h>
#include <kernel/Networking/IPv4.h>
namespace Kernel
{
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol)
{
auto& header = packet.as<IPv4Header>();
header.version_IHL = 0x45;
header.DSCP_ECN = 0x10;
header.total_length = packet.size();
header.identification = 1;
header.flags_frament = 0x00;
header.time_to_live = 0x40;
header.protocol = protocol;
header.checksum = header.calculate_checksum();
header.src_address = src_ipv4;
header.dst_address = dst_ipv4;
}
}

View File

@ -0,0 +1,284 @@
#include <kernel/Memory/Heap.h>
#include <kernel/Memory/PageTable.h>
#include <kernel/Networking/ICMP.h>
#include <kernel/Networking/IPv4Layer.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UDPSocket.h>
#include <netinet/in.h>
#define DEBUG_IPV4 0
namespace Kernel
{
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
{
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
ipv4_manager->m_process = Process::create_kernel(
[](void* ipv4_manager_ptr)
{
auto& ipv4_manager = *reinterpret_cast<IPv4Layer*>(ipv4_manager_ptr);
ipv4_manager.packet_handle_task();
}, ipv4_manager.ptr()
);
ASSERT(ipv4_manager->m_process);
ipv4_manager->m_pending_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
pending_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
ipv4_manager->m_arp_table = TRY(ARPTable::create());
return ipv4_manager;
}
IPv4Layer::IPv4Layer()
{ }
IPv4Layer::~IPv4Layer()
{
if (m_process)
m_process->exit(0, SIGKILL);
m_process = nullptr;
}
void IPv4Layer::add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const
{
auto& header = packet.as<IPv4Header>();
header.version_IHL = 0x45;
header.DSCP_ECN = 0x00;
header.total_length = packet.size();
header.identification = 1;
header.flags_frament = 0x00;
header.time_to_live = 0x40;
header.protocol = protocol;
header.src_address = src_ipv4;
header.dst_address = dst_ipv4;
header.checksum = header.calculate_checksum();
}
void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
LockGuard _(m_lock);
if (m_bound_sockets.contains(port))
{
ASSERT(m_bound_sockets[port].valid());
ASSERT(m_bound_sockets[port].lock() == socket);
m_bound_sockets.remove(port);
}
NetworkManager::get().TmpFileSystem::remove_from_cache(socket);
}
BAN::ErrorOr<void> IPv4Layer::bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
if (NetworkManager::get().interfaces().empty())
return BAN::Error::from_errno(EADDRNOTAVAIL);
LockGuard _(m_lock);
if (port == NetworkSocket::PORT_NONE)
{
for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++)
{
if (!m_bound_sockets.contains(temp))
{
port = temp;
break;
}
}
if (port == NetworkSocket::PORT_NONE)
{
dwarnln("No ports available");
return BAN::Error::from_errno(EAGAIN);
}
}
if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket));
// FIXME: actually determine proper interface
auto interface = NetworkManager::get().interfaces().front();
socket->bind_interface_and_port(interface.ptr(), port);
return {};
}
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, const sys_sendto_t* arguments)
{
if (arguments->dest_addr->sa_family != AF_INET)
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(arguments->dest_addr);
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4));
BAN::Vector<uint8_t> packet_buffer;
TRY(packet_buffer.resize(arguments->length + sizeof(IPv4Header) + socket.protocol_header_size()));
auto packet = BAN::ByteSpan { packet_buffer.span() };
memcpy(
packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(),
arguments->message,
arguments->length
);
socket.add_protocol_header(
packet.slice(sizeof(IPv4Header)),
dst_port
);
add_ipv4_header(
packet,
socket.interface().get_ipv4_address(),
dst_ipv4,
socket.protocol()
);
TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet));
return arguments->length;
}
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet)
{
uint32_t checksum = 0;
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
while (checksum >> 16)
checksum = (checksum >> 16) | (checksum & 0xFFFF);
return ~(uint16_t)checksum;
}
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet)
{
auto& ipv4_header = packet.as<const IPv4Header>();
auto ipv4_data = packet.slice(sizeof(IPv4Header));
ASSERT(ipv4_header.is_valid_checksum());
auto src_ipv4 = ipv4_header.src_address;
switch (ipv4_header.protocol)
{
case NetworkProtocol::ICMP:
{
auto& icmp_header = ipv4_data.as<const ICMPHeader>();
switch (icmp_header.type)
{
case ICMPType::EchoRequest:
{
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4));
auto& reply_icmp_header = ipv4_data.as<ICMPHeader>();
reply_icmp_header.type = ICMPType::EchoReply;
reply_icmp_header.checksum = 0;
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data);
add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP);
TRY(interface.send_bytes(dst_mac, EtherType::IPv4, packet));
break;
}
default:
dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type);
break;
}
break;
}
case NetworkProtocol::UDP:
{
auto& udp_header = ipv4_data.as<const UDPHeader>();
uint16_t src_port = udp_header.src_port;
uint16_t dst_port = udp_header.dst_port;
LockGuard _(m_lock);
if (!m_bound_sockets.contains(dst_port) || !m_bound_sockets[dst_port].valid())
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
auto udp_data = ipv4_data.slice(sizeof(UDPHeader));
m_bound_sockets[dst_port].lock()->add_packet(udp_data, src_ipv4, src_port);
break;
}
default:
dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol);
break;
}
return {};
}
void IPv4Layer::packet_handle_task()
{
for (;;)
{
BAN::Optional<PendingIPv4Packet> pending;
{
CriticalScope _;
if (!m_pending_packets.empty())
{
pending = m_pending_packets.front();
m_pending_packets.pop();
}
}
if (!pending.has_value())
{
m_pending_semaphore.block();
continue;
}
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
const size_t ipv4_packet_size = reinterpret_cast<const IPv4Header*>(buffer_start)->total_length;
if (auto ret = handle_ipv4_packet(pending->interface, BAN::ByteSpan(buffer_start, ipv4_packet_size)); ret.is_error())
dwarnln("{}", ret.error());
CriticalScope _;
m_pending_total_size -= ipv4_packet_size;
if (m_pending_total_size)
memmove(buffer_start, buffer_start + ipv4_packet_size, m_pending_total_size);
}
}
void IPv4Layer::add_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer)
{
if (m_pending_packets.full())
{
dwarnln("IPv4 packet queue full");
return;
}
if (m_pending_total_size + buffer.size() > m_pending_packet_buffer->size())
{
dwarnln("IPv4 packet queue full");
return;
}
auto& ipv4_header = buffer.as<const IPv4Header>();
if (!ipv4_header.is_valid_checksum())
{
dwarnln("Invalid IPv4 packet");
return;
}
if (ipv4_header.total_length > buffer.size())
{
dwarnln("Too short IPv4 packet");
return;
}
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
memcpy(buffer_start + m_pending_total_size, buffer.data(), ipv4_header.total_length);
m_pending_total_size += ipv4_header.total_length;
m_pending_packets.push({ .interface = interface });
m_pending_semaphore.unblock();
}
}

View File

@ -32,19 +32,4 @@ namespace Kernel
m_name[3] = minor(m_rdev) + '0'; m_name[3] = minor(m_rdev) + '0';
} }
size_t NetworkInterface::interface_header_size() const
{
ASSERT(m_type == Type::Ethernet);
return sizeof(EthernetHeader);
}
void NetworkInterface::add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination)
{
ASSERT(m_type == Type::Ethernet);
auto& header = packet.as<EthernetHeader>();
header.dst_mac = destination;
header.src_mac = get_mac_address();
header.ether_type = 0x0800;
}
} }

View File

@ -3,9 +3,12 @@
#include <kernel/FS/DevFS/FileSystem.h> #include <kernel/FS/DevFS/FileSystem.h>
#include <kernel/Networking/E1000/E1000.h> #include <kernel/Networking/E1000/E1000.h>
#include <kernel/Networking/E1000/E1000E.h> #include <kernel/Networking/E1000/E1000E.h>
#include <kernel/Networking/IPv4.h> #include <kernel/Networking/ICMP.h>
#include <kernel/Networking/NetworkManager.h> #include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UDPSocket.h> #include <kernel/Networking/UDPSocket.h>
#include <kernel/Networking/UNIX/Socket.h>
#define DEBUG_ETHERTYPE 0
namespace Kernel namespace Kernel
{ {
@ -19,8 +22,8 @@ namespace Kernel
if (manager_ptr == nullptr) if (manager_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM); return BAN::Error::from_errno(ENOMEM);
auto manager = BAN::UniqPtr<NetworkManager>::adopt(manager_ptr); auto manager = BAN::UniqPtr<NetworkManager>::adopt(manager_ptr);
manager->m_arp_table = TRY(ARPTable::create());
TRY(manager->TmpFileSystem::initialize(0777, 0, 0)); TRY(manager->TmpFileSystem::initialize(0777, 0, 0));
manager->m_ipv4_layer = TRY(IPv4Layer::create());
s_instance = BAN::move(manager); s_instance = BAN::move(manager);
return {}; return {};
} }
@ -68,45 +71,50 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> NetworkManager::create_socket(SocketType type, mode_t mode, uid_t uid, gid_t gid) BAN::ErrorOr<BAN::RefPtr<TmpInode>> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid)
{ {
switch (domain)
{
case SocketDomain::INET:
{
if (type != SocketType::DGRAM)
return BAN::Error::from_errno(EPROTOTYPE);
break;
}
case SocketDomain::UNIX:
{
break;
}
default:
return BAN::Error::from_errno(EAFNOSUPPORT);
}
ASSERT((mode & Inode::Mode::TYPE_MASK) == 0); ASSERT((mode & Inode::Mode::TYPE_MASK) == 0);
mode |= Inode::Mode::IFSOCK;
if (type != SocketType::DGRAM) auto inode_info = create_inode_info(mode, uid, gid);
return BAN::Error::from_errno(EPROTOTYPE); ino_t ino = TRY(allocate_inode(inode_info));
auto udp_socket = TRY(UDPSocket::create(mode | Inode::Mode::IFSOCK, uid, gid)); BAN::RefPtr<TmpInode> socket;
return BAN::RefPtr<NetworkSocket>(udp_socket); switch (domain)
}
void NetworkManager::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
if (m_bound_sockets.contains(port))
{ {
ASSERT(m_bound_sockets[port].valid()); case SocketDomain::INET:
ASSERT(m_bound_sockets[port].lock() == socket); {
m_bound_sockets.remove(port); if (type == SocketType::DGRAM)
} socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info));
NetworkManager::get().remove_from_cache(socket); break;
} }
case SocketDomain::UNIX:
BAN::ErrorOr<void> NetworkManager::bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket) {
{ socket = TRY(UnixDomainSocket::create(type, ino, inode_info));
if (m_interfaces.empty()) break;
return BAN::Error::from_errno(EADDRNOTAVAIL); }
default:
if (port != NetworkSocket::PORT_NONE) ASSERT_NOT_REACHED();
{
if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket));
} }
// FIXME: actually determine proper interface ASSERT(socket);
auto interface = m_interfaces.front(); return socket;
socket->bind_interface_and_port(interface.ptr(), port);
return {};
} }
void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet) void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet)
@ -117,41 +125,16 @@ namespace Kernel
{ {
case EtherType::ARP: case EtherType::ARP:
{ {
m_arp_table->add_arp_packet(interface, packet.slice(sizeof(EthernetHeader))); m_ipv4_layer->arp_table().add_arp_packet(interface, packet.slice(sizeof(EthernetHeader)));
break; break;
} }
case EtherType::IPv4: case EtherType::IPv4:
{ {
auto ipv4 = packet.slice(sizeof(EthernetHeader)); m_ipv4_layer->add_ipv4_packet(interface, packet.slice(sizeof(EthernetHeader)));
auto& ipv4_header = ipv4.as<const IPv4Header>();
auto src_ipv4 = ipv4_header.src_address;
switch (ipv4_header.protocol)
{
case NetworkProtocol::UDP:
{
auto udp = ipv4.slice(sizeof(IPv4Header));
auto& udp_header = udp.as<const UDPHeader>();
uint16_t src_port = udp_header.src_port;
uint16_t dst_port = udp_header.dst_port;
if (!m_bound_sockets.contains(dst_port))
{
dprintln("no one is listening on port {}", dst_port);
return;
}
auto raw = udp.slice(8);
m_bound_sockets[dst_port].lock()->add_packet(raw, src_ipv4, src_port);
break;
}
default:
dprintln("Unknown network protocol 0x{2H}", ipv4_header.protocol);
break;
}
break; break;
} }
default: default:
dprintln("Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type); dprintln_if(DEBUG_ETHERTYPE, "Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type);
break; break;
} }
} }

View File

@ -1,4 +1,3 @@
#include <kernel/Networking/IPv4.h>
#include <kernel/Networking/NetworkManager.h> #include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/NetworkSocket.h> #include <kernel/Networking/NetworkSocket.h>
@ -7,13 +6,9 @@
namespace Kernel namespace Kernel
{ {
NetworkSocket::NetworkSocket(mode_t mode, uid_t uid, gid_t gid) NetworkSocket::NetworkSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
// FIXME: what the fuck is this : TmpInode(NetworkManager::get(), ino, inode_info)
: TmpInode( , m_network_layer(network_layer)
NetworkManager::get(),
MUST(NetworkManager::get().allocate_inode(create_inode_info(mode, uid, gid))),
create_inode_info(mode, uid, gid)
)
{ } { }
NetworkSocket::~NetworkSocket() NetworkSocket::~NetworkSocket()
@ -23,7 +18,7 @@ namespace Kernel
void NetworkSocket::on_close_impl() void NetworkSocket::on_close_impl()
{ {
if (m_interface) if (m_interface)
NetworkManager::get().unbind_socket(m_port, this); m_network_layer.unbind_socket(m_port, this);
} }
void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port) void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port)
@ -36,16 +31,15 @@ namespace Kernel
BAN::ErrorOr<void> NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len) BAN::ErrorOr<void> NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{ {
if (address_len != sizeof(sockaddr_in)) if (m_interface || address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
auto* addr_in = reinterpret_cast<const sockaddr_in*>(address); auto* addr_in = reinterpret_cast<const sockaddr_in*>(address);
return NetworkManager::get().bind_socket(addr_in->sin_port, this); uint16_t dst_port = BAN::host_to_network_endian(addr_in->sin_port);
return m_network_layer.bind_socket(dst_port, this);
} }
BAN::ErrorOr<ssize_t> NetworkSocket::sendto_impl(const sys_sendto_t* arguments) BAN::ErrorOr<size_t> NetworkSocket::sendto_impl(const sys_sendto_t* arguments)
{ {
if (arguments->dest_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
if (arguments->flags) if (arguments->flags)
{ {
dprintln("flags not supported"); dprintln("flags not supported");
@ -53,45 +47,12 @@ namespace Kernel
} }
if (!m_interface) if (!m_interface)
TRY(NetworkManager::get().bind_socket(PORT_NONE, this)); TRY(m_network_layer.bind_socket(PORT_NONE, this));
auto* destination = reinterpret_cast<const sockaddr_in*>(arguments->dest_addr); return TRY(m_network_layer.sendto(*this, arguments));
auto message = BAN::ConstByteSpan((const uint8_t*)arguments->message, arguments->length);
uint16_t dst_port = destination->sin_port;
if (dst_port == PORT_NONE)
return BAN::Error::from_errno(EINVAL);
auto dst_addr = BAN::IPv4Address(destination->sin_addr.s_addr);
auto dst_mac = TRY(NetworkManager::get().arp_table().get_mac_from_ipv4(*m_interface, dst_addr));
const size_t interface_header_offset = 0;
const size_t interface_header_size = m_interface->interface_header_size();
const size_t ipv4_header_offset = interface_header_offset + interface_header_size;
const size_t ipv4_header_size = sizeof(IPv4Header);
const size_t protocol_header_offset = ipv4_header_offset + ipv4_header_size;
const size_t protocol_header_size = this->protocol_header_size();
const size_t payload_offset = protocol_header_offset + protocol_header_size;
const size_t payload_size = message.size();
BAN::Vector<uint8_t> full_packet;
TRY(full_packet.resize(payload_offset + payload_size));
BAN::ByteSpan packet_bytespan { full_packet.span() };
memcpy(full_packet.data() + payload_offset, message.data(), payload_size);
add_protocol_header(packet_bytespan.slice(protocol_header_offset), m_port, dst_port);
add_ipv4_header(packet_bytespan.slice(ipv4_header_offset), m_interface->get_ipv4_address(), dst_addr, protocol());
m_interface->add_interface_header(packet_bytespan.slice(interface_header_offset), dst_mac);
TRY(m_interface->send_raw_bytes(packet_bytespan));
return arguments->length;
} }
BAN::ErrorOr<ssize_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments) BAN::ErrorOr<size_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{ {
sockaddr_in* sender_addr = nullptr; sockaddr_in* sender_addr = nullptr;
if (arguments->address) if (arguments->address)
@ -140,36 +101,52 @@ namespace Kernel
{ {
case SIOCGIFADDR: case SIOCGIFADDR:
{ {
auto ipv4_address = m_interface->get_ipv4_address(); auto& ifru_addr = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_addr);
ifreq->ifr_ifru.ifru_addr.sa_family = AF_INET; ifru_addr.sin_family = AF_INET;
memcpy(ifreq->ifr_ifru.ifru_addr.sa_data, &ipv4_address, sizeof(ipv4_address)); ifru_addr.sin_addr.s_addr = m_interface->get_ipv4_address().raw;
return 0; return 0;
} }
case SIOCSIFADDR: case SIOCSIFADDR:
{ {
if (ifreq->ifr_ifru.ifru_addr.sa_family != AF_INET) auto& ifru_addr = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_addr);
if (ifru_addr.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL); return BAN::Error::from_errno(EADDRNOTAVAIL);
BAN::IPv4Address ipv4_address { *reinterpret_cast<uint32_t*>(ifreq->ifr_ifru.ifru_addr.sa_data) }; m_interface->set_ipv4_address(BAN::IPv4Address { ifru_addr.sin_addr.s_addr });
m_interface->set_ipv4_address(ipv4_address);
dprintln("IPv4 address set to {}", m_interface->get_ipv4_address()); dprintln("IPv4 address set to {}", m_interface->get_ipv4_address());
return 0; return 0;
} }
case SIOCGIFNETMASK: case SIOCGIFNETMASK:
{ {
auto netmask_address = m_interface->get_netmask(); auto& ifru_netmask = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_netmask);
ifreq->ifr_ifru.ifru_netmask.sa_family = AF_INET; ifru_netmask.sin_family = AF_INET;
memcpy(ifreq->ifr_ifru.ifru_netmask.sa_data, &netmask_address, sizeof(netmask_address)); ifru_netmask.sin_addr.s_addr = m_interface->get_netmask().raw;
return 0; return 0;
} }
case SIOCSIFNETMASK: case SIOCSIFNETMASK:
{ {
if (ifreq->ifr_ifru.ifru_netmask.sa_family != AF_INET) auto& ifru_netmask = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_netmask);
if (ifru_netmask.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL); return BAN::Error::from_errno(EADDRNOTAVAIL);
BAN::IPv4Address netmask { *reinterpret_cast<uint32_t*>(ifreq->ifr_ifru.ifru_netmask.sa_data) }; m_interface->set_netmask(BAN::IPv4Address { ifru_netmask.sin_addr.s_addr });
m_interface->set_netmask(netmask);
dprintln("Netmask set to {}", m_interface->get_netmask()); dprintln("Netmask set to {}", m_interface->get_netmask());
return 0; return 0;
} }
case SIOCGIFGWADDR:
{
auto& ifru_gwaddr = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_gwaddr);
ifru_gwaddr.sin_family = AF_INET;
ifru_gwaddr.sin_addr.s_addr = m_interface->get_gateway().raw;
return 0;
}
case SIOCSIFGWADDR:
{
auto& ifru_gwaddr = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_gwaddr);
if (ifru_gwaddr.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL);
m_interface->set_gateway(BAN::IPv4Address { ifru_gwaddr.sin_addr.s_addr });
dprintln("Gateway set to {}", m_interface->get_gateway());
return 0;
}
case SIOCGIFHWADDR: case SIOCGIFHWADDR:
{ {
auto mac_address = m_interface->get_mac_address(); auto mac_address = m_interface->get_mac_address();

View File

@ -5,9 +5,9 @@
namespace Kernel namespace Kernel
{ {
BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(mode_t mode, uid_t uid, gid_t gid) BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
{ {
auto socket = TRY(BAN::RefPtr<UDPSocket>::create(mode, uid, gid)); auto socket = TRY(BAN::RefPtr<UDPSocket>::create(network_layer, ino, inode_info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range( socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(), PageTable::kernel(),
KERNEL_OFFSET, KERNEL_OFFSET,
@ -19,14 +19,14 @@ namespace Kernel
return socket; return socket;
} }
UDPSocket::UDPSocket(mode_t mode, uid_t uid, gid_t gid) UDPSocket::UDPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info)
: NetworkSocket(mode, uid, gid) : NetworkSocket(network_layer, ino, inode_info)
{ } { }
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port)
{ {
auto& header = packet.as<UDPHeader>(); auto& header = packet.as<UDPHeader>();
header.src_port = src_port; header.src_port = m_port;
header.dst_port = dst_port; header.dst_port = dst_port;
header.length = packet.size(); header.length = packet.size();
header.checksum = 0; header.checksum = 0;
@ -91,8 +91,8 @@ namespace Kernel
if (sender_addr) if (sender_addr)
{ {
sender_addr->sin_family = AF_INET; sender_addr->sin_family = AF_INET;
sender_addr->sin_port = packet_info.sender_port; sender_addr->sin_port = BAN::NetworkEndian(packet_info.sender_port);
sender_addr->sin_addr.s_addr = packet_info.sender_addr.as_u32(); sender_addr->sin_addr.s_addr = packet_info.sender_addr.raw;
} }
return nread; return nread;

View File

@ -0,0 +1,323 @@
#include <BAN/HashMap.h>
#include <kernel/FS/VirtualFileSystem.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UNIX/Socket.h>
#include <kernel/Scheduler.h>
#include <fcntl.h>
#include <sys/un.h>
namespace Kernel
{
static BAN::HashMap<BAN::String, BAN::RefPtr<UnixDomainSocket>> s_bound_sockets;
static SpinLock s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
{
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, ino, inode_info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
s_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
return socket;
}
UnixDomainSocket::UnixDomainSocket(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
, m_socket_type(socket_type)
{
switch (socket_type)
{
case SocketType::STREAM:
case SocketType::SEQPACKET:
m_info.emplace<ConnectionInfo>();
break;
case SocketType::DGRAM:
m_info.emplace<ConnectionlessInfo>();
break;
default:
ASSERT_NOT_REACHED();
}
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len)
{
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.listening)
return BAN::Error::from_errno(EINVAL);
while (connection_info.pending_connections.empty())
TRY(Thread::current().block_or_eintr(connection_info.pending_semaphore));
BAN::RefPtr<UnixDomainSocket> pending;
{
LockGuard _(connection_info.pending_lock);
pending = connection_info.pending_connections.front();
connection_info.pending_connections.pop();
connection_info.pending_semaphore.unblock();
}
BAN::RefPtr<UnixDomainSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(SocketDomain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
}
TRY(return_inode->m_bound_path.append(m_bound_path));
return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection_done = true;
if (address && address_len && !is_bound_to_unused())
{
size_t copy_len = BAN::Math::min<size_t>(*address_len, sizeof(sockaddr) + m_bound_path.size() + 1);
auto& sockaddr_un = *reinterpret_cast<struct sockaddr_un*>(address);
sockaddr_un.sun_family = AF_UNIX;
strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len);
}
return TRY(Process::current().open_inode(return_inode, O_RDWR));
}
BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (!is_bound())
TRY(m_bound_path.push_back('X'));
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
BAN::RefPtr<UnixDomainSocket> target;
{
LockGuard _(s_bound_socket_lock);
if (!s_bound_sockets.contains(file.canonical_path))
return BAN::Error::from_errno(ECONNREFUSED);
target = s_bound_sockets[file.canonical_path];
}
if (m_socket_type != target->m_socket_type)
return BAN::Error::from_errno(EPROTOTYPE);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(ECONNREFUSED);
if (connection_info.listening)
return BAN::Error::from_errno(EOPNOTSUPP);
connection_info.connection_done = false;
for (;;)
{
auto& target_info = target->m_info.get<ConnectionInfo>();
{
LockGuard _(target_info.pending_lock);
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
{
MUST(target_info.pending_connections.push(this));
target_info.pending_semaphore.unblock();
break;
}
}
TRY(Thread::current().block_or_eintr(target_info.pending_semaphore));
}
while (!connection_info.connection_done)
Scheduler::get().reschedule();
return {};
}
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog)
{
backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN);
if (!is_bound())
return BAN::Error::from_errno(EDESTADDRREQ);
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(EINVAL);
TRY(connection_info.pending_connections.reserve(backlog));
connection_info.listening = true;
return {};
}
BAN::ErrorOr<void> UnixDomainSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (is_bound())
return BAN::Error::from_errno(EINVAL);
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
if (auto ret = Process::current().create_file_or_dir(absolute_path, 0755 | S_IFSOCK); ret.is_error())
{
if (ret.error().get_error_code() == EEXIST)
return BAN::Error::from_errno(EADDRINUSE);
return ret.release_error();
}
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
LockGuard _(s_bound_socket_lock);
ASSERT(!s_bound_sockets.contains(file.canonical_path));
TRY(s_bound_sockets.emplace(file.canonical_path, this));
m_bound_path = BAN::move(file.canonical_path);
return {};
}
bool UnixDomainSocket::is_streaming() const
{
switch (m_socket_type)
{
case SocketType::STREAM:
return true;
case SocketType::SEQPACKET:
case SocketType::DGRAM:
return false;
default:
ASSERT_NOT_REACHED();
}
}
// This to feels too hacky to expose out of here
struct LockFreeGuard
{
LockFreeGuard(RecursivePrioritySpinLock& lock)
: m_lock(lock)
, m_depth(lock.lock_depth())
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.unlock();
}
~LockFreeGuard()
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.lock();
}
private:
RecursivePrioritySpinLock& m_lock;
const uint32_t m_depth;
};
BAN::ErrorOr<void> UnixDomainSocket::add_packet(BAN::ConstByteSpan packet)
{
LockGuard _(m_lock);
while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size)
{
LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
memcpy(packet_buffer, packet.data(), packet.size());
m_packet_size_total += packet.size();
if (!is_streaming())
m_packet_sizes.push(packet.size());
m_packet_semaphore.unblock();
return {};
}
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(const sys_sendto_t* arguments)
{
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (arguments->length > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (arguments->dest_addr)
return BAN::Error::from_errno(EISCONN);
auto target = connection_info.connection.lock();
if (!target)
return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet({ reinterpret_cast<const uint8_t*>(arguments->message), arguments->length }));
return arguments->length;
}
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.connection)
return BAN::Error::from_errno(ENOTCONN);
}
while (m_packet_size_total == 0)
{
LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t nread = 0;
if (is_streaming())
nread = BAN::Math::min(arguments->length, m_packet_size_total);
else
{
nread = BAN::Math::min(arguments->length, m_packet_sizes.front());
m_packet_sizes.pop();
}
memcpy(arguments->buffer, packet_buffer, nread);
memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread);
m_packet_size_total -= nread;
m_packet_semaphore.unblock();
return nread;
}
}

View File

@ -55,6 +55,21 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::RefPtr<Inode> inode, int flags)
{
ASSERT(inode);
ASSERT(!inode->mode().ifdir());
if (flags & ~(O_RDONLY | O_WRONLY))
return BAN::Error::from_errno(ENOTSUP);
int fd = TRY(get_free_fd());
// FIXME: path?
m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(inode, ""sv, 0, flags));
return fd;
}
BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::StringView absolute_path, int flags) BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::StringView absolute_path, int flags)
{ {
if (flags & ~(O_RDONLY | O_WRONLY | O_NOFOLLOW | O_SEARCH | O_APPEND | O_TRUNC | O_CLOEXEC | O_TTY_INIT | O_DIRECTORY | O_NONBLOCK)) if (flags & ~(O_RDONLY | O_WRONLY | O_NOFOLLOW | O_SEARCH | O_APPEND | O_TRUNC | O_CLOEXEC | O_TTY_INIT | O_DIRECTORY | O_NONBLOCK))
@ -80,13 +95,25 @@ namespace Kernel
BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol) BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol)
{ {
using SocketType = NetworkManager::SocketType;
if (domain != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (protocol != 0) if (protocol != 0)
return BAN::Error::from_errno(EPROTONOSUPPORT); return BAN::Error::from_errno(EPROTONOSUPPORT);
SocketDomain sock_domain;
switch (domain)
{
case AF_INET:
sock_domain = SocketDomain::INET;
break;
case AF_INET6:
sock_domain = SocketDomain::INET6;
break;
case AF_UNIX:
sock_domain = SocketDomain::UNIX;
break;
default:
return BAN::Error::from_errno(EPROTOTYPE);
}
SocketType sock_type; SocketType sock_type;
switch (type) switch (type)
{ {
@ -103,7 +130,7 @@ namespace Kernel
return BAN::Error::from_errno(EPROTOTYPE); return BAN::Error::from_errno(EPROTOTYPE);
} }
auto socket = TRY(NetworkManager::get().create_socket(sock_type, 0777, m_credentials.euid(), m_credentials.egid())); auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid()));
int fd = TRY(get_free_fd()); int fd = TRY(get_free_fd());
m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(socket, "no-path"sv, 0, O_RDWR)); m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(socket, "no-path"sv, 0, O_RDWR));

View File

@ -649,8 +649,9 @@ namespace Kernel
case Inode::Mode::IFREG: break; case Inode::Mode::IFREG: break;
case Inode::Mode::IFDIR: break; case Inode::Mode::IFDIR: break;
case Inode::Mode::IFIFO: break; case Inode::Mode::IFIFO: break;
case Inode::Mode::IFSOCK: break;
default: default:
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(ENOTSUP);
} }
LockGuard _(m_lock); LockGuard _(m_lock);
@ -707,8 +708,17 @@ namespace Kernel
return false; return false;
} }
BAN::ErrorOr<long> Process::open_inode(BAN::RefPtr<Inode> inode, int flags)
{
ASSERT(inode);
LockGuard _(m_lock);
return TRY(m_open_file_descriptors.open(inode, flags));
}
BAN::ErrorOr<long> Process::open_file(BAN::StringView path, int flags, mode_t mode) BAN::ErrorOr<long> Process::open_file(BAN::StringView path, int flags, mode_t mode)
{ {
LockGuard _(m_lock);
BAN::String absolute_path = TRY(absolute_path_of(path)); BAN::String absolute_path = TRY(absolute_path_of(path));
if (flags & O_CREAT) if (flags & O_CREAT)
@ -716,6 +726,8 @@ namespace Kernel
if (flags & O_DIRECTORY) if (flags & O_DIRECTORY)
return BAN::Error::from_errno(ENOTSUP); return BAN::Error::from_errno(ENOTSUP);
auto file_or_error = VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, O_WRONLY); auto file_or_error = VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, O_WRONLY);
if (!file_or_error.is_error() && (flags & O_EXCL))
return BAN::Error::from_errno(EEXIST);
if (file_or_error.is_error()) if (file_or_error.is_error())
{ {
if (file_or_error.error().get_error_code() == ENOENT) if (file_or_error.error().get_error_code() == ENOENT)
@ -901,6 +913,26 @@ namespace Kernel
return TRY(m_open_file_descriptors.socket(domain, type, protocol)); return TRY(m_open_file_descriptors.socket(domain, type, protocol));
} }
BAN::ErrorOr<long> Process::sys_accept(int socket, sockaddr* address, socklen_t* address_len)
{
if (address && !address_len)
return BAN::Error::from_errno(EINVAL);
if (!address && address_len)
return BAN::Error::from_errno(EINVAL);
LockGuard _(m_lock);
if (address)
{
TRY(validate_pointer_access(address_len, sizeof(*address_len)));
TRY(validate_pointer_access(address, *address_len));
}
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return TRY(inode->accept(address, address_len));
}
BAN::ErrorOr<long> Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len) BAN::ErrorOr<long> Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len)
{ {
@ -915,6 +947,31 @@ namespace Kernel
return 0; return 0;
} }
BAN::ErrorOr<long> Process::sys_connect(int socket, const sockaddr* address, socklen_t address_len)
{
LockGuard _(m_lock);
TRY(validate_pointer_access(address, address_len));
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
TRY(inode->connect(address, address_len));
return 0;
}
BAN::ErrorOr<long> Process::sys_listen(int socket, int backlog)
{
LockGuard _(m_lock);
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
TRY(inode->listen(backlog));
return 0;
}
BAN::ErrorOr<long> Process::sys_sendto(const sys_sendto_t* arguments) BAN::ErrorOr<long> Process::sys_sendto(const sys_sendto_t* arguments)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
@ -1697,7 +1754,7 @@ namespace Kernel
BAN::ErrorOr<BAN::String> Process::absolute_path_of(BAN::StringView path) const BAN::ErrorOr<BAN::String> Process::absolute_path_of(BAN::StringView path) const
{ {
ASSERT(m_lock.is_locked()); LockGuard _(m_lock);
if (path.empty() || path == "."sv) if (path.empty() || path == "."sv)
return m_working_directory; return m_working_directory;

View File

@ -228,6 +228,15 @@ namespace Kernel
case SYS_IOCTL: case SYS_IOCTL:
ret = Process::current().sys_ioctl((int)arg1, (int)arg2, (void*)arg3); ret = Process::current().sys_ioctl((int)arg1, (int)arg2, (void*)arg3);
break; break;
case SYS_ACCEPT:
ret = Process::current().sys_accept((int)arg1, (sockaddr*)arg2, (socklen_t*)arg3);
break;
case SYS_CONNECT:
ret = Process::current().sys_connect((int)arg1, (const sockaddr*)arg2, (socklen_t)arg3);
break;
case SYS_LISTEN:
ret = Process::current().sys_listen((int)arg1, (int)arg2);
break;
default: default:
dwarnln("Unknown syscall {}", syscall); dwarnln("Unknown syscall {}", syscall);
break; break;

View File

@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.26)
project(libc CXX ASM) project(libc CXX ASM)
set(LIBC_SOURCES set(LIBC_SOURCES
arpa/inet.cpp
assert.cpp assert.cpp
ctype.cpp ctype.cpp
dirent.cpp dirent.cpp

52
libc/arpa/inet.cpp Normal file
View File

@ -0,0 +1,52 @@
#include <BAN/Endianness.h>
#include <arpa/inet.h>
#include <errno.h>
#include <stdio.h>
uint32_t htonl(uint32_t hostlong)
{
return BAN::host_to_network_endian(hostlong);
}
uint16_t htons(uint16_t hostshort)
{
return BAN::host_to_network_endian(hostshort);
}
uint32_t ntohl(uint32_t netlong)
{
return BAN::host_to_network_endian(netlong);
}
uint16_t ntohs(uint16_t netshort)
{
return BAN::host_to_network_endian(netshort);
}
in_addr_t inet_addr(const char* cp)
{
uint32_t a = 0, b = 0, c = 0, d = 0;
int ret = sscanf(cp, "%u.%u.%u.%u", &a, &b, &c, &d);
if (ret < 1 || ret > 4)
return (in_addr_t)(-1);
uint32_t result = 0;
result |= (ret == 1) ? a : a << 24;
result |= (ret == 2) ? b : b << 16;
result |= (ret == 3) ? c : c << 8;
result |= (ret == 4) ? d : d << 0;
return htonl(result);
}
char* inet_ntoa(struct in_addr in)
{
static char buffer[16];
uint32_t he = ntohl(in.s_addr);
sprintf(buffer, "%u.%u.%u.%u",
(he >> 24) & 0xFF,
(he >> 16) & 0xFF,
(he >> 8) & 0xFF,
(he >> 0) & 0xFF
);
return buffer;
}

View File

@ -22,6 +22,7 @@ struct ifreq
union { union {
struct sockaddr ifru_addr; struct sockaddr ifru_addr;
struct sockaddr ifru_netmask; struct sockaddr ifru_netmask;
struct sockaddr ifru_gwaddr;
struct sockaddr ifru_hwaddr; struct sockaddr ifru_hwaddr;
unsigned char __min_storage[sizeof(sockaddr) + 6]; unsigned char __min_storage[sizeof(sockaddr) + 6];
} ifr_ifru; } ifr_ifru;
@ -31,7 +32,9 @@ struct ifreq
#define SIOCSIFADDR 2 /* Set interface address */ #define SIOCSIFADDR 2 /* Set interface address */
#define SIOCGIFNETMASK 3 /* Get network mask */ #define SIOCGIFNETMASK 3 /* Get network mask */
#define SIOCSIFNETMASK 4 /* Set network mask */ #define SIOCSIFNETMASK 4 /* Set network mask */
#define SIOCGIFHWADDR 5 /* Get hardware address */ #define SIOCGIFGWADDR 5 /* Get gateway address */
#define SIOCSIFGWADDR 6 /* Set gateway address */
#define SIOCGIFHWADDR 7 /* Get hardware address */
void if_freenameindex(struct if_nameindex* ptr); void if_freenameindex(struct if_nameindex* ptr);
char* if_indextoname(unsigned ifindex, char* ifname); char* if_indextoname(unsigned ifindex, char* ifname);

View File

@ -53,6 +53,8 @@ extern FILE* __stdout;
#define stdout __stdout #define stdout __stdout
extern FILE* __stderr; extern FILE* __stderr;
#define stderr __stderr #define stderr __stderr
extern FILE* __stddbg;
#define stddbg __stddbg
void clearerr(FILE* stream); void clearerr(FILE* stream);
char* ctermid(char* s); char* ctermid(char* s);

View File

@ -16,6 +16,12 @@ __BEGIN_DECLS
#include <bits/types/sa_family_t.h> #include <bits/types/sa_family_t.h>
typedef long socklen_t; typedef long socklen_t;
#if !defined(FILENAME_MAX)
#define FILENAME_MAX 256
#elif FILENAME_MAX != 256
#error "invalid FILENAME_MAX"
#endif
struct sockaddr struct sockaddr
{ {
sa_family_t sa_family; /* Address family. */ sa_family_t sa_family; /* Address family. */
@ -24,8 +30,8 @@ struct sockaddr
struct sockaddr_storage struct sockaddr_storage
{ {
// FIXME
sa_family_t ss_family; sa_family_t ss_family;
char ss_storage[FILENAME_MAX];
}; };
struct msghdr struct msghdr

View File

@ -68,6 +68,9 @@ __BEGIN_DECLS
#define SYS_SENDTO 67 #define SYS_SENDTO 67
#define SYS_RECVFROM 68 #define SYS_RECVFROM 68
#define SYS_IOCTL 69 #define SYS_IOCTL 69
#define SYS_ACCEPT 70
#define SYS_CONNECT 71
#define SYS_LISTEN 72
__END_DECLS __END_DECLS

View File

@ -11,8 +11,8 @@ __BEGIN_DECLS
struct sockaddr_un struct sockaddr_un
{ {
sa_family_t sun_family; /* Address family. */ sa_family_t sun_family; /* Address family. */
char sun_path[]; /* Socket pathname. */ char sun_path[FILENAME_MAX]; /* Socket pathname. */
}; };
__END_DECLS __END_DECLS

View File

@ -120,6 +120,7 @@ __BEGIN_DECLS
#define STDIN_FILENO 0 #define STDIN_FILENO 0
#define STDOUT_FILENO 1 #define STDOUT_FILENO 1
#define STDERR_FILENO 2 #define STDERR_FILENO 2
#define STDDBG_FILENO 3
#define _POSIX_VDISABLE 0 #define _POSIX_VDISABLE 0

View File

@ -22,11 +22,13 @@ static FILE s_files[FOPEN_MAX] {
{ .fd = STDIN_FILENO }, { .fd = STDIN_FILENO },
{ .fd = STDOUT_FILENO }, { .fd = STDOUT_FILENO },
{ .fd = STDERR_FILENO }, { .fd = STDERR_FILENO },
{ .fd = STDDBG_FILENO },
}; };
FILE* stdin = &s_files[0]; FILE* stdin = &s_files[0];
FILE* stdout = &s_files[1]; FILE* stdout = &s_files[1];
FILE* stderr = &s_files[2]; FILE* stderr = &s_files[2];
FILE* stddbg = &s_files[3];
void clearerr(FILE* file) void clearerr(FILE* file)
{ {

View File

@ -2,11 +2,31 @@
#include <sys/syscall.h> #include <sys/syscall.h>
#include <unistd.h> #include <unistd.h>
int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
{
return syscall(SYS_ACCEPT, socket, address, address_len);
}
int bind(int socket, const struct sockaddr* address, socklen_t address_len) int bind(int socket, const struct sockaddr* address, socklen_t address_len)
{ {
return syscall(SYS_BIND, socket, address, address_len); return syscall(SYS_BIND, socket, address, address_len);
} }
int connect(int socket, const struct sockaddr* address, socklen_t address_len)
{
return syscall(SYS_CONNECT, socket, address, address_len);
}
int listen(int socket, int backlog)
{
return syscall(SYS_LISTEN, socket, backlog);
}
ssize_t recv(int socket, void* __restrict buffer, size_t length, int flags)
{
return recvfrom(socket, buffer, length, flags, nullptr, nullptr);
}
ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len) ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
{ {
sys_recvfrom_t arguments { sys_recvfrom_t arguments {
@ -20,6 +40,10 @@ ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags,
return syscall(SYS_RECVFROM, &arguments); return syscall(SYS_RECVFROM, &arguments);
} }
ssize_t send(int socket, const void* message, size_t length, int flags)
{
return sendto(socket, message, length, flags, nullptr, 0);
}
ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len) ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len)
{ {

View File

@ -11,14 +11,16 @@ set(USERSPACE_PROJECTS
dhcp-client dhcp-client
echo echo
id id
init
image image
init
loadkeys loadkeys
ls ls
meminfo meminfo
mkdir mkdir
mmap-shared-test mmap-shared-test
nslookup
poweroff poweroff
resolver
rm rm
Shell Shell
sleep sleep

View File

@ -72,9 +72,9 @@ static constexpr Position s_dir_offset[] {
i64 solve_general(FILE* fp, auto parse_dir, auto parse_count) i64 solve_general(FILE* fp, auto parse_dir, auto parse_count)
{ {
BAN::HashSetUnstable<Position, PositionHash> path; BAN::HashSet<Position, PositionHash> path;
BAN::HashSetUnstable<Position, PositionHash> lpath; BAN::HashSet<Position, PositionHash> lpath;
BAN::HashSetUnstable<Position, PositionHash> rpath; BAN::HashSet<Position, PositionHash> rpath;
Position current_pos { 0, 0 }; Position current_pos { 0, 0 };
MUST(path.insert(current_pos)); MUST(path.insert(current_pos));
@ -157,8 +157,8 @@ i64 solve_general(FILE* fp, auto parse_dir, auto parse_count)
ASSERT(lmin_x != rmin_x); ASSERT(lmin_x != rmin_x);
auto& expand = (lmin_x < rmin_x) ? rpath : lpath; auto& expand = (lmin_x < rmin_x) ? rpath : lpath;
BAN::HashSetUnstable<Position, PositionHash> visited; BAN::HashSet<Position, PositionHash> visited;
BAN::HashSetUnstable<Position, PositionHash> inner_area; BAN::HashSet<Position, PositionHash> inner_area;
while (!expand.empty()) while (!expand.empty())
{ {

View File

@ -33,7 +33,7 @@ struct Rule
BAN::String target; BAN::String target;
}; };
using Workflows = BAN::HashMapUnstable<BAN::String, BAN::Vector<Rule>>; using Workflows = BAN::HashMap<BAN::String, BAN::Vector<Rule>>;
struct Item struct Item
{ {

View File

@ -72,9 +72,9 @@ struct ConjunctionModule : public Module
} }
}; };
BAN::HashMapUnstable<BAN::String, BAN::UniqPtr<Module>> parse_modules(FILE* fp) BAN::HashMap<BAN::String, BAN::UniqPtr<Module>> parse_modules(FILE* fp)
{ {
BAN::HashMapUnstable<BAN::String, BAN::UniqPtr<Module>> modules; BAN::HashMap<BAN::String, BAN::UniqPtr<Module>> modules;
char buffer[128]; char buffer[128];
while (fgets(buffer, sizeof(buffer), fp)) while (fgets(buffer, sizeof(buffer), fp))

View File

@ -88,13 +88,13 @@ i64 puzzle1(FILE* fp)
{ {
auto garden = parse_garden(fp); auto garden = parse_garden(fp);
BAN::HashSetUnstable<Position> visited, reachable, pending; BAN::HashSet<Position> visited, reachable, pending;
MUST(pending.insert(garden.start)); MUST(pending.insert(garden.start));
for (i32 i = 0; i <= 64; i++) for (i32 i = 0; i <= 64; i++)
{ {
auto temp = BAN::move(pending); auto temp = BAN::move(pending);
pending = BAN::HashSetUnstable<Position>(); pending = BAN::HashSet<Position>();
while (!temp.empty()) while (!temp.empty())
{ {

View File

@ -1,8 +1,10 @@
#include <BAN/Debug.h>
#include <BAN/Endianness.h> #include <BAN/Endianness.h>
#include <BAN/IPv4.h> #include <BAN/IPv4.h>
#include <BAN/MAC.h> #include <BAN/MAC.h>
#include <BAN/Vector.h> #include <BAN/Vector.h>
#include <arpa/inet.h>
#include <fcntl.h> #include <fcntl.h>
#include <net/if.h> #include <net/if.h>
#include <netinet/in.h> #include <netinet/in.h>
@ -12,7 +14,7 @@
#include <stropts.h> #include <stropts.h>
#include <sys/socket.h> #include <sys/socket.h>
#define DEBUG_DHCP 0 #define DEBUG_DHCP 1
struct DHCPPacket struct DHCPPacket
{ {
@ -23,10 +25,10 @@ struct DHCPPacket
BAN::NetworkEndian<uint32_t> xid { 0x3903F326 }; BAN::NetworkEndian<uint32_t> xid { 0x3903F326 };
BAN::NetworkEndian<uint16_t> secs { 0x0000 }; BAN::NetworkEndian<uint16_t> secs { 0x0000 };
BAN::NetworkEndian<uint16_t> flags { 0x0000 }; BAN::NetworkEndian<uint16_t> flags { 0x0000 };
BAN::NetworkEndian<uint32_t> ciaddr { 0 }; BAN::IPv4Address ciaddr { 0 };
BAN::NetworkEndian<uint32_t> yiaddr { 0 }; BAN::IPv4Address yiaddr { 0 };
BAN::NetworkEndian<uint32_t> siaddr { 0 }; BAN::IPv4Address siaddr { 0 };
BAN::NetworkEndian<uint32_t> giaddr { 0 }; BAN::IPv4Address giaddr { 0 };
BAN::MACAddress chaddr; BAN::MACAddress chaddr;
uint8_t padding[10] {}; uint8_t padding[10] {};
uint8_t legacy[192] {}; uint8_t legacy[192] {};
@ -71,12 +73,13 @@ BAN::MACAddress get_mac_address(int socket)
return mac_address; return mac_address;
} }
void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address subnet) void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address netmask, BAN::IPv4Address gateway)
{ {
{ {
ifreq ifreq; ifreq ifreq;
ifreq.ifr_ifru.ifru_addr.sa_family = AF_INET; auto& ifru_addr = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_addr);
*(uint32_t*)ifreq.ifr_ifru.ifru_addr.sa_data = address.as_u32(); ifru_addr.sin_family = AF_INET;
ifru_addr.sin_addr.s_addr = address.raw;
if (ioctl(socket, SIOCSIFADDR, &ifreq) == -1) if (ioctl(socket, SIOCSIFADDR, &ifreq) == -1)
{ {
perror("ioctl"); perror("ioctl");
@ -86,22 +89,36 @@ void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address sub
{ {
ifreq ifreq; ifreq ifreq;
ifreq.ifr_ifru.ifru_netmask.sa_family = AF_INET; auto& ifru_netmask = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_netmask);
*(uint32_t*)ifreq.ifr_ifru.ifru_netmask.sa_data = subnet.as_u32(); ifru_netmask.sin_family = AF_INET;
ifru_netmask.sin_addr.s_addr = netmask.raw;
if (ioctl(socket, SIOCSIFNETMASK, &ifreq) == -1) if (ioctl(socket, SIOCSIFNETMASK, &ifreq) == -1)
{ {
perror("ioctl"); perror("ioctl");
exit(1); exit(1);
} }
} }
if (gateway.raw)
{
ifreq ifreq;
auto& ifru_gwaddr = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_gwaddr);
ifru_gwaddr.sin_family = AF_INET;
ifru_gwaddr.sin_addr.s_addr = gateway.raw;
if (ioctl(socket, SIOCSIFGWADDR, &ifreq) == -1)
{
perror("ioctl");
exit(1);
}
}
} }
void send_dhcp_packet(int socket, const DHCPPacket& dhcp_packet, BAN::IPv4Address server_ipv4) void send_dhcp_packet(int socket, const DHCPPacket& dhcp_packet, BAN::IPv4Address server_ipv4)
{ {
sockaddr_in server_addr; sockaddr_in server_addr;
server_addr.sin_family = AF_INET; server_addr.sin_family = AF_INET;
server_addr.sin_port = 67; server_addr.sin_port = htons(67);
server_addr.sin_addr.s_addr = server_ipv4.as_u32();; server_addr.sin_addr.s_addr = server_ipv4.raw;
if (sendto(socket, &dhcp_packet, sizeof(DHCPPacket), 0, (sockaddr*)&server_addr, sizeof(server_addr)) == -1) if (sendto(socket, &dhcp_packet, sizeof(DHCPPacket), 0, (sockaddr*)&server_addr, sizeof(server_addr)) == -1)
{ {
@ -137,7 +154,7 @@ void send_dhcp_request(int socket, BAN::MACAddress mac_address, BAN::IPv4Address
{ {
DHCPPacket dhcp_packet; DHCPPacket dhcp_packet;
dhcp_packet.op = 0x01; dhcp_packet.op = 0x01;
dhcp_packet.siaddr = server_ipv4.as_u32(); dhcp_packet.siaddr = server_ipv4.raw;
dhcp_packet.chaddr = mac_address; dhcp_packet.chaddr = mac_address;
size_t idx = 0; size_t idx = 0;
@ -148,10 +165,10 @@ void send_dhcp_request(int socket, BAN::MACAddress mac_address, BAN::IPv4Address
dhcp_packet.options[idx++] = RequestedIPv4Address; dhcp_packet.options[idx++] = RequestedIPv4Address;
dhcp_packet.options[idx++] = 0x04; dhcp_packet.options[idx++] = 0x04;
dhcp_packet.options[idx++] = offered_ipv4.address[0]; dhcp_packet.options[idx++] = offered_ipv4.octets[0];
dhcp_packet.options[idx++] = offered_ipv4.address[1]; dhcp_packet.options[idx++] = offered_ipv4.octets[1];
dhcp_packet.options[idx++] = offered_ipv4.address[2]; dhcp_packet.options[idx++] = offered_ipv4.octets[2];
dhcp_packet.options[idx++] = offered_ipv4.address[3]; dhcp_packet.options[idx++] = offered_ipv4.octets[3];
dhcp_packet.options[idx++] = 0xFF; dhcp_packet.options[idx++] = 0xFF;
@ -188,7 +205,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
fprintf(stderr, "Subnet mask with invalid length %hhu\n", length); fprintf(stderr, "Subnet mask with invalid length %hhu\n", length);
break; break;
} }
uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options); uint32_t raw = *reinterpret_cast<const uint32_t*>(options);
packet_info.subnet = BAN::IPv4Address(raw); packet_info.subnet = BAN::IPv4Address(raw);
break; break;
} }
@ -201,7 +218,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
} }
for (int i = 0; i < length; i += 4) for (int i = 0; i < length; i += 4)
{ {
uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options + i); uint32_t raw = *reinterpret_cast<const uint32_t*>(options + i);
MUST(packet_info.routers.emplace_back(raw)); MUST(packet_info.routers.emplace_back(raw));
} }
break; break;
@ -215,7 +232,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
} }
for (int i = 0; i < length; i += 4) for (int i = 0; i < length; i += 4)
{ {
uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options + i); uint32_t raw = *reinterpret_cast<const uint32_t*>(options + i);
MUST(packet_info.dns.emplace_back(raw)); MUST(packet_info.dns.emplace_back(raw));
} }
break; break;
@ -244,7 +261,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
fprintf(stderr, "Server identifier with invalid length %hhu\n", length); fprintf(stderr, "Server identifier with invalid length %hhu\n", length);
break; break;
} }
uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options); uint32_t raw = *reinterpret_cast<const uint32_t*>(options);
packet_info.server = BAN::IPv4Address(raw); packet_info.server = BAN::IPv4Address(raw);
break; break;
} }
@ -293,8 +310,8 @@ int main()
sockaddr_in client_addr; sockaddr_in client_addr;
client_addr.sin_family = AF_INET; client_addr.sin_family = AF_INET;
client_addr.sin_port = 68; client_addr.sin_port = htons(68);
client_addr.sin_addr.s_addr = 0x00000000; client_addr.sin_addr.s_addr = INADDR_ANY;
if (bind(socket, (sockaddr*)&client_addr, sizeof(client_addr)) == -1) if (bind(socket, (sockaddr*)&client_addr, sizeof(client_addr)) == -1)
{ {
@ -304,12 +321,12 @@ int main()
auto mac_address = get_mac_address(socket); auto mac_address = get_mac_address(socket);
#if DEBUG_DHCP #if DEBUG_DHCP
BAN::Formatter::println(putchar, "MAC: {}", mac_address); dprintln("MAC: {}", mac_address);
#endif #endif
send_dhcp_discover(socket, mac_address); send_dhcp_discover(socket, mac_address);
#if DEBUG_DHCP #if DEBUG_DHCP
printf("DHCPDISCOVER sent\n"); dprintln("DHCPDISCOVER sent");
#endif #endif
auto dhcp_offer = read_dhcp_packet(socket); auto dhcp_offer = read_dhcp_packet(socket);
@ -322,15 +339,15 @@ int main()
} }
#if DEBUG_DHCP #if DEBUG_DHCP
BAN::Formatter::println(putchar, "DHCPOFFER"); dprintln("DHCPOFFER");
BAN::Formatter::println(putchar, " IP {}", dhcp_offer->address); dprintln(" IP {}", dhcp_offer->address);
BAN::Formatter::println(putchar, " SUBNET {}", dhcp_offer->subnet); dprintln(" SUBNET {}", dhcp_offer->subnet);
BAN::Formatter::println(putchar, " SERVER {}", dhcp_offer->server); dprintln(" SERVER {}", dhcp_offer->server);
#endif #endif
send_dhcp_request(socket, mac_address, dhcp_offer->address, dhcp_offer->server); send_dhcp_request(socket, mac_address, dhcp_offer->address, dhcp_offer->server);
#if DEBUG_DHCP #if DEBUG_DHCP
printf("DHCPREQUEST sent\n"); dprintln("DHCPREQUEST sent");
#endif #endif
auto dhcp_ack = read_dhcp_packet(socket); auto dhcp_ack = read_dhcp_packet(socket);
@ -343,10 +360,10 @@ int main()
} }
#if DEBUG_DHCP #if DEBUG_DHCP
BAN::Formatter::println(putchar, "DHCPACK"); dprintln("DHCPACK");
BAN::Formatter::println(putchar, " IP {}", dhcp_ack->address); dprintln(" IP {}", dhcp_ack->address);
BAN::Formatter::println(putchar, " SUBNET {}", dhcp_ack->subnet); dprintln(" SUBNET {}", dhcp_ack->subnet);
BAN::Formatter::println(putchar, " SERVER {}", dhcp_ack->server); dprintln(" SERVER {}", dhcp_ack->server);
#endif #endif
if (dhcp_offer->address != dhcp_ack->address) if (dhcp_offer->address != dhcp_ack->address)
@ -355,7 +372,11 @@ int main()
return 1; return 1;
} }
update_ipv4_info(socket, dhcp_ack->address, dhcp_ack->subnet); BAN::IPv4Address gateway { 0 };
if (!dhcp_ack->routers.empty())
gateway = dhcp_ack->routers.front();
update_ipv4_info(socket, dhcp_ack->address, dhcp_ack->subnet, gateway);
close(socket); close(socket);

View File

@ -17,6 +17,7 @@ void initialize_stdio()
if (open(tty, O_RDONLY | O_TTY_INIT) != 0) _exit(1); if (open(tty, O_RDONLY | O_TTY_INIT) != 0) _exit(1);
if (open(tty, O_WRONLY) != 1) _exit(1); if (open(tty, O_WRONLY) != 1) _exit(1);
if (open(tty, O_WRONLY) != 2) _exit(1); if (open(tty, O_WRONLY) != 2) _exit(1);
if (open("/dev/debug", O_WRONLY) != 3) _exit(1);
} }
int main() int main()
@ -35,6 +36,12 @@ int main()
exit(1); exit(1);
} }
if (fork() == 0)
{
execl("/bin/resolver", "resolver", NULL);
exit(1);
}
bool first = true; bool first = true;
termios termios; termios termios;

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.26)
project(nslookup CXX)
set(SOURCES
main.cpp
)
add_executable(nslookup ${SOURCES})
target_compile_options(nslookup PUBLIC -O2 -g)
target_link_libraries(nslookup PUBLIC libc)
add_custom_target(nslookup-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/nslookup ${BANAN_BIN}/
DEPENDS nslookup
)

View File

@ -0,0 +1,47 @@
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "usage: %s DOMAIN\n", argv[0]);
return 1;
}
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
perror("socket");
return 1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("connect");
return 1;
}
if (send(socket, argv[1], strlen(argv[1]), 0) == -1)
{
perror("send");
return 1;
}
char buffer[128];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
return 1;
}
buffer[nrecv] = '\0';
printf("%s\n", buffer);
return 0;
}

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.26)
project(resolver CXX)
set(SOURCES
main.cpp
)
add_executable(resolver ${SOURCES})
target_compile_options(resolver PUBLIC -O2 -g)
target_link_libraries(resolver PUBLIC libc ban)
add_custom_target(resolver-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/resolver ${BANAN_BIN}/
DEPENDS resolver
)

218
userspace/resolver/main.cpp Normal file
View File

@ -0,0 +1,218 @@
#include <BAN/ByteSpan.h>
#include <BAN/Endianness.h>
#include <BAN/HashMap.h>
#include <BAN/IPv4.h>
#include <BAN/String.h>
#include <BAN/StringView.h>
#include <BAN/Vector.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <unistd.h>
struct DNSPacket
{
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags { 0 };
BAN::NetworkEndian<uint16_t> question_count { 0 };
BAN::NetworkEndian<uint16_t> answer_count { 0 };
BAN::NetworkEndian<uint16_t> authority_RR_count { 0 };
BAN::NetworkEndian<uint16_t> additional_RR_count { 0 };
uint8_t data[];
};
static_assert(sizeof(DNSPacket) == 12);
struct DNSAnswer
{
uint8_t __storage[12];
BAN::NetworkEndian<uint16_t>& name() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x00); };
BAN::NetworkEndian<uint16_t>& type() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x02); };
BAN::NetworkEndian<uint16_t>& class_() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x04); };
BAN::NetworkEndian<uint32_t>& ttl() { return *reinterpret_cast<BAN::NetworkEndian<uint32_t>*>(__storage + 0x06); };
BAN::NetworkEndian<uint16_t>& data_len() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x0A); };
uint8_t data[];
};
static_assert(sizeof(DNSAnswer) == 12);
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
{
static uint8_t buffer[4096];
memset(buffer, 0, sizeof(buffer));
DNSPacket& request = *reinterpret_cast<DNSPacket*>(buffer);
request.identification = id;
request.flags = 0x0100;
request.question_count = 1;
size_t idx = 0;
auto labels = MUST(BAN::StringView(domain).split('.'));
for (auto label : labels)
{
ASSERT(label.size() <= 0xFF);
request.data[idx++] = label.size();
for (char c : label)
request.data[idx++] = c;
}
request.data[idx++] = 0x00;
*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2;
*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2;
sockaddr_in nameserver;
nameserver.sin_family = AF_INET;
nameserver.sin_port = htons(53);
nameserver.sin_addr.s_addr = inet_addr("8.8.8.8");
if (sendto(socket, &request, sizeof(DNSPacket) + idx, 0, (sockaddr*)&nameserver, sizeof(nameserver)) == -1)
{
perror("sendto");
return false;
}
return true;
}
BAN::Optional<BAN::String> read_dns_response(int socket, uint16_t id)
{
static uint8_t buffer[4096];
ssize_t nrecv = recvfrom(socket, buffer, sizeof(buffer), 0, nullptr, nullptr);
if (nrecv == -1)
{
perror("recvfrom");
return {};
}
DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
if (reply.identification != id)
{
fprintf(stderr, "Reply to invalid packet\n");
return {};
}
if (reply.flags & 0x0F)
{
fprintf(stderr, "DNS error (rcode %u)\n", (unsigned)(reply.flags & 0xF));
return {};
}
size_t idx = 0;
for (size_t i = 0; i < reply.question_count; i++)
{
while (reply.data[idx])
idx += reply.data[idx] + 1;
idx += 5;
}
DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]);
if (answer.data_len() != 4)
{
fprintf(stderr, "Not IPv4\n");
return {};
}
return inet_ntoa({ .s_addr = *reinterpret_cast<uint32_t*>(answer.data) });
}
int create_service_socket()
{
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
perror("socket");
return -1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (bind(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("bind");
close(socket);
return -1;
}
if (chmod("/tmp/resolver.sock", 0777) == -1)
{
perror("chmod");
close(socket);
return -1;
}
if (listen(socket, 10) == -1)
{
perror("listen");
close(socket);
return -1;
}
return socket;
}
BAN::Optional<BAN::String> read_service_query(int socket)
{
static char buffer[4096];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
return {};
}
buffer[nrecv] = '\0';
return BAN::String(buffer);
}
int main(int, char**)
{
srand(time(nullptr));
int service_socket = create_service_socket();
if (service_socket == -1)
return 1;
int dns_socket = socket(AF_INET, SOCK_DGRAM, 0);
if (dns_socket == -1)
{
perror("socket");
return 1;
}
for (;;)
{
int client = accept(service_socket, nullptr, nullptr);
if (client == -1)
{
perror("accept");
continue;
}
auto query = read_service_query(client);
if (!query.has_value())
continue;
uint16_t id = rand() % 0xFFFF;
if (send_dns_query(dns_socket, *query, id))
{
auto response = read_dns_response(dns_socket, id);
if (response.has_value())
{
if (send(client, response->data(), response->size() + 1, 0) == -1)
perror("send");
close(client);
continue;
}
}
char message[] = "unavailable";
send(client, message, sizeof(message), 0);
close(client);
}
return 0;
}