Compare commits

..

No commits in common. "49889858fa36448d3abeb6b54a74a7ead244db63" and "79897e77dcb7d19aded77b5a2760cc87216cd329" have entirely different histories.

67 changed files with 468 additions and 1920 deletions

View File

@ -2,8 +2,6 @@
#include <BAN/Span.h> #include <BAN/Span.h>
#include <stdint.h>
namespace BAN namespace BAN
{ {

View File

@ -1,31 +0,0 @@
#pragma once
#if __is_kernel
#error "This is userspace only file"
#endif
#include <BAN/Formatter.h>
#include <stdio.h>
#define __debug_putchar [](int c) { putc(c, stddbg); }
#define dprintln(...) \
do { \
BAN::Formatter::print(__debug_putchar, __VA_ARGS__); \
BAN::Formatter::print(__debug_putchar,"\r\n"); \
fflush(stddbg); \
} while (false)
#define dwarnln(...) \
do { \
BAN::Formatter::print(__debug_putchar, "\e[33m"); \
dprintln(__VA_ARGS__); \
BAN::Formatter::print(__debug_putchar, "\e[m"); \
} while(false)
#define derrorln(...) \
do { \
BAN::Formatter::print(__debug_putchar, "\e[31m"); \
dprintln(__VA_ARGS__); \
BAN::Formatter::print(__debug_putchar, "\e[m"); \
} while(false)

View File

@ -90,10 +90,4 @@ namespace BAN
template<integral T> template<integral T>
using NetworkEndian = BigEndian<T>; using NetworkEndian = BigEndian<T>;
template<integral T>
constexpr T host_to_network_endian(T value)
{
return host_to_big_endian(value);
}
} }

View File

@ -7,7 +7,7 @@
namespace BAN namespace BAN
{ {
template<typename Key, typename T, typename HASH = BAN::hash<Key>> template<typename Key, typename T, typename HASH = BAN::hash<Key>, bool STABLE = true>
class HashMap class HashMap
{ {
public: public:
@ -32,12 +32,12 @@ namespace BAN
public: public:
HashMap() = default; HashMap() = default;
HashMap(const HashMap<Key, T, HASH>&); HashMap(const HashMap<Key, T, HASH, STABLE>&);
HashMap(HashMap<Key, T, HASH>&&); HashMap(HashMap<Key, T, HASH, STABLE>&&);
~HashMap(); ~HashMap();
HashMap<Key, T, HASH>& operator=(const HashMap<Key, T, HASH>&); HashMap<Key, T, HASH, STABLE>& operator=(const HashMap<Key, T, HASH, STABLE>&);
HashMap<Key, T, HASH>& operator=(HashMap<Key, T, HASH>&&); HashMap<Key, T, HASH, STABLE>& operator=(HashMap<Key, T, HASH, STABLE>&&);
ErrorOr<void> insert(const Key&, const T&); ErrorOr<void> insert(const Key&, const T&);
ErrorOr<void> insert(const Key&, T&&); ErrorOr<void> insert(const Key&, T&&);
@ -74,26 +74,26 @@ namespace BAN
friend iterator; friend iterator;
}; };
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH>::HashMap(const HashMap<Key, T, HASH>& other) HashMap<Key, T, HASH, STABLE>::HashMap(const HashMap<Key, T, HASH, STABLE>& other)
{ {
*this = other; *this = other;
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH>::HashMap(HashMap<Key, T, HASH>&& other) HashMap<Key, T, HASH, STABLE>::HashMap(HashMap<Key, T, HASH, STABLE>&& other)
{ {
*this = move(other); *this = move(other);
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH>::~HashMap() HashMap<Key, T, HASH, STABLE>::~HashMap()
{ {
clear(); clear();
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH>& HashMap<Key, T, HASH>::operator=(const HashMap<Key, T, HASH>& other) HashMap<Key, T, HASH, STABLE>& HashMap<Key, T, HASH, STABLE>::operator=(const HashMap<Key, T, HASH, STABLE>& other)
{ {
clear(); clear();
m_buckets = other.m_buckets; m_buckets = other.m_buckets;
@ -101,8 +101,8 @@ namespace BAN
return *this; return *this;
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
HashMap<Key, T, HASH>& HashMap<Key, T, HASH>::operator=(HashMap<Key, T, HASH>&& other) HashMap<Key, T, HASH, STABLE>& HashMap<Key, T, HASH, STABLE>::operator=(HashMap<Key, T, HASH, STABLE>&& other)
{ {
clear(); clear();
m_buckets = move(other.m_buckets); m_buckets = move(other.m_buckets);
@ -111,21 +111,21 @@ namespace BAN
return *this; return *this;
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
ErrorOr<void> HashMap<Key, T, HASH>::insert(const Key& key, const T& value) ErrorOr<void> HashMap<Key, T, HASH, STABLE>::insert(const Key& key, const T& value)
{ {
return insert(key, move(T(value))); return insert(key, move(T(value)));
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
ErrorOr<void> HashMap<Key, T, HASH>::insert(const Key& key, T&& value) ErrorOr<void> HashMap<Key, T, HASH, STABLE>::insert(const Key& key, T&& value)
{ {
return emplace(key, move(value)); return emplace(key, move(value));
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
template<typename... Args> template<typename... Args>
ErrorOr<void> HashMap<Key, T, HASH>::emplace(const Key& key, Args&&... args) ErrorOr<void> HashMap<Key, T, HASH, STABLE>::emplace(const Key& key, Args&&... args)
{ {
ASSERT(!contains(key)); ASSERT(!contains(key));
TRY(rebucket(m_size + 1)); TRY(rebucket(m_size + 1));
@ -135,15 +135,15 @@ namespace BAN
return {}; return {};
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
ErrorOr<void> HashMap<Key, T, HASH>::reserve(size_type size) ErrorOr<void> HashMap<Key, T, HASH, STABLE>::reserve(size_type size)
{ {
TRY(rebucket(size)); TRY(rebucket(size));
return {}; return {};
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
void HashMap<Key, T, HASH>::remove(const Key& key) void HashMap<Key, T, HASH, STABLE>::remove(const Key& key)
{ {
if (empty()) return; if (empty()) return;
auto& bucket = get_bucket(key); auto& bucket = get_bucket(key);
@ -158,15 +158,15 @@ namespace BAN
} }
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
void HashMap<Key, T, HASH>::clear() void HashMap<Key, T, HASH, STABLE>::clear()
{ {
m_buckets.clear(); m_buckets.clear();
m_size = 0; m_size = 0;
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
T& HashMap<Key, T, HASH>::operator[](const Key& key) T& HashMap<Key, T, HASH, STABLE>::operator[](const Key& key)
{ {
ASSERT(!empty()); ASSERT(!empty());
auto& bucket = get_bucket(key); auto& bucket = get_bucket(key);
@ -176,8 +176,8 @@ namespace BAN
ASSERT(false); ASSERT(false);
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
const T& HashMap<Key, T, HASH>::operator[](const Key& key) const const T& HashMap<Key, T, HASH, STABLE>::operator[](const Key& key) const
{ {
ASSERT(!empty()); ASSERT(!empty());
const auto& bucket = get_bucket(key); const auto& bucket = get_bucket(key);
@ -187,8 +187,8 @@ namespace BAN
ASSERT(false); ASSERT(false);
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
bool HashMap<Key, T, HASH>::contains(const Key& key) const bool HashMap<Key, T, HASH, STABLE>::contains(const Key& key) const
{ {
if (empty()) return false; if (empty()) return false;
const auto& bucket = get_bucket(key); const auto& bucket = get_bucket(key);
@ -198,20 +198,20 @@ namespace BAN
return false; return false;
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
bool HashMap<Key, T, HASH>::empty() const bool HashMap<Key, T, HASH, STABLE>::empty() const
{ {
return m_size == 0; return m_size == 0;
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
typename HashMap<Key, T, HASH>::size_type HashMap<Key, T, HASH>::size() const typename HashMap<Key, T, HASH, STABLE>::size_type HashMap<Key, T, HASH, STABLE>::size() const
{ {
return m_size; return m_size;
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
ErrorOr<void> HashMap<Key, T, HASH>::rebucket(size_type bucket_count) ErrorOr<void> HashMap<Key, T, HASH, STABLE>::rebucket(size_type bucket_count)
{ {
if (m_buckets.size() >= bucket_count) if (m_buckets.size() >= bucket_count)
return {}; return {};
@ -222,10 +222,13 @@ namespace BAN
for (auto& bucket : m_buckets) for (auto& bucket : m_buckets)
{ {
for (auto it = bucket.begin(); it != bucket.end();) for (Entry& entry : bucket)
{ {
size_type new_bucket_index = HASH()(it->key) % new_buckets.size(); size_type bucket_index = HASH()(entry.key) % new_buckets.size();
it = bucket.move_element_to_other_linked_list(new_buckets[new_bucket_index], new_buckets[new_bucket_index].end(), it); if constexpr(STABLE)
TRY(new_buckets[bucket_index].push_back(entry));
else
TRY(new_buckets[bucket_index].push_back(move(entry)));
} }
} }
@ -233,20 +236,27 @@ namespace BAN
return {}; return {};
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
LinkedList<typename HashMap<Key, T, HASH>::Entry>& HashMap<Key, T, HASH>::get_bucket(const Key& key) LinkedList<typename HashMap<Key, T, HASH, STABLE>::Entry>& HashMap<Key, T, HASH, STABLE>::get_bucket(const Key& key)
{ {
ASSERT(!m_buckets.empty()); ASSERT(!m_buckets.empty());
auto index = HASH()(key) % m_buckets.size(); auto index = HASH()(key) % m_buckets.size();
return m_buckets[index]; return m_buckets[index];
} }
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH, bool STABLE>
const LinkedList<typename HashMap<Key, T, HASH>::Entry>& HashMap<Key, T, HASH>::get_bucket(const Key& key) const const LinkedList<typename HashMap<Key, T, HASH, STABLE>::Entry>& HashMap<Key, T, HASH, STABLE>::get_bucket(const Key& key) const
{ {
ASSERT(!m_buckets.empty()); ASSERT(!m_buckets.empty());
auto index = HASH()(key) % m_buckets.size(); auto index = HASH()(key) % m_buckets.size();
return m_buckets[index]; return m_buckets[index];
} }
// Unstable hash map moves values between container during rebucketing.
// This means that if insertion to map fails, elements could be in invalid state
// and that container is no longer usable. This is better if either way you are
// going to stop using the hash map after insertion fails.
template<typename Key, typename T, typename HASH = BAN::hash<Key>>
using HashMapUnstable = HashMap<Key, T, HASH, false>;
} }

View File

@ -11,7 +11,7 @@
namespace BAN namespace BAN
{ {
template<typename T, typename HASH = hash<T>> template<typename T, typename HASH = hash<T>, bool STABLE = true>
class HashSet class HashSet
{ {
public: public:
@ -55,23 +55,23 @@ namespace BAN
size_type m_size = 0; size_type m_size = 0;
}; };
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
HashSet<T, HASH>::HashSet(const HashSet& other) HashSet<T, HASH, STABLE>::HashSet(const HashSet& other)
: m_buckets(other.m_buckets) : m_buckets(other.m_buckets)
, m_size(other.m_size) , m_size(other.m_size)
{ {
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
HashSet<T, HASH>::HashSet(HashSet&& other) HashSet<T, HASH, STABLE>::HashSet(HashSet&& other)
: m_buckets(move(other.m_buckets)) : m_buckets(move(other.m_buckets))
, m_size(other.m_size) , m_size(other.m_size)
{ {
other.clear(); other.clear();
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
HashSet<T, HASH>& HashSet<T, HASH>::operator=(const HashSet& other) HashSet<T, HASH, STABLE>& HashSet<T, HASH, STABLE>::operator=(const HashSet& other)
{ {
clear(); clear();
m_buckets = other.m_buckets; m_buckets = other.m_buckets;
@ -79,8 +79,8 @@ namespace BAN
return *this; return *this;
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
HashSet<T, HASH>& HashSet<T, HASH>::operator=(HashSet&& other) HashSet<T, HASH, STABLE>& HashSet<T, HASH, STABLE>::operator=(HashSet&& other)
{ {
clear(); clear();
m_buckets = move(other.m_buckets); m_buckets = move(other.m_buckets);
@ -89,14 +89,14 @@ namespace BAN
return *this; return *this;
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
ErrorOr<void> HashSet<T, HASH>::insert(const T& key) ErrorOr<void> HashSet<T, HASH, STABLE>::insert(const T& key)
{ {
return insert(move(T(key))); return insert(move(T(key)));
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
ErrorOr<void> HashSet<T, HASH>::insert(T&& key) ErrorOr<void> HashSet<T, HASH, STABLE>::insert(T&& key)
{ {
if (!empty() && get_bucket(key).contains(key)) if (!empty() && get_bucket(key).contains(key))
return {}; return {};
@ -107,8 +107,8 @@ namespace BAN
return {}; return {};
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
void HashSet<T, HASH>::remove(const T& key) void HashSet<T, HASH, STABLE>::remove(const T& key)
{ {
if (empty()) return; if (empty()) return;
auto& bucket = get_bucket(key); auto& bucket = get_bucket(key);
@ -123,41 +123,41 @@ namespace BAN
} }
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
void HashSet<T, HASH>::clear() void HashSet<T, HASH, STABLE>::clear()
{ {
m_buckets.clear(); m_buckets.clear();
m_size = 0; m_size = 0;
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
ErrorOr<void> HashSet<T, HASH>::reserve(size_type size) ErrorOr<void> HashSet<T, HASH, STABLE>::reserve(size_type size)
{ {
TRY(rebucket(size)); TRY(rebucket(size));
return {}; return {};
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
bool HashSet<T, HASH>::contains(const T& key) const bool HashSet<T, HASH, STABLE>::contains(const T& key) const
{ {
if (empty()) return false; if (empty()) return false;
return get_bucket(key).contains(key); return get_bucket(key).contains(key);
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
typename HashSet<T, HASH>::size_type HashSet<T, HASH>::size() const typename HashSet<T, HASH, STABLE>::size_type HashSet<T, HASH, STABLE>::size() const
{ {
return m_size; return m_size;
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
bool HashSet<T, HASH>::empty() const bool HashSet<T, HASH, STABLE>::empty() const
{ {
return m_size == 0; return m_size == 0;
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
ErrorOr<void> HashSet<T, HASH>::rebucket(size_type bucket_count) ErrorOr<void> HashSet<T, HASH, STABLE>::rebucket(size_type bucket_count)
{ {
if (m_buckets.size() >= bucket_count) if (m_buckets.size() >= bucket_count)
return {}; return {};
@ -169,10 +169,13 @@ namespace BAN
for (auto& bucket : m_buckets) for (auto& bucket : m_buckets)
{ {
for (auto it = bucket.begin(); it != bucket.end();) for (T& key : bucket)
{ {
size_type new_bucket_index = HASH()(*it) % new_buckets.size(); size_type bucket_index = HASH()(key) % new_buckets.size();
it = bucket.move_element_to_other_linked_list(new_buckets[new_bucket_index], new_buckets[new_bucket_index].end(), it); if constexpr(STABLE)
TRY(new_buckets[bucket_index].push_back(key));
else
TRY(new_buckets[bucket_index].push_back(move(key)));
} }
} }
@ -180,20 +183,27 @@ namespace BAN
return {}; return {};
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
LinkedList<T>& HashSet<T, HASH>::get_bucket(const T& key) LinkedList<T>& HashSet<T, HASH, STABLE>::get_bucket(const T& key)
{ {
ASSERT(!m_buckets.empty()); ASSERT(!m_buckets.empty());
size_type index = HASH()(key) % m_buckets.size(); size_type index = HASH()(key) % m_buckets.size();
return m_buckets[index]; return m_buckets[index];
} }
template<typename T, typename HASH> template<typename T, typename HASH, bool STABLE>
const LinkedList<T>& HashSet<T, HASH>::get_bucket(const T& key) const const LinkedList<T>& HashSet<T, HASH, STABLE>::get_bucket(const T& key) const
{ {
ASSERT(!m_buckets.empty()); ASSERT(!m_buckets.empty());
size_type index = HASH()(key) % m_buckets.size(); size_type index = HASH()(key) % m_buckets.size();
return m_buckets[index]; return m_buckets[index];
} }
// Unstable hash set moves values between container during rebucketing.
// This means that if insertion to set fails, elements could be in invalid state
// and that container is no longer usable. This is better if either way you are
// going to stop using the hash set after insertion fails.
template<typename T, typename HASH = hash<T>>
using HashSetUnstable = HashSet<T, HASH, false>;
} }

View File

@ -1,6 +1,5 @@
#pragma once #pragma once
#include <BAN/Endianness.h>
#include <BAN/Formatter.h> #include <BAN/Formatter.h>
#include <BAN/Hash.h> #include <BAN/Hash.h>
@ -11,32 +10,31 @@ namespace BAN
{ {
constexpr IPv4Address(uint32_t u32_address) constexpr IPv4Address(uint32_t u32_address)
{ {
raw = u32_address; address[0] = u32_address >> 24;
address[1] = u32_address >> 16;
address[2] = u32_address >> 8;
address[3] = u32_address >> 0;
} }
constexpr IPv4Address(uint8_t oct1, uint8_t oct2, uint8_t oct3, uint8_t oct4) constexpr uint32_t as_u32() const
{ {
octets[0] = oct1; return
octets[1] = oct2; ((uint32_t)address[0] << 24) |
octets[2] = oct3; ((uint32_t)address[1] << 16) |
octets[3] = oct4; ((uint32_t)address[2] << 8) |
((uint32_t)address[3] << 0);
} }
constexpr bool operator==(const IPv4Address& other) const constexpr bool operator==(const IPv4Address& other) const
{ {
return raw == other.raw; return
address[0] == other.address[0] &&
address[1] == other.address[1] &&
address[2] == other.address[2] &&
address[3] == other.address[3];
} }
constexpr IPv4Address mask(const IPv4Address& other) const uint8_t address[4];
{
return IPv4Address(raw & other.raw);
}
union
{
uint8_t octets[4];
uint32_t raw;
} __attribute__((packed));
}; };
static_assert(sizeof(IPv4Address) == 4); static_assert(sizeof(IPv4Address) == 4);
@ -45,7 +43,7 @@ namespace BAN
{ {
constexpr hash_t operator()(IPv4Address ipv4) const constexpr hash_t operator()(IPv4Address ipv4) const
{ {
return hash<uint32_t>()(ipv4.raw); return hash<uint32_t>()(ipv4.as_u32());
} }
}; };
@ -64,11 +62,11 @@ namespace BAN::Formatter
.upper = false, .upper = false,
}; };
print_argument(putc, ipv4.octets[0], format); print_argument(putc, ipv4.address[0], format);
for (size_t i = 1; i < 4; i++) for (size_t i = 1; i < 4; i++)
{ {
putc('.'); putc('.');
print_argument(putc, ipv4.octets[i], format); print_argument(putc, ipv4.address[i], format);
} }
} }

View File

@ -3,8 +3,6 @@
#include <BAN/Assert.h> #include <BAN/Assert.h>
#include <BAN/Traits.h> #include <BAN/Traits.h>
#include <stddef.h>
namespace BAN namespace BAN
{ {

View File

@ -45,7 +45,6 @@ namespace BAN
void clear(); void clear();
bool empty() const; bool empty() const;
size_type capacity() const;
size_type size() const; size_type size() const;
const T& front() const; const T& front() const;
@ -187,12 +186,6 @@ namespace BAN
return m_size == 0; return m_size == 0;
} }
template<typename T>
typename Queue<T>::size_type Queue<T>::capacity() const
{
return m_capacity;
}
template<typename T> template<typename T>
typename Queue<T>::size_type Queue<T>::size() const typename Queue<T>::size_type Queue<T>::size() const
{ {

View File

@ -216,14 +216,6 @@ namespace BAN
return m_index == detail::index<T, Ts...>(); return m_index == detail::index<T, Ts...>();
} }
template<typename T, typename... Args>
void emplace(Args&&... args) requires (can_have<T>())
{
clear();
m_index = detail::index<T, Ts...>();
new (m_storage) T(BAN::forward<Args>(args)...);
}
template<typename T> template<typename T>
void set(T&& value) requires (can_have<T>() && !is_lvalue_reference_v<T>) void set(T&& value) requires (can_have<T>() && !is_lvalue_reference_v<T>)
{ {

View File

@ -91,8 +91,6 @@ namespace BAN
bool valid() const { return m_link && m_link->valid(); } bool valid() const { return m_link && m_link->valid(); }
operator bool() const { return valid(); }
private: private:
WeakPtr(const RefPtr<WeakLink<T>>& link) WeakPtr(const RefPtr<WeakLink<T>>& link)
: m_link(link) : m_link(link)

View File

@ -16,7 +16,6 @@ set(KERNEL_SOURCES
kernel/CPUID.cpp kernel/CPUID.cpp
kernel/Credentials.cpp kernel/Credentials.cpp
kernel/Debug.cpp kernel/Debug.cpp
kernel/Device/DebugDevice.cpp
kernel/Device/Device.cpp kernel/Device/Device.cpp
kernel/Device/FramebufferDevice.cpp kernel/Device/FramebufferDevice.cpp
kernel/Device/NullDevice.cpp kernel/Device/NullDevice.cpp
@ -53,12 +52,11 @@ set(KERNEL_SOURCES
kernel/Networking/ARPTable.cpp kernel/Networking/ARPTable.cpp
kernel/Networking/E1000/E1000.cpp kernel/Networking/E1000/E1000.cpp
kernel/Networking/E1000/E1000E.cpp kernel/Networking/E1000/E1000E.cpp
kernel/Networking/IPv4Layer.cpp kernel/Networking/IPv4.cpp
kernel/Networking/NetworkInterface.cpp kernel/Networking/NetworkInterface.cpp
kernel/Networking/NetworkManager.cpp kernel/Networking/NetworkManager.cpp
kernel/Networking/NetworkSocket.cpp kernel/Networking/NetworkSocket.cpp
kernel/Networking/UDPSocket.cpp kernel/Networking/UDPSocket.cpp
kernel/Networking/UNIX/Socket.cpp
kernel/OpenFileDescriptorSet.cpp kernel/OpenFileDescriptorSet.cpp
kernel/Panic.cpp kernel/Panic.cpp
kernel/PCI.cpp kernel/PCI.cpp

View File

@ -5,7 +5,7 @@
#define dprintln(...) \ #define dprintln(...) \
do { \ do { \
Debug::DebugLock::lock(); \ Debug::DebugLock::lock(); \
Debug::print_prefix(__FILE__, __LINE__); \ Debug::print_prefix(__FILE__, __LINE__); \
BAN::Formatter::print(Debug::putchar, __VA_ARGS__); \ BAN::Formatter::print(Debug::putchar, __VA_ARGS__); \
BAN::Formatter::print(Debug::putchar, "\r\n"); \ BAN::Formatter::print(Debug::putchar, "\r\n"); \
Debug::DebugLock::unlock(); \ Debug::DebugLock::unlock(); \

View File

@ -1,28 +0,0 @@
#include <kernel/Device/Device.h>
namespace Kernel
{
class DebugDevice : public CharacterDevice
{
public:
static BAN::ErrorOr<BAN::RefPtr<DebugDevice>> create(mode_t, uid_t, gid_t);
virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return "debug"sv; }
protected:
DebugDevice(mode_t mode, uid_t uid, gid_t gid, dev_t rdev)
: CharacterDevice(mode, uid, gid)
, m_rdev(rdev)
{ }
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return 0; }
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan buffer) override;
private:
const dev_t m_rdev;
};
}

View File

@ -100,12 +100,9 @@ namespace Kernel
BAN::ErrorOr<BAN::String> link_target(); BAN::ErrorOr<BAN::String> link_target();
// Socket API // Socket API
BAN::ErrorOr<long> accept(sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<void> bind(const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<void> connect(const sockaddr* address, socklen_t address_len); BAN::ErrorOr<ssize_t> sendto(const sys_sendto_t*);
BAN::ErrorOr<void> listen(int backlog); BAN::ErrorOr<ssize_t> recvfrom(sys_recvfrom_t*);
BAN::ErrorOr<size_t> sendto(const sys_sendto_t*);
BAN::ErrorOr<size_t> recvfrom(sys_recvfrom_t*);
// General API // General API
BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer); BAN::ErrorOr<size_t> read(off_t, BAN::ByteSpan buffer);
@ -131,12 +128,9 @@ namespace Kernel
virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<BAN::String> link_target_impl() { return BAN::Error::from_errno(ENOTSUP); }
// Socket API // Socket API
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> listen_impl(int) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<ssize_t> sendto_impl(const sys_sendto_t*) { return BAN::Error::from_errno(ENOTSUP); }
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<ssize_t> recvfrom_impl(sys_recvfrom_t*) { return BAN::Error::from_errno(ENOTSUP); }
// General API // General API
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); } virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) { return BAN::Error::from_errno(ENOTSUP); }

View File

@ -1,20 +0,0 @@
#pragma once
namespace Kernel
{
enum class SocketDomain
{
INET,
INET6,
UNIX,
};
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
}

View File

@ -80,25 +80,6 @@ namespace Kernel
friend class TmpInode; friend class TmpInode;
}; };
class TmpSocketInode : public TmpInode
{
public:
static BAN::ErrorOr<BAN::RefPtr<TmpSocketInode>> create_new(TmpFileSystem&, mode_t, uid_t, gid_t);
~TmpSocketInode();
protected:
virtual BAN::ErrorOr<size_t> read_impl(off_t, BAN::ByteSpan) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<size_t> write_impl(off_t, BAN::ConstByteSpan) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<void> truncate_impl(size_t) override { return BAN::Error::from_errno(ENODEV); }
virtual BAN::ErrorOr<void> chmod_impl(mode_t) override;
virtual bool has_data_impl() const override { return true; }
private:
TmpSocketInode(TmpFileSystem&, ino_t, const TmpInodeInfo&);
friend class TmpInode;
};
class TmpSymlinkInode : public TmpInode class TmpSymlinkInode : public TmpInode
{ {
public: public:

View File

@ -31,7 +31,6 @@ namespace Kernel
public: public:
static BAN::ErrorOr<BAN::UniqPtr<ARPTable>> create(); static BAN::ErrorOr<BAN::UniqPtr<ARPTable>> create();
~ARPTable();
BAN::ErrorOr<BAN::MACAddress> get_mac_from_ipv4(NetworkInterface&, BAN::IPv4Address); BAN::ErrorOr<BAN::MACAddress> get_mac_from_ipv4(NetworkInterface&, BAN::IPv4Address);

View File

@ -42,7 +42,7 @@ namespace Kernel
uint32_t read32(uint16_t reg); uint32_t read32(uint16_t reg);
void write32(uint16_t reg, uint32_t value); void write32(uint16_t reg, uint32_t value);
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) override; virtual BAN::ErrorOr<void> send_raw_bytes(BAN::ConstByteSpan) override;
private: private:
BAN::ErrorOr<void> read_mac_address(); BAN::ErrorOr<void> read_mac_address();

View File

@ -1,24 +0,0 @@
#pragma once
#include <BAN/Endianness.h>
#include <stdint.h>
namespace Kernel
{
struct ICMPHeader
{
uint8_t type;
uint8_t code;
BAN::NetworkEndian<uint16_t> checksum;
BAN::NetworkEndian<uint32_t> rest;
};
static_assert(sizeof(ICMPHeader) == 8);
enum ICMPType : uint8_t
{
EchoReply = 0x00,
EchoRequest = 0x08,
};
}

View File

@ -0,0 +1,38 @@
#pragma once
#include <BAN/ByteSpan.h>
#include <BAN/Endianness.h>
#include <BAN/IPv4.h>
#include <BAN/Vector.h>
namespace Kernel
{
struct IPv4Header
{
uint8_t version_IHL;
uint8_t DSCP_ECN;
BAN::NetworkEndian<uint16_t> total_length { 0 };
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags_frament { 0 };
uint8_t time_to_live;
uint8_t protocol;
BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address;
BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
return 0xFFFF
- (((uint16_t)version_IHL << 8) | DSCP_ECN)
- total_length
- identification
- flags_frament
- (((uint16_t)time_to_live << 8) | protocol);
}
};
static_assert(sizeof(IPv4Header) == 20);
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol);
}

View File

@ -1,105 +0,0 @@
#pragma once
#include <BAN/Array.h>
#include <BAN/ByteSpan.h>
#include <BAN/CircularQueue.h>
#include <BAN/Endianness.h>
#include <BAN/IPv4.h>
#include <BAN/NoCopyMove.h>
#include <BAN/UniqPtr.h>
#include <kernel/Networking/ARPTable.h>
#include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
#include <kernel/Networking/NetworkSocket.h>
#include <kernel/Process.h>
#include <kernel/SpinLock.h>
namespace Kernel
{
struct IPv4Header
{
uint8_t version_IHL;
uint8_t DSCP_ECN;
BAN::NetworkEndian<uint16_t> total_length { 0 };
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags_frament { 0 };
uint8_t time_to_live;
uint8_t protocol;
BAN::NetworkEndian<uint16_t> checksum { 0 };
BAN::IPv4Address src_address;
BAN::IPv4Address dst_address;
constexpr uint16_t calculate_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
total_sum -= checksum;
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return ~(uint16_t)total_sum;
}
constexpr bool is_valid_checksum() const
{
uint32_t total_sum = 0;
for (size_t i = 0; i < sizeof(IPv4Header) / sizeof(uint16_t); i++)
total_sum += reinterpret_cast<const BAN::NetworkEndian<uint16_t>*>(this)[i];
while (total_sum >> 16)
total_sum = (total_sum >> 16) + (total_sum & 0xFFFF);
return total_sum == 0xFFFF;
}
};
static_assert(sizeof(IPv4Header) == 20);
class IPv4Layer : public NetworkLayer
{
BAN_NON_COPYABLE(IPv4Layer);
BAN_NON_MOVABLE(IPv4Layer);
public:
static BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> create();
~IPv4Layer();
ARPTable& arp_table() { return *m_arp_table; }
void add_ipv4_packet(NetworkInterface&, BAN::ConstByteSpan);
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) override;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) override;
private:
IPv4Layer();
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const;
void packet_handle_task();
BAN::ErrorOr<void> handle_ipv4_packet(NetworkInterface&, BAN::ByteSpan);
private:
struct PendingIPv4Packet
{
NetworkInterface& interface;
};
private:
SpinLock m_lock;
BAN::UniqPtr<ARPTable> m_arp_table;
Process* m_process { nullptr };
static constexpr size_t pending_packet_buffer_size = 128 * PAGE_SIZE;
BAN::UniqPtr<VirtualRange> m_pending_packet_buffer;
BAN::CircularQueue<PendingIPv4Packet, 128> m_pending_packets;
Semaphore m_pending_semaphore;
size_t m_pending_total_size { 0 };
BAN::HashMap<int, BAN::WeakPtr<NetworkSocket>> m_bound_sockets;
friend class BAN::UniqPtr<IPv4Layer>;
};
}

View File

@ -1,10 +1,10 @@
#pragma once #pragma once
#include <BAN/ByteSpan.h>
#include <BAN/Errors.h> #include <BAN/Errors.h>
#include <BAN/IPv4.h> #include <BAN/ByteSpan.h>
#include <BAN/MAC.h> #include <BAN/MAC.h>
#include <kernel/Device/Device.h> #include <kernel/Device/Device.h>
#include <kernel/Networking/IPv4.h>
namespace Kernel namespace Kernel
{ {
@ -46,16 +46,16 @@ namespace Kernel
BAN::IPv4Address get_netmask() const { return m_netmask; } BAN::IPv4Address get_netmask() const { return m_netmask; }
void set_netmask(BAN::IPv4Address new_netmask) { m_netmask = new_netmask; } void set_netmask(BAN::IPv4Address new_netmask) { m_netmask = new_netmask; }
BAN::IPv4Address get_gateway() const { return m_gateway; }
void set_gateway(BAN::IPv4Address new_gateway) { m_gateway = new_gateway; }
virtual bool link_up() = 0; virtual bool link_up() = 0;
virtual int link_speed() = 0; virtual int link_speed() = 0;
size_t interface_header_size() const;
void add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination);
virtual dev_t rdev() const override { return m_rdev; } virtual dev_t rdev() const override { return m_rdev; }
virtual BAN::StringView name() const override { return m_name; } virtual BAN::StringView name() const override { return m_name; }
virtual BAN::ErrorOr<void> send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan) = 0; virtual BAN::ErrorOr<void> send_raw_bytes(BAN::ConstByteSpan) = 0;
private: private:
const Type m_type; const Type m_type;
@ -65,7 +65,6 @@ namespace Kernel
BAN::IPv4Address m_ipv4_address { 0 }; BAN::IPv4Address m_ipv4_address { 0 };
BAN::IPv4Address m_netmask { 0 }; BAN::IPv4Address m_netmask { 0 };
BAN::IPv4Address m_gateway { 0 };
}; };
} }

View File

@ -1,25 +0,0 @@
#pragma once
#include <kernel/Networking/NetworkInterface.h>
namespace Kernel
{
class NetworkSocket;
enum class SocketType;
class NetworkLayer
{
public:
virtual ~NetworkLayer() {}
virtual void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>) = 0;
virtual BAN::ErrorOr<size_t> sendto(NetworkSocket&, const sys_sendto_t*) = 0;
protected:
NetworkLayer() = default;
};
}

View File

@ -2,7 +2,7 @@
#include <BAN/Vector.h> #include <BAN/Vector.h>
#include <kernel/FS/TmpFS/FileSystem.h> #include <kernel/FS/TmpFS/FileSystem.h>
#include <kernel/Networking/IPv4Layer.h> #include <kernel/Networking/ARPTable.h>
#include <kernel/Networking/NetworkInterface.h> #include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkSocket.h> #include <kernel/Networking/NetworkSocket.h>
#include <kernel/PCI.h> #include <kernel/PCI.h>
@ -17,15 +17,26 @@ namespace Kernel
BAN_NON_COPYABLE(NetworkManager); BAN_NON_COPYABLE(NetworkManager);
BAN_NON_MOVABLE(NetworkManager); BAN_NON_MOVABLE(NetworkManager);
public:
enum class SocketType
{
STREAM,
DGRAM,
SEQPACKET,
};
public: public:
static BAN::ErrorOr<void> initialize(); static BAN::ErrorOr<void> initialize();
static NetworkManager& get(); static NetworkManager& get();
ARPTable& arp_table() { return *m_arp_table; }
BAN::ErrorOr<void> add_interface(PCI::Device& pci_device); BAN::ErrorOr<void> add_interface(PCI::Device& pci_device);
BAN::Vector<BAN::RefPtr<NetworkInterface>> interfaces() { return m_interfaces; } void unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>);
BAN::ErrorOr<void> bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket>);
BAN::ErrorOr<BAN::RefPtr<TmpInode>> create_socket(SocketDomain, SocketType, mode_t, uid_t, gid_t); BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> create_socket(SocketType, mode_t, uid_t, gid_t);
void on_receive(NetworkInterface&, BAN::ConstByteSpan); void on_receive(NetworkInterface&, BAN::ConstByteSpan);
@ -33,8 +44,9 @@ namespace Kernel
NetworkManager(); NetworkManager();
private: private:
BAN::UniqPtr<IPv4Layer> m_ipv4_layer; BAN::UniqPtr<ARPTable> m_arp_table;
BAN::Vector<BAN::RefPtr<NetworkInterface>> m_interfaces; BAN::Vector<BAN::RefPtr<NetworkInterface>> m_interfaces;
BAN::HashMap<int, BAN::WeakPtr<NetworkSocket>> m_bound_sockets;
}; };
} }

View File

@ -1,10 +1,8 @@
#pragma once #pragma once
#include <BAN/WeakPtr.h> #include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h> #include <kernel/FS/TmpFS/Inode.h>
#include <kernel/Networking/NetworkInterface.h> #include <kernel/Networking/NetworkInterface.h>
#include <kernel/Networking/NetworkLayer.h>
#include <netinet/in.h> #include <netinet/in.h>
@ -13,7 +11,6 @@ namespace Kernel
enum NetworkProtocol : uint8_t enum NetworkProtocol : uint8_t
{ {
ICMP = 0x01,
UDP = 0x11, UDP = 0x11,
}; };
@ -29,29 +26,26 @@ namespace Kernel
void bind_interface_and_port(NetworkInterface*, uint16_t port); void bind_interface_and_port(NetworkInterface*, uint16_t port);
~NetworkSocket(); ~NetworkSocket();
NetworkInterface& interface() { ASSERT(m_interface); return *m_interface; }
virtual size_t protocol_header_size() const = 0; virtual size_t protocol_header_size() const = 0;
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) = 0; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) = 0;
virtual NetworkProtocol protocol() const = 0; virtual NetworkProtocol protocol() const = 0;
virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0; virtual void add_packet(BAN::ConstByteSpan, BAN::IPv4Address sender_address, uint16_t sender_port) = 0;
protected: protected:
NetworkSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); NetworkSocket(mode_t mode, uid_t uid, gid_t gid);
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0; virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) = 0;
virtual void on_close_impl() override; virtual void on_close_impl() override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override; virtual BAN::ErrorOr<void> bind_impl(const sockaddr* address, socklen_t address_len) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override; virtual BAN::ErrorOr<ssize_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override; virtual BAN::ErrorOr<ssize_t> recvfrom_impl(sys_recvfrom_t*) override;
virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override; virtual BAN::ErrorOr<long> ioctl_impl(int request, void* arg) override;
protected: protected:
NetworkLayer& m_network_layer;
NetworkInterface* m_interface = nullptr; NetworkInterface* m_interface = nullptr;
uint16_t m_port = PORT_NONE; uint16_t m_port = PORT_NONE;
}; };

View File

@ -22,10 +22,10 @@ namespace Kernel
class UDPSocket final : public NetworkSocket class UDPSocket final : public NetworkSocket
{ {
public: public:
static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(NetworkLayer&, ino_t, const TmpInodeInfo&); static BAN::ErrorOr<BAN::RefPtr<UDPSocket>> create(mode_t, uid_t, gid_t);
virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); } virtual size_t protocol_header_size() const override { return sizeof(UDPHeader); }
virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) override; virtual void add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port) override;
virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; } virtual NetworkProtocol protocol() const override { return NetworkProtocol::UDP; }
protected: protected:
@ -33,7 +33,7 @@ namespace Kernel
virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override; virtual BAN::ErrorOr<size_t> read_packet(BAN::ByteSpan, sockaddr_in* sender_address) override;
private: private:
UDPSocket(NetworkLayer&, ino_t, const TmpInodeInfo&); UDPSocket(mode_t, uid_t, gid_t);
struct PacketInfo struct PacketInfo
{ {

View File

@ -1,67 +0,0 @@
#pragma once
#include <BAN/Queue.h>
#include <BAN/WeakPtr.h>
#include <kernel/FS/Socket.h>
#include <kernel/FS/TmpFS/Inode.h>
namespace Kernel
{
class UnixDomainSocket final : public TmpInode, public BAN::Weakable<UnixDomainSocket>
{
BAN_NON_COPYABLE(UnixDomainSocket);
BAN_NON_MOVABLE(UnixDomainSocket);
public:
static BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> create(SocketType, ino_t, const TmpInodeInfo&);
protected:
virtual BAN::ErrorOr<long> accept_impl(sockaddr*, socklen_t*) override;
virtual BAN::ErrorOr<void> connect_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<void> listen_impl(int) override;
virtual BAN::ErrorOr<void> bind_impl(const sockaddr*, socklen_t) override;
virtual BAN::ErrorOr<size_t> sendto_impl(const sys_sendto_t*) override;
virtual BAN::ErrorOr<size_t> recvfrom_impl(sys_recvfrom_t*) override;
private:
UnixDomainSocket(SocketType, ino_t, const TmpInodeInfo&);
BAN::ErrorOr<void> add_packet(BAN::ConstByteSpan);
bool is_bound() const { return !m_bound_path.empty(); }
bool is_bound_to_unused() const { return m_bound_path == "X"sv; }
bool is_streaming() const;
private:
struct ConnectionInfo
{
bool listening { false };
BAN::Atomic<bool> connection_done { false };
BAN::WeakPtr<UnixDomainSocket> connection;
BAN::Queue<BAN::RefPtr<UnixDomainSocket>> pending_connections;
Semaphore pending_semaphore;
SpinLock pending_lock;
};
struct ConnectionlessInfo
{
};
private:
const SocketType m_socket_type;
BAN::String m_bound_path;
BAN::Variant<ConnectionInfo, ConnectionlessInfo> m_info;
BAN::CircularQueue<size_t, 128> m_packet_sizes;
size_t m_packet_size_total { 0 };
BAN::UniqPtr<VirtualRange> m_packet_buffer;
Semaphore m_packet_semaphore;
friend class BAN::RefPtr<UnixDomainSocket>;
};
}

View File

@ -21,7 +21,6 @@ namespace Kernel
BAN::ErrorOr<void> clone_from(const OpenFileDescriptorSet&); BAN::ErrorOr<void> clone_from(const OpenFileDescriptorSet&);
BAN::ErrorOr<int> open(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags); BAN::ErrorOr<int> open(BAN::StringView absolute_path, int flags);
BAN::ErrorOr<int> socket(int domain, int type, int protocol); BAN::ErrorOr<int> socket(int domain, int type, int protocol);

View File

@ -62,8 +62,6 @@ namespace Kernel
bool is_session_leader() const { return pid() == sid(); } bool is_session_leader() const { return pid() == sid(); }
const char* name() const { return m_cmdline.empty() ? "" : m_cmdline.front().data(); }
const Credentials& credentials() const { return m_credentials; } const Credentials& credentials() const { return m_credentials; }
BAN::ErrorOr<long> sys_exit(int status); BAN::ErrorOr<long> sys_exit(int status);
@ -95,10 +93,8 @@ namespace Kernel
BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); } BAN::ErrorOr<long> sys_getegid() const { return m_credentials.egid(); }
BAN::ErrorOr<long> sys_getpgid(pid_t); BAN::ErrorOr<long> sys_getpgid(pid_t);
BAN::ErrorOr<long> open_inode(BAN::RefPtr<Inode>, int flags);
BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode); BAN::ErrorOr<void> create_file_or_dir(BAN::StringView name, mode_t mode);
BAN::ErrorOr<long> open_file(BAN::StringView path, int oflag, mode_t = 0); BAN::ErrorOr<long> open_file(BAN::StringView path, int, mode_t = 0);
BAN::ErrorOr<long> sys_open(const char* path, int, mode_t); BAN::ErrorOr<long> sys_open(const char* path, int, mode_t);
BAN::ErrorOr<long> sys_openat(int, const char* path, int, mode_t); BAN::ErrorOr<long> sys_openat(int, const char* path, int, mode_t);
BAN::ErrorOr<long> sys_close(int fd); BAN::ErrorOr<long> sys_close(int fd);
@ -117,10 +113,7 @@ namespace Kernel
BAN::ErrorOr<long> sys_chown(const char*, uid_t, gid_t); BAN::ErrorOr<long> sys_chown(const char*, uid_t, gid_t);
BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol); BAN::ErrorOr<long> sys_socket(int domain, int type, int protocol);
BAN::ErrorOr<long> sys_accept(int socket, sockaddr* address, socklen_t* address_len);
BAN::ErrorOr<long> sys_bind(int socket, const sockaddr* address, socklen_t address_len); BAN::ErrorOr<long> sys_bind(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_connect(int socket, const sockaddr* address, socklen_t address_len);
BAN::ErrorOr<long> sys_listen(int socket, int backlog);
BAN::ErrorOr<long> sys_sendto(const sys_sendto_t*); BAN::ErrorOr<long> sys_sendto(const sys_sendto_t*);
BAN::ErrorOr<long> sys_recvfrom(sys_recvfrom_t*); BAN::ErrorOr<long> sys_recvfrom(sys_recvfrom_t*);
@ -182,8 +175,6 @@ namespace Kernel
// Return false if access was page violation (segfault) // Return false if access was page violation (segfault)
BAN::ErrorOr<bool> allocate_page_for_demand_paging(vaddr_t addr); BAN::ErrorOr<bool> allocate_page_for_demand_paging(vaddr_t addr);
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
private: private:
Process(const Credentials&, pid_t pid, pid_t parent, pid_t sid, pid_t pgrp); Process(const Credentials&, pid_t pid, pid_t parent, pid_t sid, pid_t pgrp);
static Process* create_process(const Credentials&, pid_t parent, pid_t sid = 0, pid_t pgrp = 0); static Process* create_process(const Credentials&, pid_t parent, pid_t sid = 0, pid_t pgrp = 0);
@ -193,6 +184,8 @@ namespace Kernel
BAN::ErrorOr<int> block_until_exit(pid_t pid); BAN::ErrorOr<int> block_until_exit(pid_t pid);
BAN::ErrorOr<BAN::String> absolute_path_of(BAN::StringView) const;
BAN::ErrorOr<void> validate_string_access(const char*); BAN::ErrorOr<void> validate_string_access(const char*);
BAN::ErrorOr<void> validate_pointer_access(const void*, size_t); BAN::ErrorOr<void> validate_pointer_access(const void*, size_t);

View File

@ -1,32 +0,0 @@
#include <kernel/Device/DebugDevice.h>
#include <kernel/FS/DevFS/FileSystem.h>
#include <kernel/Process.h>
#include <kernel/Timer/Timer.h>
namespace Kernel
{
BAN::ErrorOr<BAN::RefPtr<DebugDevice>> DebugDevice::create(mode_t mode, uid_t uid, gid_t gid)
{
auto* result = new DebugDevice(mode, uid, gid, DevFileSystem::get().get_next_dev());
if (result == nullptr)
return BAN::Error::from_errno(ENOMEM);
return BAN::RefPtr<DebugDevice>::adopt(result);
}
BAN::ErrorOr<size_t> DebugDevice::write_impl(off_t, BAN::ConstByteSpan buffer)
{
auto ms_since_boot = SystemTimer::get().ms_since_boot();
Debug::DebugLock::lock();
BAN::Formatter::print(Debug::putchar, "[{5}.{3}] {}: ",
ms_since_boot / 1000,
ms_since_boot % 1000,
Kernel::Process::current().name()
);
for (size_t i = 0; i < buffer.size(); i++)
Debug::putchar(buffer[i]);
Debug::DebugLock::unlock();
return buffer.size();
}
}

View File

@ -1,5 +1,4 @@
#include <BAN/ScopeGuard.h> #include <BAN/ScopeGuard.h>
#include <kernel/Device/DebugDevice.h>
#include <kernel/Device/FramebufferDevice.h> #include <kernel/Device/FramebufferDevice.h>
#include <kernel/Device/NullDevice.h> #include <kernel/Device/NullDevice.h>
#include <kernel/Device/ZeroDevice.h> #include <kernel/Device/ZeroDevice.h>
@ -23,7 +22,6 @@ namespace Kernel
ASSERT(s_instance); ASSERT(s_instance);
MUST(s_instance->TmpFileSystem::initialize(0755, 0, 0)); MUST(s_instance->TmpFileSystem::initialize(0755, 0, 0));
s_instance->add_device(MUST(DebugDevice::create(0666, 0, 0)));
s_instance->add_device(MUST(NullDevice::create(0666, 0, 0))); s_instance->add_device(MUST(NullDevice::create(0666, 0, 0)));
s_instance->add_device(MUST(ZeroDevice::create(0666, 0, 0))); s_instance->add_device(MUST(ZeroDevice::create(0666, 0, 0)));
} }

View File

@ -116,14 +116,6 @@ namespace Kernel
return link_target_impl(); return link_target_impl();
} }
BAN::ErrorOr<long> Inode::accept(sockaddr* address, socklen_t* address_len)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return accept_impl(address, address_len);
}
BAN::ErrorOr<void> Inode::bind(const sockaddr* address, socklen_t address_len) BAN::ErrorOr<void> Inode::bind(const sockaddr* address, socklen_t address_len)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
@ -132,23 +124,7 @@ namespace Kernel
return bind_impl(address, address_len); return bind_impl(address, address_len);
} }
BAN::ErrorOr<void> Inode::connect(const sockaddr* address, socklen_t address_len) BAN::ErrorOr<ssize_t> Inode::sendto(const sys_sendto_t* arguments)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return connect_impl(address, address_len);
}
BAN::ErrorOr<void> Inode::listen(int backlog)
{
LockGuard _(m_lock);
if (!mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return listen_impl(backlog);
}
BAN::ErrorOr<size_t> Inode::sendto(const sys_sendto_t* arguments)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (!mode().ifsock()) if (!mode().ifsock())
@ -156,7 +132,7 @@ namespace Kernel
return sendto_impl(arguments); return sendto_impl(arguments);
}; };
BAN::ErrorOr<size_t> Inode::recvfrom(sys_recvfrom_t* arguments) BAN::ErrorOr<ssize_t> Inode::recvfrom(sys_recvfrom_t* arguments)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (!mode().ifsock()) if (!mode().ifsock())

View File

@ -215,36 +215,6 @@ namespace Kernel
return {}; return {};
} }
/* SOCKET INODE */
BAN::ErrorOr<BAN::RefPtr<TmpSocketInode>> TmpSocketInode::create_new(TmpFileSystem& fs, mode_t mode, uid_t uid, gid_t gid)
{
auto info = create_inode_info(Mode::IFSOCK | mode, uid, gid);
ino_t ino = TRY(fs.allocate_inode(info));
auto* inode_ptr = new TmpSocketInode(fs, ino, info);
if (inode_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM);
return BAN::RefPtr<TmpSocketInode>::adopt(inode_ptr);
}
TmpSocketInode::TmpSocketInode(TmpFileSystem& fs, ino_t ino, const TmpInodeInfo& info)
: TmpInode(fs, ino, info)
{
ASSERT(mode().ifsock());
}
TmpSocketInode::~TmpSocketInode()
{
}
BAN::ErrorOr<void> TmpSocketInode::chmod_impl(mode_t new_mode)
{
m_inode_info.mode = new_mode;
return {};
}
/* SYMLINK INODE */ /* SYMLINK INODE */
BAN::ErrorOr<BAN::RefPtr<TmpSymlinkInode>> TmpSymlinkInode::create_new(TmpFileSystem& fs, mode_t mode, uid_t uid, gid_t gid, BAN::StringView target) BAN::ErrorOr<BAN::RefPtr<TmpSymlinkInode>> TmpSymlinkInode::create_new(TmpFileSystem& fs, mode_t mode, uid_t uid, gid_t gid, BAN::StringView target)
@ -476,19 +446,7 @@ namespace Kernel
BAN::ErrorOr<void> TmpDirectoryInode::create_file_impl(BAN::StringView name, mode_t mode, uid_t uid, gid_t gid) BAN::ErrorOr<void> TmpDirectoryInode::create_file_impl(BAN::StringView name, mode_t mode, uid_t uid, gid_t gid)
{ {
BAN::RefPtr<TmpInode> new_inode; auto new_inode = TRY(TmpFileInode::create_new(m_fs, mode, uid, gid));
switch (mode & Mode::TYPE_MASK)
{
case Mode::IFREG:
new_inode = TRY(TmpFileInode::create_new(m_fs, mode, uid, gid));
break;
case Mode::IFSOCK:
new_inode = TRY(TmpSocketInode::create_new(m_fs, mode, uid, gid));
break;
default:
dprintln("Creating with mode {o} is not supported", mode);
return BAN::Error::from_errno(ENOTSUP);
}
TRY(link_inode(*new_inode, name)); TRY(link_inode(*new_inode, name));
return {}; return {};
} }

View File

@ -294,9 +294,6 @@ void* kmalloc(size_t size, size_t align, bool force_identity_map)
// currently kmalloc is always identity mapped // currently kmalloc is always identity mapped
(void)force_identity_map; (void)force_identity_map;
if (size == 0)
size = 1;
const kmalloc_info& info = s_kmalloc_info; const kmalloc_info& info = s_kmalloc_info;
ASSERT(is_power_of_two(align)); ASSERT(is_power_of_two(align));

View File

@ -1,6 +1,5 @@
#include <kernel/LockGuard.h> #include <kernel/LockGuard.h>
#include <kernel/Networking/ARPTable.h> #include <kernel/Networking/ARPTable.h>
#include <kernel/Scheduler.h>
#include <kernel/Timer/Timer.h> #include <kernel/Timer/Timer.h>
namespace Kernel namespace Kernel
@ -33,31 +32,26 @@ namespace Kernel
{ {
} }
ARPTable::~ARPTable()
{
if (m_process)
m_process->exit(0, SIGKILL);
m_process = nullptr;
}
BAN::ErrorOr<BAN::MACAddress> ARPTable::get_mac_from_ipv4(NetworkInterface& interface, BAN::IPv4Address ipv4_address) BAN::ErrorOr<BAN::MACAddress> ARPTable::get_mac_from_ipv4(NetworkInterface& interface, BAN::IPv4Address ipv4_address)
{ {
if (ipv4_address == s_broadcast_ipv4)
return s_broadcast_mac;
if (interface.get_ipv4_address() == BAN::IPv4Address { 0 })
return BAN::Error::from_errno(EINVAL);
if (interface.get_ipv4_address().mask(interface.get_netmask()) != ipv4_address.mask(interface.get_netmask()))
ipv4_address = interface.get_gateway();
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
if (ipv4_address == s_broadcast_ipv4)
return s_broadcast_mac;
if (m_arp_table.contains(ipv4_address)) if (m_arp_table.contains(ipv4_address))
return m_arp_table[ipv4_address]; return m_arp_table[ipv4_address];
} }
ARPPacket arp_request; BAN::Vector<uint8_t> full_packet_buffer;
TRY(full_packet_buffer.resize(sizeof(ARPPacket) + sizeof(EthernetHeader)));
auto full_packet = BAN::ByteSpan { full_packet_buffer.span() };
auto& ethernet_header = full_packet.as<EthernetHeader>();
ethernet_header.dst_mac = s_broadcast_mac;
ethernet_header.src_mac = interface.get_mac_address();
ethernet_header.ether_type = EtherType::ARP;
auto& arp_request = full_packet.slice(sizeof(EthernetHeader)).as<ARPPacket>();
arp_request.htype = 0x0001; arp_request.htype = 0x0001;
arp_request.ptype = EtherType::IPv4; arp_request.ptype = EtherType::IPv4;
arp_request.hlen = 0x06; arp_request.hlen = 0x06;
@ -68,9 +62,9 @@ namespace Kernel
arp_request.tha = {{ 0, 0, 0, 0, 0, 0 }}; arp_request.tha = {{ 0, 0, 0, 0, 0, 0 }};
arp_request.tpa = ipv4_address; arp_request.tpa = ipv4_address;
TRY(interface.send_bytes(s_broadcast_mac, EtherType::ARP, BAN::ConstByteSpan::from(arp_request))); TRY(interface.send_raw_bytes(full_packet));
uint64_t timeout = SystemTimer::get().ms_since_boot() + 1'000; uint64_t timeout = SystemTimer::get().ms_since_boot() + 5'000;
while (SystemTimer::get().ms_since_boot() < timeout) while (SystemTimer::get().ms_since_boot() < timeout)
{ {
{ {
@ -78,10 +72,10 @@ namespace Kernel
if (m_arp_table.contains(ipv4_address)) if (m_arp_table.contains(ipv4_address))
return m_arp_table[ipv4_address]; return m_arp_table[ipv4_address];
} }
Scheduler::get().reschedule(); TRY(Thread::current().block_or_eintr(m_pending_semaphore));
} }
return BAN::Error::from_errno(ETIMEDOUT); return BAN::Error::from_errno(EINVAL);
} }
BAN::ErrorOr<void> ARPTable::handle_arp_packet(NetworkInterface& interface, const ARPPacket& packet) BAN::ErrorOr<void> ARPTable::handle_arp_packet(NetworkInterface& interface, const ARPPacket& packet)
@ -98,17 +92,27 @@ namespace Kernel
{ {
if (packet.tpa == interface.get_ipv4_address()) if (packet.tpa == interface.get_ipv4_address())
{ {
ARPPacket arp_reply; BAN::Vector<uint8_t> full_packet_buffer;
arp_reply.htype = 0x0001; TRY(full_packet_buffer.resize(sizeof(ARPPacket) + sizeof(EthernetHeader)));
arp_reply.ptype = EtherType::IPv4; auto full_packet = BAN::ByteSpan { full_packet_buffer.span() };
arp_reply.hlen = 0x06;
arp_reply.plen = 0x04; auto& ethernet_header = full_packet.as<EthernetHeader>();
arp_reply.oper = ARPOperation::Reply; ethernet_header.dst_mac = packet.sha;
arp_reply.sha = interface.get_mac_address(); ethernet_header.src_mac = interface.get_mac_address();
arp_reply.spa = interface.get_ipv4_address(); ethernet_header.ether_type = EtherType::ARP;
arp_reply.tha = packet.sha;
arp_reply.tpa = packet.spa; auto& arp_request = full_packet.slice(sizeof(EthernetHeader)).as<ARPPacket>();
TRY(interface.send_bytes(packet.sha, EtherType::ARP, BAN::ConstByteSpan::from(arp_reply))); arp_request.htype = 0x0001;
arp_request.ptype = EtherType::IPv4;
arp_request.hlen = 0x06;
arp_request.plen = 0x04;
arp_request.oper = ARPOperation::Reply;
arp_request.sha = interface.get_mac_address();
arp_request.spa = interface.get_ipv4_address();
arp_request.tha = packet.sha;
arp_request.tpa = packet.spa;
TRY(interface.send_raw_bytes(full_packet));
} }
break; break;
} }

View File

@ -256,26 +256,19 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<void> E1000::send_raw_bytes(BAN::ConstByteSpan buffer)
BAN::ErrorOr<void> E1000::send_bytes(BAN::MACAddress destination, EtherType protocol, BAN::ConstByteSpan buffer)
{ {
ASSERT_LTE(buffer.size() + sizeof(EthernetHeader), E1000_TX_BUFFER_SIZE); ASSERT_LTE(buffer.size(), E1000_TX_BUFFER_SIZE);
CriticalScope _; CriticalScope _;
size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT; size_t tx_current = read32(REG_TDT) % E1000_TX_DESCRIPTOR_COUNT;
auto* tx_buffer = reinterpret_cast<uint8_t*>(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * tx_current); auto* tx_buffer = reinterpret_cast<void*>(m_tx_buffer_region->vaddr() + E1000_TX_BUFFER_SIZE * tx_current);
memcpy(tx_buffer, buffer.data(), buffer.size());
auto& ethernet_header = *reinterpret_cast<EthernetHeader*>(tx_buffer);
ethernet_header.dst_mac = destination;
ethernet_header.src_mac = get_mac_address();
ethernet_header.ether_type = protocol;
memcpy(tx_buffer + sizeof(EthernetHeader), buffer.data(), buffer.size());
auto& descriptor = reinterpret_cast<volatile e1000_tx_desc*>(m_tx_descriptor_region->vaddr())[tx_current]; auto& descriptor = reinterpret_cast<volatile e1000_tx_desc*>(m_tx_descriptor_region->vaddr())[tx_current];
descriptor.length = sizeof(EthernetHeader) + buffer.size(); descriptor.length = buffer.size();
descriptor.status = 0; descriptor.status = 0;
descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS; descriptor.cmd = CMD_EOP | CMD_IFCS | CMD_RS;

View File

@ -0,0 +1,22 @@
#include <BAN/Endianness.h>
#include <kernel/Networking/IPv4.h>
namespace Kernel
{
void add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol)
{
auto& header = packet.as<IPv4Header>();
header.version_IHL = 0x45;
header.DSCP_ECN = 0x10;
header.total_length = packet.size();
header.identification = 1;
header.flags_frament = 0x00;
header.time_to_live = 0x40;
header.protocol = protocol;
header.checksum = header.calculate_checksum();
header.src_address = src_ipv4;
header.dst_address = dst_ipv4;
}
}

View File

@ -1,284 +0,0 @@
#include <kernel/Memory/Heap.h>
#include <kernel/Memory/PageTable.h>
#include <kernel/Networking/ICMP.h>
#include <kernel/Networking/IPv4Layer.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UDPSocket.h>
#include <netinet/in.h>
#define DEBUG_IPV4 0
namespace Kernel
{
BAN::ErrorOr<BAN::UniqPtr<IPv4Layer>> IPv4Layer::create()
{
auto ipv4_manager = TRY(BAN::UniqPtr<IPv4Layer>::create());
ipv4_manager->m_process = Process::create_kernel(
[](void* ipv4_manager_ptr)
{
auto& ipv4_manager = *reinterpret_cast<IPv4Layer*>(ipv4_manager_ptr);
ipv4_manager.packet_handle_task();
}, ipv4_manager.ptr()
);
ASSERT(ipv4_manager->m_process);
ipv4_manager->m_pending_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
pending_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
ipv4_manager->m_arp_table = TRY(ARPTable::create());
return ipv4_manager;
}
IPv4Layer::IPv4Layer()
{ }
IPv4Layer::~IPv4Layer()
{
if (m_process)
m_process->exit(0, SIGKILL);
m_process = nullptr;
}
void IPv4Layer::add_ipv4_header(BAN::ByteSpan packet, BAN::IPv4Address src_ipv4, BAN::IPv4Address dst_ipv4, uint8_t protocol) const
{
auto& header = packet.as<IPv4Header>();
header.version_IHL = 0x45;
header.DSCP_ECN = 0x00;
header.total_length = packet.size();
header.identification = 1;
header.flags_frament = 0x00;
header.time_to_live = 0x40;
header.protocol = protocol;
header.src_address = src_ipv4;
header.dst_address = dst_ipv4;
header.checksum = header.calculate_checksum();
}
void IPv4Layer::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
LockGuard _(m_lock);
if (m_bound_sockets.contains(port))
{
ASSERT(m_bound_sockets[port].valid());
ASSERT(m_bound_sockets[port].lock() == socket);
m_bound_sockets.remove(port);
}
NetworkManager::get().TmpFileSystem::remove_from_cache(socket);
}
BAN::ErrorOr<void> IPv4Layer::bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
if (NetworkManager::get().interfaces().empty())
return BAN::Error::from_errno(EADDRNOTAVAIL);
LockGuard _(m_lock);
if (port == NetworkSocket::PORT_NONE)
{
for (uint32_t temp = 0xC000; temp < 0xFFFF; temp++)
{
if (!m_bound_sockets.contains(temp))
{
port = temp;
break;
}
}
if (port == NetworkSocket::PORT_NONE)
{
dwarnln("No ports available");
return BAN::Error::from_errno(EAGAIN);
}
}
if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket));
// FIXME: actually determine proper interface
auto interface = NetworkManager::get().interfaces().front();
socket->bind_interface_and_port(interface.ptr(), port);
return {};
}
BAN::ErrorOr<size_t> IPv4Layer::sendto(NetworkSocket& socket, const sys_sendto_t* arguments)
{
if (arguments->dest_addr->sa_family != AF_INET)
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_in = *reinterpret_cast<const struct sockaddr_in*>(arguments->dest_addr);
auto dst_port = BAN::host_to_network_endian(sockaddr_in.sin_port);
auto dst_ipv4 = BAN::IPv4Address { sockaddr_in.sin_addr.s_addr };
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(socket.interface(), dst_ipv4));
BAN::Vector<uint8_t> packet_buffer;
TRY(packet_buffer.resize(arguments->length + sizeof(IPv4Header) + socket.protocol_header_size()));
auto packet = BAN::ByteSpan { packet_buffer.span() };
memcpy(
packet.slice(sizeof(IPv4Header)).slice(socket.protocol_header_size()).data(),
arguments->message,
arguments->length
);
socket.add_protocol_header(
packet.slice(sizeof(IPv4Header)),
dst_port
);
add_ipv4_header(
packet,
socket.interface().get_ipv4_address(),
dst_ipv4,
socket.protocol()
);
TRY(socket.interface().send_bytes(dst_mac, EtherType::IPv4, packet));
return arguments->length;
}
static uint16_t calculate_internet_checksum(BAN::ConstByteSpan packet)
{
uint32_t checksum = 0;
for (size_t i = 0; i < packet.size() / sizeof(uint16_t); i++)
checksum += BAN::host_to_network_endian(reinterpret_cast<const uint16_t*>(packet.data())[i]);
while (checksum >> 16)
checksum = (checksum >> 16) | (checksum & 0xFFFF);
return ~(uint16_t)checksum;
}
BAN::ErrorOr<void> IPv4Layer::handle_ipv4_packet(NetworkInterface& interface, BAN::ByteSpan packet)
{
auto& ipv4_header = packet.as<const IPv4Header>();
auto ipv4_data = packet.slice(sizeof(IPv4Header));
ASSERT(ipv4_header.is_valid_checksum());
auto src_ipv4 = ipv4_header.src_address;
switch (ipv4_header.protocol)
{
case NetworkProtocol::ICMP:
{
auto& icmp_header = ipv4_data.as<const ICMPHeader>();
switch (icmp_header.type)
{
case ICMPType::EchoRequest:
{
auto dst_mac = TRY(m_arp_table->get_mac_from_ipv4(interface, src_ipv4));
auto& reply_icmp_header = ipv4_data.as<ICMPHeader>();
reply_icmp_header.type = ICMPType::EchoReply;
reply_icmp_header.checksum = 0;
reply_icmp_header.checksum = calculate_internet_checksum(ipv4_data);
add_ipv4_header(packet, interface.get_ipv4_address(), src_ipv4, NetworkProtocol::ICMP);
TRY(interface.send_bytes(dst_mac, EtherType::IPv4, packet));
break;
}
default:
dprintln("Unhandleded ICMP packet (type {2H})", icmp_header.type);
break;
}
break;
}
case NetworkProtocol::UDP:
{
auto& udp_header = ipv4_data.as<const UDPHeader>();
uint16_t src_port = udp_header.src_port;
uint16_t dst_port = udp_header.dst_port;
LockGuard _(m_lock);
if (!m_bound_sockets.contains(dst_port) || !m_bound_sockets[dst_port].valid())
{
dprintln_if(DEBUG_IPV4, "no one is listening on port {}", dst_port);
return {};
}
auto udp_data = ipv4_data.slice(sizeof(UDPHeader));
m_bound_sockets[dst_port].lock()->add_packet(udp_data, src_ipv4, src_port);
break;
}
default:
dprintln_if(DEBUG_IPV4, "Unknown network protocol 0x{2H}", ipv4_header.protocol);
break;
}
return {};
}
void IPv4Layer::packet_handle_task()
{
for (;;)
{
BAN::Optional<PendingIPv4Packet> pending;
{
CriticalScope _;
if (!m_pending_packets.empty())
{
pending = m_pending_packets.front();
m_pending_packets.pop();
}
}
if (!pending.has_value())
{
m_pending_semaphore.block();
continue;
}
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
const size_t ipv4_packet_size = reinterpret_cast<const IPv4Header*>(buffer_start)->total_length;
if (auto ret = handle_ipv4_packet(pending->interface, BAN::ByteSpan(buffer_start, ipv4_packet_size)); ret.is_error())
dwarnln("{}", ret.error());
CriticalScope _;
m_pending_total_size -= ipv4_packet_size;
if (m_pending_total_size)
memmove(buffer_start, buffer_start + ipv4_packet_size, m_pending_total_size);
}
}
void IPv4Layer::add_ipv4_packet(NetworkInterface& interface, BAN::ConstByteSpan buffer)
{
if (m_pending_packets.full())
{
dwarnln("IPv4 packet queue full");
return;
}
if (m_pending_total_size + buffer.size() > m_pending_packet_buffer->size())
{
dwarnln("IPv4 packet queue full");
return;
}
auto& ipv4_header = buffer.as<const IPv4Header>();
if (!ipv4_header.is_valid_checksum())
{
dwarnln("Invalid IPv4 packet");
return;
}
if (ipv4_header.total_length > buffer.size())
{
dwarnln("Too short IPv4 packet");
return;
}
uint8_t* buffer_start = reinterpret_cast<uint8_t*>(m_pending_packet_buffer->vaddr());
memcpy(buffer_start + m_pending_total_size, buffer.data(), ipv4_header.total_length);
m_pending_total_size += ipv4_header.total_length;
m_pending_packets.push({ .interface = interface });
m_pending_semaphore.unblock();
}
}

View File

@ -32,4 +32,19 @@ namespace Kernel
m_name[3] = minor(m_rdev) + '0'; m_name[3] = minor(m_rdev) + '0';
} }
size_t NetworkInterface::interface_header_size() const
{
ASSERT(m_type == Type::Ethernet);
return sizeof(EthernetHeader);
}
void NetworkInterface::add_interface_header(BAN::ByteSpan packet, BAN::MACAddress destination)
{
ASSERT(m_type == Type::Ethernet);
auto& header = packet.as<EthernetHeader>();
header.dst_mac = destination;
header.src_mac = get_mac_address();
header.ether_type = 0x0800;
}
} }

View File

@ -3,12 +3,9 @@
#include <kernel/FS/DevFS/FileSystem.h> #include <kernel/FS/DevFS/FileSystem.h>
#include <kernel/Networking/E1000/E1000.h> #include <kernel/Networking/E1000/E1000.h>
#include <kernel/Networking/E1000/E1000E.h> #include <kernel/Networking/E1000/E1000E.h>
#include <kernel/Networking/ICMP.h> #include <kernel/Networking/IPv4.h>
#include <kernel/Networking/NetworkManager.h> #include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UDPSocket.h> #include <kernel/Networking/UDPSocket.h>
#include <kernel/Networking/UNIX/Socket.h>
#define DEBUG_ETHERTYPE 0
namespace Kernel namespace Kernel
{ {
@ -22,8 +19,8 @@ namespace Kernel
if (manager_ptr == nullptr) if (manager_ptr == nullptr)
return BAN::Error::from_errno(ENOMEM); return BAN::Error::from_errno(ENOMEM);
auto manager = BAN::UniqPtr<NetworkManager>::adopt(manager_ptr); auto manager = BAN::UniqPtr<NetworkManager>::adopt(manager_ptr);
manager->m_arp_table = TRY(ARPTable::create());
TRY(manager->TmpFileSystem::initialize(0777, 0, 0)); TRY(manager->TmpFileSystem::initialize(0777, 0, 0));
manager->m_ipv4_layer = TRY(IPv4Layer::create());
s_instance = BAN::move(manager); s_instance = BAN::move(manager);
return {}; return {};
} }
@ -71,50 +68,45 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<BAN::RefPtr<TmpInode>> NetworkManager::create_socket(SocketDomain domain, SocketType type, mode_t mode, uid_t uid, gid_t gid) BAN::ErrorOr<BAN::RefPtr<NetworkSocket>> NetworkManager::create_socket(SocketType type, mode_t mode, uid_t uid, gid_t gid)
{ {
switch (domain)
{
case SocketDomain::INET:
{
if (type != SocketType::DGRAM)
return BAN::Error::from_errno(EPROTOTYPE);
break;
}
case SocketDomain::UNIX:
{
break;
}
default:
return BAN::Error::from_errno(EAFNOSUPPORT);
}
ASSERT((mode & Inode::Mode::TYPE_MASK) == 0); ASSERT((mode & Inode::Mode::TYPE_MASK) == 0);
mode |= Inode::Mode::IFSOCK;
auto inode_info = create_inode_info(mode, uid, gid); if (type != SocketType::DGRAM)
ino_t ino = TRY(allocate_inode(inode_info)); return BAN::Error::from_errno(EPROTOTYPE);
BAN::RefPtr<TmpInode> socket; auto udp_socket = TRY(UDPSocket::create(mode | Inode::Mode::IFSOCK, uid, gid));
switch (domain) return BAN::RefPtr<NetworkSocket>(udp_socket);
}
void NetworkManager::unbind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
{
if (m_bound_sockets.contains(port))
{ {
case SocketDomain::INET: ASSERT(m_bound_sockets[port].valid());
{ ASSERT(m_bound_sockets[port].lock() == socket);
if (type == SocketType::DGRAM) m_bound_sockets.remove(port);
socket = TRY(UDPSocket::create(*m_ipv4_layer, ino, inode_info)); }
break; NetworkManager::get().remove_from_cache(socket);
} }
case SocketDomain::UNIX:
{ BAN::ErrorOr<void> NetworkManager::bind_socket(uint16_t port, BAN::RefPtr<NetworkSocket> socket)
socket = TRY(UnixDomainSocket::create(type, ino, inode_info)); {
break; if (m_interfaces.empty())
} return BAN::Error::from_errno(EADDRNOTAVAIL);
default:
ASSERT_NOT_REACHED(); if (port != NetworkSocket::PORT_NONE)
{
if (m_bound_sockets.contains(port))
return BAN::Error::from_errno(EADDRINUSE);
TRY(m_bound_sockets.insert(port, socket));
} }
ASSERT(socket); // FIXME: actually determine proper interface
return socket; auto interface = m_interfaces.front();
socket->bind_interface_and_port(interface.ptr(), port);
return {};
} }
void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet) void NetworkManager::on_receive(NetworkInterface& interface, BAN::ConstByteSpan packet)
@ -125,16 +117,41 @@ namespace Kernel
{ {
case EtherType::ARP: case EtherType::ARP:
{ {
m_ipv4_layer->arp_table().add_arp_packet(interface, packet.slice(sizeof(EthernetHeader))); m_arp_table->add_arp_packet(interface, packet.slice(sizeof(EthernetHeader)));
break; break;
} }
case EtherType::IPv4: case EtherType::IPv4:
{ {
m_ipv4_layer->add_ipv4_packet(interface, packet.slice(sizeof(EthernetHeader))); auto ipv4 = packet.slice(sizeof(EthernetHeader));
auto& ipv4_header = ipv4.as<const IPv4Header>();
auto src_ipv4 = ipv4_header.src_address;
switch (ipv4_header.protocol)
{
case NetworkProtocol::UDP:
{
auto udp = ipv4.slice(sizeof(IPv4Header));
auto& udp_header = udp.as<const UDPHeader>();
uint16_t src_port = udp_header.src_port;
uint16_t dst_port = udp_header.dst_port;
if (!m_bound_sockets.contains(dst_port))
{
dprintln("no one is listening on port {}", dst_port);
return;
}
auto raw = udp.slice(8);
m_bound_sockets[dst_port].lock()->add_packet(raw, src_ipv4, src_port);
break;
}
default:
dprintln("Unknown network protocol 0x{2H}", ipv4_header.protocol);
break;
}
break; break;
} }
default: default:
dprintln_if(DEBUG_ETHERTYPE, "Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type); dprintln("Unknown EtherType 0x{4H}", (uint16_t)ethernet_header.ether_type);
break; break;
} }
} }

View File

@ -1,3 +1,4 @@
#include <kernel/Networking/IPv4.h>
#include <kernel/Networking/NetworkManager.h> #include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/NetworkSocket.h> #include <kernel/Networking/NetworkSocket.h>
@ -6,9 +7,13 @@
namespace Kernel namespace Kernel
{ {
NetworkSocket::NetworkSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) NetworkSocket::NetworkSocket(mode_t mode, uid_t uid, gid_t gid)
: TmpInode(NetworkManager::get(), ino, inode_info) // FIXME: what the fuck is this
, m_network_layer(network_layer) : TmpInode(
NetworkManager::get(),
MUST(NetworkManager::get().allocate_inode(create_inode_info(mode, uid, gid))),
create_inode_info(mode, uid, gid)
)
{ } { }
NetworkSocket::~NetworkSocket() NetworkSocket::~NetworkSocket()
@ -18,7 +23,7 @@ namespace Kernel
void NetworkSocket::on_close_impl() void NetworkSocket::on_close_impl()
{ {
if (m_interface) if (m_interface)
m_network_layer.unbind_socket(m_port, this); NetworkManager::get().unbind_socket(m_port, this);
} }
void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port) void NetworkSocket::bind_interface_and_port(NetworkInterface* interface, uint16_t port)
@ -31,15 +36,16 @@ namespace Kernel
BAN::ErrorOr<void> NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len) BAN::ErrorOr<void> NetworkSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{ {
if (m_interface || address_len != sizeof(sockaddr_in)) if (address_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL); return BAN::Error::from_errno(EINVAL);
auto* addr_in = reinterpret_cast<const sockaddr_in*>(address); auto* addr_in = reinterpret_cast<const sockaddr_in*>(address);
uint16_t dst_port = BAN::host_to_network_endian(addr_in->sin_port); return NetworkManager::get().bind_socket(addr_in->sin_port, this);
return m_network_layer.bind_socket(dst_port, this);
} }
BAN::ErrorOr<size_t> NetworkSocket::sendto_impl(const sys_sendto_t* arguments) BAN::ErrorOr<ssize_t> NetworkSocket::sendto_impl(const sys_sendto_t* arguments)
{ {
if (arguments->dest_len != sizeof(sockaddr_in))
return BAN::Error::from_errno(EINVAL);
if (arguments->flags) if (arguments->flags)
{ {
dprintln("flags not supported"); dprintln("flags not supported");
@ -47,12 +53,45 @@ namespace Kernel
} }
if (!m_interface) if (!m_interface)
TRY(m_network_layer.bind_socket(PORT_NONE, this)); TRY(NetworkManager::get().bind_socket(PORT_NONE, this));
return TRY(m_network_layer.sendto(*this, arguments)); auto* destination = reinterpret_cast<const sockaddr_in*>(arguments->dest_addr);
auto message = BAN::ConstByteSpan((const uint8_t*)arguments->message, arguments->length);
uint16_t dst_port = destination->sin_port;
if (dst_port == PORT_NONE)
return BAN::Error::from_errno(EINVAL);
auto dst_addr = BAN::IPv4Address(destination->sin_addr.s_addr);
auto dst_mac = TRY(NetworkManager::get().arp_table().get_mac_from_ipv4(*m_interface, dst_addr));
const size_t interface_header_offset = 0;
const size_t interface_header_size = m_interface->interface_header_size();
const size_t ipv4_header_offset = interface_header_offset + interface_header_size;
const size_t ipv4_header_size = sizeof(IPv4Header);
const size_t protocol_header_offset = ipv4_header_offset + ipv4_header_size;
const size_t protocol_header_size = this->protocol_header_size();
const size_t payload_offset = protocol_header_offset + protocol_header_size;
const size_t payload_size = message.size();
BAN::Vector<uint8_t> full_packet;
TRY(full_packet.resize(payload_offset + payload_size));
BAN::ByteSpan packet_bytespan { full_packet.span() };
memcpy(full_packet.data() + payload_offset, message.data(), payload_size);
add_protocol_header(packet_bytespan.slice(protocol_header_offset), m_port, dst_port);
add_ipv4_header(packet_bytespan.slice(ipv4_header_offset), m_interface->get_ipv4_address(), dst_addr, protocol());
m_interface->add_interface_header(packet_bytespan.slice(interface_header_offset), dst_mac);
TRY(m_interface->send_raw_bytes(packet_bytespan));
return arguments->length;
} }
BAN::ErrorOr<size_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments) BAN::ErrorOr<ssize_t> NetworkSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{ {
sockaddr_in* sender_addr = nullptr; sockaddr_in* sender_addr = nullptr;
if (arguments->address) if (arguments->address)
@ -101,52 +140,36 @@ namespace Kernel
{ {
case SIOCGIFADDR: case SIOCGIFADDR:
{ {
auto& ifru_addr = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_addr); auto ipv4_address = m_interface->get_ipv4_address();
ifru_addr.sin_family = AF_INET; ifreq->ifr_ifru.ifru_addr.sa_family = AF_INET;
ifru_addr.sin_addr.s_addr = m_interface->get_ipv4_address().raw; memcpy(ifreq->ifr_ifru.ifru_addr.sa_data, &ipv4_address, sizeof(ipv4_address));
return 0; return 0;
} }
case SIOCSIFADDR: case SIOCSIFADDR:
{ {
auto& ifru_addr = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_addr); if (ifreq->ifr_ifru.ifru_addr.sa_family != AF_INET)
if (ifru_addr.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL); return BAN::Error::from_errno(EADDRNOTAVAIL);
m_interface->set_ipv4_address(BAN::IPv4Address { ifru_addr.sin_addr.s_addr }); BAN::IPv4Address ipv4_address { *reinterpret_cast<uint32_t*>(ifreq->ifr_ifru.ifru_addr.sa_data) };
m_interface->set_ipv4_address(ipv4_address);
dprintln("IPv4 address set to {}", m_interface->get_ipv4_address()); dprintln("IPv4 address set to {}", m_interface->get_ipv4_address());
return 0; return 0;
} }
case SIOCGIFNETMASK: case SIOCGIFNETMASK:
{ {
auto& ifru_netmask = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_netmask); auto netmask_address = m_interface->get_netmask();
ifru_netmask.sin_family = AF_INET; ifreq->ifr_ifru.ifru_netmask.sa_family = AF_INET;
ifru_netmask.sin_addr.s_addr = m_interface->get_netmask().raw; memcpy(ifreq->ifr_ifru.ifru_netmask.sa_data, &netmask_address, sizeof(netmask_address));
return 0; return 0;
} }
case SIOCSIFNETMASK: case SIOCSIFNETMASK:
{ {
auto& ifru_netmask = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_netmask); if (ifreq->ifr_ifru.ifru_netmask.sa_family != AF_INET)
if (ifru_netmask.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL); return BAN::Error::from_errno(EADDRNOTAVAIL);
m_interface->set_netmask(BAN::IPv4Address { ifru_netmask.sin_addr.s_addr }); BAN::IPv4Address netmask { *reinterpret_cast<uint32_t*>(ifreq->ifr_ifru.ifru_netmask.sa_data) };
m_interface->set_netmask(netmask);
dprintln("Netmask set to {}", m_interface->get_netmask()); dprintln("Netmask set to {}", m_interface->get_netmask());
return 0; return 0;
} }
case SIOCGIFGWADDR:
{
auto& ifru_gwaddr = *reinterpret_cast<sockaddr_in*>(&ifreq->ifr_ifru.ifru_gwaddr);
ifru_gwaddr.sin_family = AF_INET;
ifru_gwaddr.sin_addr.s_addr = m_interface->get_gateway().raw;
return 0;
}
case SIOCSIFGWADDR:
{
auto& ifru_gwaddr = *reinterpret_cast<const sockaddr_in*>(&ifreq->ifr_ifru.ifru_gwaddr);
if (ifru_gwaddr.sin_family != AF_INET)
return BAN::Error::from_errno(EADDRNOTAVAIL);
m_interface->set_gateway(BAN::IPv4Address { ifru_gwaddr.sin_addr.s_addr });
dprintln("Gateway set to {}", m_interface->get_gateway());
return 0;
}
case SIOCGIFHWADDR: case SIOCGIFHWADDR:
{ {
auto mac_address = m_interface->get_mac_address(); auto mac_address = m_interface->get_mac_address();

View File

@ -5,9 +5,9 @@
namespace Kernel namespace Kernel
{ {
BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) BAN::ErrorOr<BAN::RefPtr<UDPSocket>> UDPSocket::create(mode_t mode, uid_t uid, gid_t gid)
{ {
auto socket = TRY(BAN::RefPtr<UDPSocket>::create(network_layer, ino, inode_info)); auto socket = TRY(BAN::RefPtr<UDPSocket>::create(mode, uid, gid));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range( socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(), PageTable::kernel(),
KERNEL_OFFSET, KERNEL_OFFSET,
@ -19,14 +19,14 @@ namespace Kernel
return socket; return socket;
} }
UDPSocket::UDPSocket(NetworkLayer& network_layer, ino_t ino, const TmpInodeInfo& inode_info) UDPSocket::UDPSocket(mode_t mode, uid_t uid, gid_t gid)
: NetworkSocket(network_layer, ino, inode_info) : NetworkSocket(mode, uid, gid)
{ } { }
void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t dst_port) void UDPSocket::add_protocol_header(BAN::ByteSpan packet, uint16_t src_port, uint16_t dst_port)
{ {
auto& header = packet.as<UDPHeader>(); auto& header = packet.as<UDPHeader>();
header.src_port = m_port; header.src_port = src_port;
header.dst_port = dst_port; header.dst_port = dst_port;
header.length = packet.size(); header.length = packet.size();
header.checksum = 0; header.checksum = 0;
@ -91,8 +91,8 @@ namespace Kernel
if (sender_addr) if (sender_addr)
{ {
sender_addr->sin_family = AF_INET; sender_addr->sin_family = AF_INET;
sender_addr->sin_port = BAN::NetworkEndian(packet_info.sender_port); sender_addr->sin_port = packet_info.sender_port;
sender_addr->sin_addr.s_addr = packet_info.sender_addr.raw; sender_addr->sin_addr.s_addr = packet_info.sender_addr.as_u32();
} }
return nread; return nread;

View File

@ -1,323 +0,0 @@
#include <BAN/HashMap.h>
#include <kernel/FS/VirtualFileSystem.h>
#include <kernel/Networking/NetworkManager.h>
#include <kernel/Networking/UNIX/Socket.h>
#include <kernel/Scheduler.h>
#include <fcntl.h>
#include <sys/un.h>
namespace Kernel
{
static BAN::HashMap<BAN::String, BAN::RefPtr<UnixDomainSocket>> s_bound_sockets;
static SpinLock s_bound_socket_lock;
static constexpr size_t s_packet_buffer_size = 10 * PAGE_SIZE;
BAN::ErrorOr<BAN::RefPtr<UnixDomainSocket>> UnixDomainSocket::create(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
{
auto socket = TRY(BAN::RefPtr<UnixDomainSocket>::create(socket_type, ino, inode_info));
socket->m_packet_buffer = TRY(VirtualRange::create_to_vaddr_range(
PageTable::kernel(),
KERNEL_OFFSET,
~(uintptr_t)0,
s_packet_buffer_size,
PageTable::Flags::ReadWrite | PageTable::Flags::Present,
true
));
return socket;
}
UnixDomainSocket::UnixDomainSocket(SocketType socket_type, ino_t ino, const TmpInodeInfo& inode_info)
: TmpInode(NetworkManager::get(), ino, inode_info)
, m_socket_type(socket_type)
{
switch (socket_type)
{
case SocketType::STREAM:
case SocketType::SEQPACKET:
m_info.emplace<ConnectionInfo>();
break;
case SocketType::DGRAM:
m_info.emplace<ConnectionlessInfo>();
break;
default:
ASSERT_NOT_REACHED();
}
}
BAN::ErrorOr<long> UnixDomainSocket::accept_impl(sockaddr* address, socklen_t* address_len)
{
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.listening)
return BAN::Error::from_errno(EINVAL);
while (connection_info.pending_connections.empty())
TRY(Thread::current().block_or_eintr(connection_info.pending_semaphore));
BAN::RefPtr<UnixDomainSocket> pending;
{
LockGuard _(connection_info.pending_lock);
pending = connection_info.pending_connections.front();
connection_info.pending_connections.pop();
connection_info.pending_semaphore.unblock();
}
BAN::RefPtr<UnixDomainSocket> return_inode;
{
auto return_inode_tmp = TRY(NetworkManager::get().create_socket(SocketDomain::UNIX, m_socket_type, mode().mode & ~Mode::TYPE_MASK, uid(), gid()));
return_inode = reinterpret_cast<UnixDomainSocket*>(return_inode_tmp.ptr());
}
TRY(return_inode->m_bound_path.append(m_bound_path));
return_inode->m_info.get<ConnectionInfo>().connection = TRY(pending->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection = TRY(return_inode->get_weak_ptr());
pending->m_info.get<ConnectionInfo>().connection_done = true;
if (address && address_len && !is_bound_to_unused())
{
size_t copy_len = BAN::Math::min<size_t>(*address_len, sizeof(sockaddr) + m_bound_path.size() + 1);
auto& sockaddr_un = *reinterpret_cast<struct sockaddr_un*>(address);
sockaddr_un.sun_family = AF_UNIX;
strncpy(sockaddr_un.sun_path, pending->m_bound_path.data(), copy_len);
}
return TRY(Process::current().open_inode(return_inode, O_RDWR));
}
BAN::ErrorOr<void> UnixDomainSocket::connect_impl(const sockaddr* address, socklen_t address_len)
{
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (!is_bound())
TRY(m_bound_path.push_back('X'));
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
BAN::RefPtr<UnixDomainSocket> target;
{
LockGuard _(s_bound_socket_lock);
if (!s_bound_sockets.contains(file.canonical_path))
return BAN::Error::from_errno(ECONNREFUSED);
target = s_bound_sockets[file.canonical_path];
}
if (m_socket_type != target->m_socket_type)
return BAN::Error::from_errno(EPROTOTYPE);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(ECONNREFUSED);
if (connection_info.listening)
return BAN::Error::from_errno(EOPNOTSUPP);
connection_info.connection_done = false;
for (;;)
{
auto& target_info = target->m_info.get<ConnectionInfo>();
{
LockGuard _(target_info.pending_lock);
if (target_info.pending_connections.size() < target_info.pending_connections.capacity())
{
MUST(target_info.pending_connections.push(this));
target_info.pending_semaphore.unblock();
break;
}
}
TRY(Thread::current().block_or_eintr(target_info.pending_semaphore));
}
while (!connection_info.connection_done)
Scheduler::get().reschedule();
return {};
}
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<void> UnixDomainSocket::listen_impl(int backlog)
{
backlog = BAN::Math::clamp(backlog, 1, SOMAXCONN);
if (!is_bound())
return BAN::Error::from_errno(EDESTADDRREQ);
if (!m_info.has<ConnectionInfo>())
return BAN::Error::from_errno(EOPNOTSUPP);
auto& connection_info = m_info.get<ConnectionInfo>();
if (connection_info.connection)
return BAN::Error::from_errno(EINVAL);
TRY(connection_info.pending_connections.reserve(backlog));
connection_info.listening = true;
return {};
}
BAN::ErrorOr<void> UnixDomainSocket::bind_impl(const sockaddr* address, socklen_t address_len)
{
if (is_bound())
return BAN::Error::from_errno(EINVAL);
if (address_len != sizeof(sockaddr_un))
return BAN::Error::from_errno(EINVAL);
auto& sockaddr_un = *reinterpret_cast<const struct sockaddr_un*>(address);
if (sockaddr_un.sun_family != AF_UNIX)
return BAN::Error::from_errno(EAFNOSUPPORT);
auto absolute_path = TRY(Process::current().absolute_path_of(sockaddr_un.sun_path));
if (auto ret = Process::current().create_file_or_dir(absolute_path, 0755 | S_IFSOCK); ret.is_error())
{
if (ret.error().get_error_code() == EEXIST)
return BAN::Error::from_errno(EADDRINUSE);
return ret.release_error();
}
auto file = TRY(VirtualFileSystem::get().file_from_absolute_path(
Process::current().credentials(),
absolute_path,
O_RDWR
));
LockGuard _(s_bound_socket_lock);
ASSERT(!s_bound_sockets.contains(file.canonical_path));
TRY(s_bound_sockets.emplace(file.canonical_path, this));
m_bound_path = BAN::move(file.canonical_path);
return {};
}
bool UnixDomainSocket::is_streaming() const
{
switch (m_socket_type)
{
case SocketType::STREAM:
return true;
case SocketType::SEQPACKET:
case SocketType::DGRAM:
return false;
default:
ASSERT_NOT_REACHED();
}
}
// This to feels too hacky to expose out of here
struct LockFreeGuard
{
LockFreeGuard(RecursivePrioritySpinLock& lock)
: m_lock(lock)
, m_depth(lock.lock_depth())
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.unlock();
}
~LockFreeGuard()
{
for (uint32_t i = 0; i < m_depth; i++)
m_lock.lock();
}
private:
RecursivePrioritySpinLock& m_lock;
const uint32_t m_depth;
};
BAN::ErrorOr<void> UnixDomainSocket::add_packet(BAN::ConstByteSpan packet)
{
LockGuard _(m_lock);
while (m_packet_sizes.full() || m_packet_size_total + packet.size() > s_packet_buffer_size)
{
LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr() + m_packet_size_total);
memcpy(packet_buffer, packet.data(), packet.size());
m_packet_size_total += packet.size();
if (!is_streaming())
m_packet_sizes.push(packet.size());
m_packet_semaphore.unblock();
return {};
}
BAN::ErrorOr<size_t> UnixDomainSocket::sendto_impl(const sys_sendto_t* arguments)
{
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (arguments->length > s_packet_buffer_size)
return BAN::Error::from_errno(ENOBUFS);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (arguments->dest_addr)
return BAN::Error::from_errno(EISCONN);
auto target = connection_info.connection.lock();
if (!target)
return BAN::Error::from_errno(ENOTCONN);
TRY(target->add_packet({ reinterpret_cast<const uint8_t*>(arguments->message), arguments->length }));
return arguments->length;
}
else
{
return BAN::Error::from_errno(ENOTSUP);
}
}
BAN::ErrorOr<size_t> UnixDomainSocket::recvfrom_impl(sys_recvfrom_t* arguments)
{
if (arguments->flags)
return BAN::Error::from_errno(ENOTSUP);
if (m_info.has<ConnectionInfo>())
{
auto& connection_info = m_info.get<ConnectionInfo>();
if (!connection_info.connection)
return BAN::Error::from_errno(ENOTCONN);
}
while (m_packet_size_total == 0)
{
LockFreeGuard _(m_lock);
TRY(Thread::current().block_or_eintr(m_packet_semaphore));
}
uint8_t* packet_buffer = reinterpret_cast<uint8_t*>(m_packet_buffer->vaddr());
size_t nread = 0;
if (is_streaming())
nread = BAN::Math::min(arguments->length, m_packet_size_total);
else
{
nread = BAN::Math::min(arguments->length, m_packet_sizes.front());
m_packet_sizes.pop();
}
memcpy(arguments->buffer, packet_buffer, nread);
memmove(packet_buffer, packet_buffer + nread, m_packet_size_total - nread);
m_packet_size_total -= nread;
m_packet_semaphore.unblock();
return nread;
}
}

View File

@ -55,21 +55,6 @@ namespace Kernel
return {}; return {};
} }
BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::RefPtr<Inode> inode, int flags)
{
ASSERT(inode);
ASSERT(!inode->mode().ifdir());
if (flags & ~(O_RDONLY | O_WRONLY))
return BAN::Error::from_errno(ENOTSUP);
int fd = TRY(get_free_fd());
// FIXME: path?
m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(inode, ""sv, 0, flags));
return fd;
}
BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::StringView absolute_path, int flags) BAN::ErrorOr<int> OpenFileDescriptorSet::open(BAN::StringView absolute_path, int flags)
{ {
if (flags & ~(O_RDONLY | O_WRONLY | O_NOFOLLOW | O_SEARCH | O_APPEND | O_TRUNC | O_CLOEXEC | O_TTY_INIT | O_DIRECTORY | O_NONBLOCK)) if (flags & ~(O_RDONLY | O_WRONLY | O_NOFOLLOW | O_SEARCH | O_APPEND | O_TRUNC | O_CLOEXEC | O_TTY_INIT | O_DIRECTORY | O_NONBLOCK))
@ -95,25 +80,13 @@ namespace Kernel
BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol) BAN::ErrorOr<int> OpenFileDescriptorSet::socket(int domain, int type, int protocol)
{ {
using SocketType = NetworkManager::SocketType;
if (domain != AF_INET)
return BAN::Error::from_errno(EAFNOSUPPORT);
if (protocol != 0) if (protocol != 0)
return BAN::Error::from_errno(EPROTONOSUPPORT); return BAN::Error::from_errno(EPROTONOSUPPORT);
SocketDomain sock_domain;
switch (domain)
{
case AF_INET:
sock_domain = SocketDomain::INET;
break;
case AF_INET6:
sock_domain = SocketDomain::INET6;
break;
case AF_UNIX:
sock_domain = SocketDomain::UNIX;
break;
default:
return BAN::Error::from_errno(EPROTOTYPE);
}
SocketType sock_type; SocketType sock_type;
switch (type) switch (type)
{ {
@ -130,7 +103,7 @@ namespace Kernel
return BAN::Error::from_errno(EPROTOTYPE); return BAN::Error::from_errno(EPROTOTYPE);
} }
auto socket = TRY(NetworkManager::get().create_socket(sock_domain, sock_type, 0777, m_credentials.euid(), m_credentials.egid())); auto socket = TRY(NetworkManager::get().create_socket(sock_type, 0777, m_credentials.euid(), m_credentials.egid()));
int fd = TRY(get_free_fd()); int fd = TRY(get_free_fd());
m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(socket, "no-path"sv, 0, O_RDWR)); m_open_files[fd] = TRY(BAN::RefPtr<OpenFileDescription>::create(socket, "no-path"sv, 0, O_RDWR));

View File

@ -649,9 +649,8 @@ namespace Kernel
case Inode::Mode::IFREG: break; case Inode::Mode::IFREG: break;
case Inode::Mode::IFDIR: break; case Inode::Mode::IFDIR: break;
case Inode::Mode::IFIFO: break; case Inode::Mode::IFIFO: break;
case Inode::Mode::IFSOCK: break;
default: default:
return BAN::Error::from_errno(ENOTSUP); return BAN::Error::from_errno(EINVAL);
} }
LockGuard _(m_lock); LockGuard _(m_lock);
@ -708,17 +707,8 @@ namespace Kernel
return false; return false;
} }
BAN::ErrorOr<long> Process::open_inode(BAN::RefPtr<Inode> inode, int flags)
{
ASSERT(inode);
LockGuard _(m_lock);
return TRY(m_open_file_descriptors.open(inode, flags));
}
BAN::ErrorOr<long> Process::open_file(BAN::StringView path, int flags, mode_t mode) BAN::ErrorOr<long> Process::open_file(BAN::StringView path, int flags, mode_t mode)
{ {
LockGuard _(m_lock);
BAN::String absolute_path = TRY(absolute_path_of(path)); BAN::String absolute_path = TRY(absolute_path_of(path));
if (flags & O_CREAT) if (flags & O_CREAT)
@ -726,8 +716,6 @@ namespace Kernel
if (flags & O_DIRECTORY) if (flags & O_DIRECTORY)
return BAN::Error::from_errno(ENOTSUP); return BAN::Error::from_errno(ENOTSUP);
auto file_or_error = VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, O_WRONLY); auto file_or_error = VirtualFileSystem::get().file_from_absolute_path(m_credentials, absolute_path, O_WRONLY);
if (!file_or_error.is_error() && (flags & O_EXCL))
return BAN::Error::from_errno(EEXIST);
if (file_or_error.is_error()) if (file_or_error.is_error())
{ {
if (file_or_error.error().get_error_code() == ENOENT) if (file_or_error.error().get_error_code() == ENOENT)
@ -913,26 +901,6 @@ namespace Kernel
return TRY(m_open_file_descriptors.socket(domain, type, protocol)); return TRY(m_open_file_descriptors.socket(domain, type, protocol));
} }
BAN::ErrorOr<long> Process::sys_accept(int socket, sockaddr* address, socklen_t* address_len)
{
if (address && !address_len)
return BAN::Error::from_errno(EINVAL);
if (!address && address_len)
return BAN::Error::from_errno(EINVAL);
LockGuard _(m_lock);
if (address)
{
TRY(validate_pointer_access(address_len, sizeof(*address_len)));
TRY(validate_pointer_access(address, *address_len));
}
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
return TRY(inode->accept(address, address_len));
}
BAN::ErrorOr<long> Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len) BAN::ErrorOr<long> Process::sys_bind(int socket, const sockaddr* address, socklen_t address_len)
{ {
@ -947,31 +915,6 @@ namespace Kernel
return 0; return 0;
} }
BAN::ErrorOr<long> Process::sys_connect(int socket, const sockaddr* address, socklen_t address_len)
{
LockGuard _(m_lock);
TRY(validate_pointer_access(address, address_len));
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
TRY(inode->connect(address, address_len));
return 0;
}
BAN::ErrorOr<long> Process::sys_listen(int socket, int backlog)
{
LockGuard _(m_lock);
auto inode = TRY(m_open_file_descriptors.inode_of(socket));
if (!inode->mode().ifsock())
return BAN::Error::from_errno(ENOTSOCK);
TRY(inode->listen(backlog));
return 0;
}
BAN::ErrorOr<long> Process::sys_sendto(const sys_sendto_t* arguments) BAN::ErrorOr<long> Process::sys_sendto(const sys_sendto_t* arguments)
{ {
LockGuard _(m_lock); LockGuard _(m_lock);
@ -1754,7 +1697,7 @@ namespace Kernel
BAN::ErrorOr<BAN::String> Process::absolute_path_of(BAN::StringView path) const BAN::ErrorOr<BAN::String> Process::absolute_path_of(BAN::StringView path) const
{ {
LockGuard _(m_lock); ASSERT(m_lock.is_locked());
if (path.empty() || path == "."sv) if (path.empty() || path == "."sv)
return m_working_directory; return m_working_directory;

View File

@ -228,15 +228,6 @@ namespace Kernel
case SYS_IOCTL: case SYS_IOCTL:
ret = Process::current().sys_ioctl((int)arg1, (int)arg2, (void*)arg3); ret = Process::current().sys_ioctl((int)arg1, (int)arg2, (void*)arg3);
break; break;
case SYS_ACCEPT:
ret = Process::current().sys_accept((int)arg1, (sockaddr*)arg2, (socklen_t*)arg3);
break;
case SYS_CONNECT:
ret = Process::current().sys_connect((int)arg1, (const sockaddr*)arg2, (socklen_t)arg3);
break;
case SYS_LISTEN:
ret = Process::current().sys_listen((int)arg1, (int)arg2);
break;
default: default:
dwarnln("Unknown syscall {}", syscall); dwarnln("Unknown syscall {}", syscall);
break; break;

View File

@ -3,7 +3,6 @@ cmake_minimum_required(VERSION 3.26)
project(libc CXX ASM) project(libc CXX ASM)
set(LIBC_SOURCES set(LIBC_SOURCES
arpa/inet.cpp
assert.cpp assert.cpp
ctype.cpp ctype.cpp
dirent.cpp dirent.cpp

View File

@ -1,52 +0,0 @@
#include <BAN/Endianness.h>
#include <arpa/inet.h>
#include <errno.h>
#include <stdio.h>
uint32_t htonl(uint32_t hostlong)
{
return BAN::host_to_network_endian(hostlong);
}
uint16_t htons(uint16_t hostshort)
{
return BAN::host_to_network_endian(hostshort);
}
uint32_t ntohl(uint32_t netlong)
{
return BAN::host_to_network_endian(netlong);
}
uint16_t ntohs(uint16_t netshort)
{
return BAN::host_to_network_endian(netshort);
}
in_addr_t inet_addr(const char* cp)
{
uint32_t a = 0, b = 0, c = 0, d = 0;
int ret = sscanf(cp, "%u.%u.%u.%u", &a, &b, &c, &d);
if (ret < 1 || ret > 4)
return (in_addr_t)(-1);
uint32_t result = 0;
result |= (ret == 1) ? a : a << 24;
result |= (ret == 2) ? b : b << 16;
result |= (ret == 3) ? c : c << 8;
result |= (ret == 4) ? d : d << 0;
return htonl(result);
}
char* inet_ntoa(struct in_addr in)
{
static char buffer[16];
uint32_t he = ntohl(in.s_addr);
sprintf(buffer, "%u.%u.%u.%u",
(he >> 24) & 0xFF,
(he >> 16) & 0xFF,
(he >> 8) & 0xFF,
(he >> 0) & 0xFF
);
return buffer;
}

View File

@ -22,7 +22,6 @@ struct ifreq
union { union {
struct sockaddr ifru_addr; struct sockaddr ifru_addr;
struct sockaddr ifru_netmask; struct sockaddr ifru_netmask;
struct sockaddr ifru_gwaddr;
struct sockaddr ifru_hwaddr; struct sockaddr ifru_hwaddr;
unsigned char __min_storage[sizeof(sockaddr) + 6]; unsigned char __min_storage[sizeof(sockaddr) + 6];
} ifr_ifru; } ifr_ifru;
@ -32,9 +31,7 @@ struct ifreq
#define SIOCSIFADDR 2 /* Set interface address */ #define SIOCSIFADDR 2 /* Set interface address */
#define SIOCGIFNETMASK 3 /* Get network mask */ #define SIOCGIFNETMASK 3 /* Get network mask */
#define SIOCSIFNETMASK 4 /* Set network mask */ #define SIOCSIFNETMASK 4 /* Set network mask */
#define SIOCGIFGWADDR 5 /* Get gateway address */ #define SIOCGIFHWADDR 5 /* Get hardware address */
#define SIOCSIFGWADDR 6 /* Set gateway address */
#define SIOCGIFHWADDR 7 /* Get hardware address */
void if_freenameindex(struct if_nameindex* ptr); void if_freenameindex(struct if_nameindex* ptr);
char* if_indextoname(unsigned ifindex, char* ifname); char* if_indextoname(unsigned ifindex, char* ifname);

View File

@ -53,8 +53,6 @@ extern FILE* __stdout;
#define stdout __stdout #define stdout __stdout
extern FILE* __stderr; extern FILE* __stderr;
#define stderr __stderr #define stderr __stderr
extern FILE* __stddbg;
#define stddbg __stddbg
void clearerr(FILE* stream); void clearerr(FILE* stream);
char* ctermid(char* s); char* ctermid(char* s);

View File

@ -16,12 +16,6 @@ __BEGIN_DECLS
#include <bits/types/sa_family_t.h> #include <bits/types/sa_family_t.h>
typedef long socklen_t; typedef long socklen_t;
#if !defined(FILENAME_MAX)
#define FILENAME_MAX 256
#elif FILENAME_MAX != 256
#error "invalid FILENAME_MAX"
#endif
struct sockaddr struct sockaddr
{ {
sa_family_t sa_family; /* Address family. */ sa_family_t sa_family; /* Address family. */
@ -30,8 +24,8 @@ struct sockaddr
struct sockaddr_storage struct sockaddr_storage
{ {
// FIXME
sa_family_t ss_family; sa_family_t ss_family;
char ss_storage[FILENAME_MAX];
}; };
struct msghdr struct msghdr

View File

@ -68,9 +68,6 @@ __BEGIN_DECLS
#define SYS_SENDTO 67 #define SYS_SENDTO 67
#define SYS_RECVFROM 68 #define SYS_RECVFROM 68
#define SYS_IOCTL 69 #define SYS_IOCTL 69
#define SYS_ACCEPT 70
#define SYS_CONNECT 71
#define SYS_LISTEN 72
__END_DECLS __END_DECLS

View File

@ -11,8 +11,8 @@ __BEGIN_DECLS
struct sockaddr_un struct sockaddr_un
{ {
sa_family_t sun_family; /* Address family. */ sa_family_t sun_family; /* Address family. */
char sun_path[FILENAME_MAX]; /* Socket pathname. */ char sun_path[]; /* Socket pathname. */
}; };
__END_DECLS __END_DECLS

View File

@ -120,7 +120,6 @@ __BEGIN_DECLS
#define STDIN_FILENO 0 #define STDIN_FILENO 0
#define STDOUT_FILENO 1 #define STDOUT_FILENO 1
#define STDERR_FILENO 2 #define STDERR_FILENO 2
#define STDDBG_FILENO 3
#define _POSIX_VDISABLE 0 #define _POSIX_VDISABLE 0

View File

@ -22,13 +22,11 @@ static FILE s_files[FOPEN_MAX] {
{ .fd = STDIN_FILENO }, { .fd = STDIN_FILENO },
{ .fd = STDOUT_FILENO }, { .fd = STDOUT_FILENO },
{ .fd = STDERR_FILENO }, { .fd = STDERR_FILENO },
{ .fd = STDDBG_FILENO },
}; };
FILE* stdin = &s_files[0]; FILE* stdin = &s_files[0];
FILE* stdout = &s_files[1]; FILE* stdout = &s_files[1];
FILE* stderr = &s_files[2]; FILE* stderr = &s_files[2];
FILE* stddbg = &s_files[3];
void clearerr(FILE* file) void clearerr(FILE* file)
{ {

View File

@ -2,31 +2,11 @@
#include <sys/syscall.h> #include <sys/syscall.h>
#include <unistd.h> #include <unistd.h>
int accept(int socket, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
{
return syscall(SYS_ACCEPT, socket, address, address_len);
}
int bind(int socket, const struct sockaddr* address, socklen_t address_len) int bind(int socket, const struct sockaddr* address, socklen_t address_len)
{ {
return syscall(SYS_BIND, socket, address, address_len); return syscall(SYS_BIND, socket, address, address_len);
} }
int connect(int socket, const struct sockaddr* address, socklen_t address_len)
{
return syscall(SYS_CONNECT, socket, address, address_len);
}
int listen(int socket, int backlog)
{
return syscall(SYS_LISTEN, socket, backlog);
}
ssize_t recv(int socket, void* __restrict buffer, size_t length, int flags)
{
return recvfrom(socket, buffer, length, flags, nullptr, nullptr);
}
ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len) ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags, struct sockaddr* __restrict address, socklen_t* __restrict address_len)
{ {
sys_recvfrom_t arguments { sys_recvfrom_t arguments {
@ -40,10 +20,6 @@ ssize_t recvfrom(int socket, void* __restrict buffer, size_t length, int flags,
return syscall(SYS_RECVFROM, &arguments); return syscall(SYS_RECVFROM, &arguments);
} }
ssize_t send(int socket, const void* message, size_t length, int flags)
{
return sendto(socket, message, length, flags, nullptr, 0);
}
ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len) ssize_t sendto(int socket, const void* message, size_t length, int flags, const struct sockaddr* dest_addr, socklen_t dest_len)
{ {

View File

@ -11,16 +11,14 @@ set(USERSPACE_PROJECTS
dhcp-client dhcp-client
echo echo
id id
image
init init
image
loadkeys loadkeys
ls ls
meminfo meminfo
mkdir mkdir
mmap-shared-test mmap-shared-test
nslookup
poweroff poweroff
resolver
rm rm
Shell Shell
sleep sleep

View File

@ -72,9 +72,9 @@ static constexpr Position s_dir_offset[] {
i64 solve_general(FILE* fp, auto parse_dir, auto parse_count) i64 solve_general(FILE* fp, auto parse_dir, auto parse_count)
{ {
BAN::HashSet<Position, PositionHash> path; BAN::HashSetUnstable<Position, PositionHash> path;
BAN::HashSet<Position, PositionHash> lpath; BAN::HashSetUnstable<Position, PositionHash> lpath;
BAN::HashSet<Position, PositionHash> rpath; BAN::HashSetUnstable<Position, PositionHash> rpath;
Position current_pos { 0, 0 }; Position current_pos { 0, 0 };
MUST(path.insert(current_pos)); MUST(path.insert(current_pos));
@ -157,8 +157,8 @@ i64 solve_general(FILE* fp, auto parse_dir, auto parse_count)
ASSERT(lmin_x != rmin_x); ASSERT(lmin_x != rmin_x);
auto& expand = (lmin_x < rmin_x) ? rpath : lpath; auto& expand = (lmin_x < rmin_x) ? rpath : lpath;
BAN::HashSet<Position, PositionHash> visited; BAN::HashSetUnstable<Position, PositionHash> visited;
BAN::HashSet<Position, PositionHash> inner_area; BAN::HashSetUnstable<Position, PositionHash> inner_area;
while (!expand.empty()) while (!expand.empty())
{ {

View File

@ -33,7 +33,7 @@ struct Rule
BAN::String target; BAN::String target;
}; };
using Workflows = BAN::HashMap<BAN::String, BAN::Vector<Rule>>; using Workflows = BAN::HashMapUnstable<BAN::String, BAN::Vector<Rule>>;
struct Item struct Item
{ {

View File

@ -72,9 +72,9 @@ struct ConjunctionModule : public Module
} }
}; };
BAN::HashMap<BAN::String, BAN::UniqPtr<Module>> parse_modules(FILE* fp) BAN::HashMapUnstable<BAN::String, BAN::UniqPtr<Module>> parse_modules(FILE* fp)
{ {
BAN::HashMap<BAN::String, BAN::UniqPtr<Module>> modules; BAN::HashMapUnstable<BAN::String, BAN::UniqPtr<Module>> modules;
char buffer[128]; char buffer[128];
while (fgets(buffer, sizeof(buffer), fp)) while (fgets(buffer, sizeof(buffer), fp))

View File

@ -88,13 +88,13 @@ i64 puzzle1(FILE* fp)
{ {
auto garden = parse_garden(fp); auto garden = parse_garden(fp);
BAN::HashSet<Position> visited, reachable, pending; BAN::HashSetUnstable<Position> visited, reachable, pending;
MUST(pending.insert(garden.start)); MUST(pending.insert(garden.start));
for (i32 i = 0; i <= 64; i++) for (i32 i = 0; i <= 64; i++)
{ {
auto temp = BAN::move(pending); auto temp = BAN::move(pending);
pending = BAN::HashSet<Position>(); pending = BAN::HashSetUnstable<Position>();
while (!temp.empty()) while (!temp.empty())
{ {

View File

@ -1,10 +1,8 @@
#include <BAN/Debug.h>
#include <BAN/Endianness.h> #include <BAN/Endianness.h>
#include <BAN/IPv4.h> #include <BAN/IPv4.h>
#include <BAN/MAC.h> #include <BAN/MAC.h>
#include <BAN/Vector.h> #include <BAN/Vector.h>
#include <arpa/inet.h>
#include <fcntl.h> #include <fcntl.h>
#include <net/if.h> #include <net/if.h>
#include <netinet/in.h> #include <netinet/in.h>
@ -14,7 +12,7 @@
#include <stropts.h> #include <stropts.h>
#include <sys/socket.h> #include <sys/socket.h>
#define DEBUG_DHCP 1 #define DEBUG_DHCP 0
struct DHCPPacket struct DHCPPacket
{ {
@ -25,10 +23,10 @@ struct DHCPPacket
BAN::NetworkEndian<uint32_t> xid { 0x3903F326 }; BAN::NetworkEndian<uint32_t> xid { 0x3903F326 };
BAN::NetworkEndian<uint16_t> secs { 0x0000 }; BAN::NetworkEndian<uint16_t> secs { 0x0000 };
BAN::NetworkEndian<uint16_t> flags { 0x0000 }; BAN::NetworkEndian<uint16_t> flags { 0x0000 };
BAN::IPv4Address ciaddr { 0 }; BAN::NetworkEndian<uint32_t> ciaddr { 0 };
BAN::IPv4Address yiaddr { 0 }; BAN::NetworkEndian<uint32_t> yiaddr { 0 };
BAN::IPv4Address siaddr { 0 }; BAN::NetworkEndian<uint32_t> siaddr { 0 };
BAN::IPv4Address giaddr { 0 }; BAN::NetworkEndian<uint32_t> giaddr { 0 };
BAN::MACAddress chaddr; BAN::MACAddress chaddr;
uint8_t padding[10] {}; uint8_t padding[10] {};
uint8_t legacy[192] {}; uint8_t legacy[192] {};
@ -73,13 +71,12 @@ BAN::MACAddress get_mac_address(int socket)
return mac_address; return mac_address;
} }
void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address netmask, BAN::IPv4Address gateway) void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address subnet)
{ {
{ {
ifreq ifreq; ifreq ifreq;
auto& ifru_addr = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_addr); ifreq.ifr_ifru.ifru_addr.sa_family = AF_INET;
ifru_addr.sin_family = AF_INET; *(uint32_t*)ifreq.ifr_ifru.ifru_addr.sa_data = address.as_u32();
ifru_addr.sin_addr.s_addr = address.raw;
if (ioctl(socket, SIOCSIFADDR, &ifreq) == -1) if (ioctl(socket, SIOCSIFADDR, &ifreq) == -1)
{ {
perror("ioctl"); perror("ioctl");
@ -89,36 +86,22 @@ void update_ipv4_info(int socket, BAN::IPv4Address address, BAN::IPv4Address net
{ {
ifreq ifreq; ifreq ifreq;
auto& ifru_netmask = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_netmask); ifreq.ifr_ifru.ifru_netmask.sa_family = AF_INET;
ifru_netmask.sin_family = AF_INET; *(uint32_t*)ifreq.ifr_ifru.ifru_netmask.sa_data = subnet.as_u32();
ifru_netmask.sin_addr.s_addr = netmask.raw;
if (ioctl(socket, SIOCSIFNETMASK, &ifreq) == -1) if (ioctl(socket, SIOCSIFNETMASK, &ifreq) == -1)
{ {
perror("ioctl"); perror("ioctl");
exit(1); exit(1);
} }
} }
if (gateway.raw)
{
ifreq ifreq;
auto& ifru_gwaddr = *reinterpret_cast<sockaddr_in*>(&ifreq.ifr_ifru.ifru_gwaddr);
ifru_gwaddr.sin_family = AF_INET;
ifru_gwaddr.sin_addr.s_addr = gateway.raw;
if (ioctl(socket, SIOCSIFGWADDR, &ifreq) == -1)
{
perror("ioctl");
exit(1);
}
}
} }
void send_dhcp_packet(int socket, const DHCPPacket& dhcp_packet, BAN::IPv4Address server_ipv4) void send_dhcp_packet(int socket, const DHCPPacket& dhcp_packet, BAN::IPv4Address server_ipv4)
{ {
sockaddr_in server_addr; sockaddr_in server_addr;
server_addr.sin_family = AF_INET; server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(67); server_addr.sin_port = 67;
server_addr.sin_addr.s_addr = server_ipv4.raw; server_addr.sin_addr.s_addr = server_ipv4.as_u32();;
if (sendto(socket, &dhcp_packet, sizeof(DHCPPacket), 0, (sockaddr*)&server_addr, sizeof(server_addr)) == -1) if (sendto(socket, &dhcp_packet, sizeof(DHCPPacket), 0, (sockaddr*)&server_addr, sizeof(server_addr)) == -1)
{ {
@ -154,7 +137,7 @@ void send_dhcp_request(int socket, BAN::MACAddress mac_address, BAN::IPv4Address
{ {
DHCPPacket dhcp_packet; DHCPPacket dhcp_packet;
dhcp_packet.op = 0x01; dhcp_packet.op = 0x01;
dhcp_packet.siaddr = server_ipv4.raw; dhcp_packet.siaddr = server_ipv4.as_u32();
dhcp_packet.chaddr = mac_address; dhcp_packet.chaddr = mac_address;
size_t idx = 0; size_t idx = 0;
@ -165,10 +148,10 @@ void send_dhcp_request(int socket, BAN::MACAddress mac_address, BAN::IPv4Address
dhcp_packet.options[idx++] = RequestedIPv4Address; dhcp_packet.options[idx++] = RequestedIPv4Address;
dhcp_packet.options[idx++] = 0x04; dhcp_packet.options[idx++] = 0x04;
dhcp_packet.options[idx++] = offered_ipv4.octets[0]; dhcp_packet.options[idx++] = offered_ipv4.address[0];
dhcp_packet.options[idx++] = offered_ipv4.octets[1]; dhcp_packet.options[idx++] = offered_ipv4.address[1];
dhcp_packet.options[idx++] = offered_ipv4.octets[2]; dhcp_packet.options[idx++] = offered_ipv4.address[2];
dhcp_packet.options[idx++] = offered_ipv4.octets[3]; dhcp_packet.options[idx++] = offered_ipv4.address[3];
dhcp_packet.options[idx++] = 0xFF; dhcp_packet.options[idx++] = 0xFF;
@ -205,7 +188,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
fprintf(stderr, "Subnet mask with invalid length %hhu\n", length); fprintf(stderr, "Subnet mask with invalid length %hhu\n", length);
break; break;
} }
uint32_t raw = *reinterpret_cast<const uint32_t*>(options); uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options);
packet_info.subnet = BAN::IPv4Address(raw); packet_info.subnet = BAN::IPv4Address(raw);
break; break;
} }
@ -218,7 +201,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
} }
for (int i = 0; i < length; i += 4) for (int i = 0; i < length; i += 4)
{ {
uint32_t raw = *reinterpret_cast<const uint32_t*>(options + i); uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options + i);
MUST(packet_info.routers.emplace_back(raw)); MUST(packet_info.routers.emplace_back(raw));
} }
break; break;
@ -232,7 +215,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
} }
for (int i = 0; i < length; i += 4) for (int i = 0; i < length; i += 4)
{ {
uint32_t raw = *reinterpret_cast<const uint32_t*>(options + i); uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options + i);
MUST(packet_info.dns.emplace_back(raw)); MUST(packet_info.dns.emplace_back(raw));
} }
break; break;
@ -261,7 +244,7 @@ DHCPPacketInfo parse_dhcp_packet(const DHCPPacket& packet)
fprintf(stderr, "Server identifier with invalid length %hhu\n", length); fprintf(stderr, "Server identifier with invalid length %hhu\n", length);
break; break;
} }
uint32_t raw = *reinterpret_cast<const uint32_t*>(options); uint32_t raw = *reinterpret_cast<const BAN::NetworkEndian<uint32_t>*>(options);
packet_info.server = BAN::IPv4Address(raw); packet_info.server = BAN::IPv4Address(raw);
break; break;
} }
@ -310,8 +293,8 @@ int main()
sockaddr_in client_addr; sockaddr_in client_addr;
client_addr.sin_family = AF_INET; client_addr.sin_family = AF_INET;
client_addr.sin_port = htons(68); client_addr.sin_port = 68;
client_addr.sin_addr.s_addr = INADDR_ANY; client_addr.sin_addr.s_addr = 0x00000000;
if (bind(socket, (sockaddr*)&client_addr, sizeof(client_addr)) == -1) if (bind(socket, (sockaddr*)&client_addr, sizeof(client_addr)) == -1)
{ {
@ -321,12 +304,12 @@ int main()
auto mac_address = get_mac_address(socket); auto mac_address = get_mac_address(socket);
#if DEBUG_DHCP #if DEBUG_DHCP
dprintln("MAC: {}", mac_address); BAN::Formatter::println(putchar, "MAC: {}", mac_address);
#endif #endif
send_dhcp_discover(socket, mac_address); send_dhcp_discover(socket, mac_address);
#if DEBUG_DHCP #if DEBUG_DHCP
dprintln("DHCPDISCOVER sent"); printf("DHCPDISCOVER sent\n");
#endif #endif
auto dhcp_offer = read_dhcp_packet(socket); auto dhcp_offer = read_dhcp_packet(socket);
@ -339,15 +322,15 @@ int main()
} }
#if DEBUG_DHCP #if DEBUG_DHCP
dprintln("DHCPOFFER"); BAN::Formatter::println(putchar, "DHCPOFFER");
dprintln(" IP {}", dhcp_offer->address); BAN::Formatter::println(putchar, " IP {}", dhcp_offer->address);
dprintln(" SUBNET {}", dhcp_offer->subnet); BAN::Formatter::println(putchar, " SUBNET {}", dhcp_offer->subnet);
dprintln(" SERVER {}", dhcp_offer->server); BAN::Formatter::println(putchar, " SERVER {}", dhcp_offer->server);
#endif #endif
send_dhcp_request(socket, mac_address, dhcp_offer->address, dhcp_offer->server); send_dhcp_request(socket, mac_address, dhcp_offer->address, dhcp_offer->server);
#if DEBUG_DHCP #if DEBUG_DHCP
dprintln("DHCPREQUEST sent"); printf("DHCPREQUEST sent\n");
#endif #endif
auto dhcp_ack = read_dhcp_packet(socket); auto dhcp_ack = read_dhcp_packet(socket);
@ -360,10 +343,10 @@ int main()
} }
#if DEBUG_DHCP #if DEBUG_DHCP
dprintln("DHCPACK"); BAN::Formatter::println(putchar, "DHCPACK");
dprintln(" IP {}", dhcp_ack->address); BAN::Formatter::println(putchar, " IP {}", dhcp_ack->address);
dprintln(" SUBNET {}", dhcp_ack->subnet); BAN::Formatter::println(putchar, " SUBNET {}", dhcp_ack->subnet);
dprintln(" SERVER {}", dhcp_ack->server); BAN::Formatter::println(putchar, " SERVER {}", dhcp_ack->server);
#endif #endif
if (dhcp_offer->address != dhcp_ack->address) if (dhcp_offer->address != dhcp_ack->address)
@ -372,11 +355,7 @@ int main()
return 1; return 1;
} }
BAN::IPv4Address gateway { 0 }; update_ipv4_info(socket, dhcp_ack->address, dhcp_ack->subnet);
if (!dhcp_ack->routers.empty())
gateway = dhcp_ack->routers.front();
update_ipv4_info(socket, dhcp_ack->address, dhcp_ack->subnet, gateway);
close(socket); close(socket);

View File

@ -17,7 +17,6 @@ void initialize_stdio()
if (open(tty, O_RDONLY | O_TTY_INIT) != 0) _exit(1); if (open(tty, O_RDONLY | O_TTY_INIT) != 0) _exit(1);
if (open(tty, O_WRONLY) != 1) _exit(1); if (open(tty, O_WRONLY) != 1) _exit(1);
if (open(tty, O_WRONLY) != 2) _exit(1); if (open(tty, O_WRONLY) != 2) _exit(1);
if (open("/dev/debug", O_WRONLY) != 3) _exit(1);
} }
int main() int main()
@ -36,12 +35,6 @@ int main()
exit(1); exit(1);
} }
if (fork() == 0)
{
execl("/bin/resolver", "resolver", NULL);
exit(1);
}
bool first = true; bool first = true;
termios termios; termios termios;

View File

@ -1,16 +0,0 @@
cmake_minimum_required(VERSION 3.26)
project(nslookup CXX)
set(SOURCES
main.cpp
)
add_executable(nslookup ${SOURCES})
target_compile_options(nslookup PUBLIC -O2 -g)
target_link_libraries(nslookup PUBLIC libc)
add_custom_target(nslookup-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/nslookup ${BANAN_BIN}/
DEPENDS nslookup
)

View File

@ -1,47 +0,0 @@
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "usage: %s DOMAIN\n", argv[0]);
return 1;
}
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
perror("socket");
return 1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (connect(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("connect");
return 1;
}
if (send(socket, argv[1], strlen(argv[1]), 0) == -1)
{
perror("send");
return 1;
}
char buffer[128];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
return 1;
}
buffer[nrecv] = '\0';
printf("%s\n", buffer);
return 0;
}

View File

@ -1,16 +0,0 @@
cmake_minimum_required(VERSION 3.26)
project(resolver CXX)
set(SOURCES
main.cpp
)
add_executable(resolver ${SOURCES})
target_compile_options(resolver PUBLIC -O2 -g)
target_link_libraries(resolver PUBLIC libc ban)
add_custom_target(resolver-install
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/resolver ${BANAN_BIN}/
DEPENDS resolver
)

View File

@ -1,218 +0,0 @@
#include <BAN/ByteSpan.h>
#include <BAN/Endianness.h>
#include <BAN/HashMap.h>
#include <BAN/IPv4.h>
#include <BAN/String.h>
#include <BAN/StringView.h>
#include <BAN/Vector.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <unistd.h>
struct DNSPacket
{
BAN::NetworkEndian<uint16_t> identification { 0 };
BAN::NetworkEndian<uint16_t> flags { 0 };
BAN::NetworkEndian<uint16_t> question_count { 0 };
BAN::NetworkEndian<uint16_t> answer_count { 0 };
BAN::NetworkEndian<uint16_t> authority_RR_count { 0 };
BAN::NetworkEndian<uint16_t> additional_RR_count { 0 };
uint8_t data[];
};
static_assert(sizeof(DNSPacket) == 12);
struct DNSAnswer
{
uint8_t __storage[12];
BAN::NetworkEndian<uint16_t>& name() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x00); };
BAN::NetworkEndian<uint16_t>& type() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x02); };
BAN::NetworkEndian<uint16_t>& class_() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x04); };
BAN::NetworkEndian<uint32_t>& ttl() { return *reinterpret_cast<BAN::NetworkEndian<uint32_t>*>(__storage + 0x06); };
BAN::NetworkEndian<uint16_t>& data_len() { return *reinterpret_cast<BAN::NetworkEndian<uint16_t>*>(__storage + 0x0A); };
uint8_t data[];
};
static_assert(sizeof(DNSAnswer) == 12);
bool send_dns_query(int socket, BAN::StringView domain, uint16_t id)
{
static uint8_t buffer[4096];
memset(buffer, 0, sizeof(buffer));
DNSPacket& request = *reinterpret_cast<DNSPacket*>(buffer);
request.identification = id;
request.flags = 0x0100;
request.question_count = 1;
size_t idx = 0;
auto labels = MUST(BAN::StringView(domain).split('.'));
for (auto label : labels)
{
ASSERT(label.size() <= 0xFF);
request.data[idx++] = label.size();
for (char c : label)
request.data[idx++] = c;
}
request.data[idx++] = 0x00;
*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2;
*(uint16_t*)&request.data[idx] = htons(0x01); idx += 2;
sockaddr_in nameserver;
nameserver.sin_family = AF_INET;
nameserver.sin_port = htons(53);
nameserver.sin_addr.s_addr = inet_addr("8.8.8.8");
if (sendto(socket, &request, sizeof(DNSPacket) + idx, 0, (sockaddr*)&nameserver, sizeof(nameserver)) == -1)
{
perror("sendto");
return false;
}
return true;
}
BAN::Optional<BAN::String> read_dns_response(int socket, uint16_t id)
{
static uint8_t buffer[4096];
ssize_t nrecv = recvfrom(socket, buffer, sizeof(buffer), 0, nullptr, nullptr);
if (nrecv == -1)
{
perror("recvfrom");
return {};
}
DNSPacket& reply = *reinterpret_cast<DNSPacket*>(buffer);
if (reply.identification != id)
{
fprintf(stderr, "Reply to invalid packet\n");
return {};
}
if (reply.flags & 0x0F)
{
fprintf(stderr, "DNS error (rcode %u)\n", (unsigned)(reply.flags & 0xF));
return {};
}
size_t idx = 0;
for (size_t i = 0; i < reply.question_count; i++)
{
while (reply.data[idx])
idx += reply.data[idx] + 1;
idx += 5;
}
DNSAnswer& answer = *reinterpret_cast<DNSAnswer*>(&reply.data[idx]);
if (answer.data_len() != 4)
{
fprintf(stderr, "Not IPv4\n");
return {};
}
return inet_ntoa({ .s_addr = *reinterpret_cast<uint32_t*>(answer.data) });
}
int create_service_socket()
{
int socket = ::socket(AF_UNIX, SOCK_SEQPACKET, 0);
if (socket == -1)
{
perror("socket");
return -1;
}
sockaddr_un addr;
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, "/tmp/resolver.sock");
if (bind(socket, (sockaddr*)&addr, sizeof(addr)) == -1)
{
perror("bind");
close(socket);
return -1;
}
if (chmod("/tmp/resolver.sock", 0777) == -1)
{
perror("chmod");
close(socket);
return -1;
}
if (listen(socket, 10) == -1)
{
perror("listen");
close(socket);
return -1;
}
return socket;
}
BAN::Optional<BAN::String> read_service_query(int socket)
{
static char buffer[4096];
ssize_t nrecv = recv(socket, buffer, sizeof(buffer), 0);
if (nrecv == -1)
{
perror("recv");
return {};
}
buffer[nrecv] = '\0';
return BAN::String(buffer);
}
int main(int, char**)
{
srand(time(nullptr));
int service_socket = create_service_socket();
if (service_socket == -1)
return 1;
int dns_socket = socket(AF_INET, SOCK_DGRAM, 0);
if (dns_socket == -1)
{
perror("socket");
return 1;
}
for (;;)
{
int client = accept(service_socket, nullptr, nullptr);
if (client == -1)
{
perror("accept");
continue;
}
auto query = read_service_query(client);
if (!query.has_value())
continue;
uint16_t id = rand() % 0xFFFF;
if (send_dns_query(dns_socket, *query, id))
{
auto response = read_dns_response(dns_socket, id);
if (response.has_value())
{
if (send(client, response->data(), response->size() + 1, 0) == -1)
perror("send");
close(client);
continue;
}
}
char message[] = "unavailable";
send(client, message, sizeof(message), 0);
close(client);
}
return 0;
}