Compare commits

...

23 Commits

Author SHA1 Message Date
Bananymous 6e981d1222 Shell: Add support for inline environment variables
e.g. `foo=$(echo lol) Shell -c 'echo $foo'` will now print lol!
2024-10-15 23:45:08 +03:00
Bananymous 8317bb13ca Shell: Cleanup code by defining argument types as nested types 2024-10-15 23:44:06 +03:00
Bananymous c40f244dff Shell: remove `env` builtin and add `type`
other shells don't seem to implement as a builtin, so i won't either
2024-10-15 23:42:01 +03:00
Bananymous a6aa048be0 userspace: Implement `env` as an executable 2024-10-15 23:42:01 +03:00
Bananymous 8fd0162393 Kernel: Rewrite x86_64 page tables to use HHDM instead of kmalloc
This allows page tables to not crash the kernel once kmalloc runs out of
its (limited) static memory.
2024-10-14 11:40:30 +03:00
Bananymous f0b18da881 Kernel: Add kmalloc helper APIs 2024-10-14 11:39:04 +03:00
Bananymous 5f63ea8f8a Kernel: Add CPUID check for 1 GiB page support 2024-10-14 11:38:03 +03:00
Bananymous 2b43569927 Kernel: Use enums in boot info instead of magic values 2024-10-14 11:36:51 +03:00
Bananymous 4ba33175cf Kernel: Don't leak memory when preparing fast page
For some reason I was allocating memory for page table entry...
2024-10-14 11:34:48 +03:00
Bananymous 3edc1af560 Kernel: Don't map main bios area in page table initialization
This is only needed for RSDP lookup so it can be done with fast pages
2024-10-14 11:32:54 +03:00
Bananymous 55fbd09e45 Kernel: Rewrite physical memory allocation with PageTable::fast_pages 2024-10-14 11:32:54 +03:00
Bananymous 6a46a25f48 image: Add benchmark flag to measure performance of image operations 2024-10-13 22:05:13 +03:00
Bananymous 88b8ca5b29 LibC: Fix some string functions
I was not casting some required values to char or handling length of
zero
2024-10-13 22:04:08 +03:00
Bananymous fdddb556ae LibC: Implement system() more properly
Old implementation did not ignore and block needed signals
2024-10-13 22:03:15 +03:00
Bananymous d36b64e0c8 LibImage: name color to u32 function to to_argb from to_rgba
This is the actual format that it returns
2024-10-13 22:01:46 +03:00
Bananymous 8adc97980a Shell: rewrite the whole shell to use tokens instead of raw strings
tab completion is still running with raw strings and that has to be
fixed in the future.
2024-10-13 22:00:16 +03:00
Bananymous dab6e5a60f BAN: Cleanup HashMap implementation and add {insert,emplace}_or_assign 2024-10-13 22:00:16 +03:00
Bananymous 0b05e9827b BAN: Use memmove instead of memcpy on overlapping data
I was accidentally using memcpy where memmove was needed
2024-10-13 22:00:16 +03:00
Bananymous 1c1a76d6d7 BAN: Member function pointers now use references instead of pointers
This seems cleaner as class pointer cannot be null anymore
2024-10-13 22:00:16 +03:00
Bananymous df4f37d68d BAN: only define placement new operators for banan-os targets
This allows building and using BAN library outside of banan-os!
2024-10-10 21:55:25 +03:00
Bananymous 44629ba5dd BAN: Allow userspace to use string literals with BAN::Error 2024-10-10 21:54:52 +03:00
Bananymous 2da6776451 BAN: Update {Byte}Span API with better constness
const BAN::Span<int> is now allowed to modify its underlying data, but
the container itself is const.

BAN::Span<const int> can be used for spans over constant data.
2024-10-10 21:53:23 +03:00
Bananymous a68f411024 BAN: Add requires clauses for Container::emplace{,_back} functions
This allows syntax highlighters to report errors!
2024-10-10 21:51:44 +03:00
59 changed files with 3878 additions and 2350 deletions

View File

@ -21,75 +21,56 @@ namespace BAN
, m_size(size) , m_size(size)
{ } { }
ByteSpanGeneral(ByteSpanGeneral& other) template<bool SRC_CONST>
ByteSpanGeneral(const ByteSpanGeneral<SRC_CONST>& other) requires(CONST || !SRC_CONST)
: m_data(other.data()) : m_data(other.data())
, m_size(other.size()) , m_size(other.size())
{ } { }
ByteSpanGeneral(ByteSpanGeneral&& other) template<bool SRC_CONST>
ByteSpanGeneral(ByteSpanGeneral<SRC_CONST>&& other) requires(CONST || !SRC_CONST)
: m_data(other.data()) : m_data(other.data())
, m_size(other.size()) , m_size(other.size())
{ {
other.m_data = nullptr; other.clear();
other.m_size = 0;
} }
template<bool C2>
ByteSpanGeneral(const ByteSpanGeneral<C2>& other) requires(CONST)
: m_data(other.data())
, m_size(other.size())
{ }
template<bool C2>
ByteSpanGeneral(ByteSpanGeneral<C2>&& other) requires(CONST)
: m_data(other.data())
, m_size(other.size())
{
other.m_data = nullptr;
other.m_size = 0;
}
ByteSpanGeneral(Span<uint8_t> other)
: m_data(other.data())
, m_size(other.size())
{ }
ByteSpanGeneral(const Span<const uint8_t>& other) requires(CONST)
: m_data(other.data())
, m_size(other.size())
{ }
ByteSpanGeneral& operator=(ByteSpanGeneral other) template<typename T>
ByteSpanGeneral(const Span<T>& other) requires(is_same_v<T, uint8_t> || (is_same_v<T, const uint8_t> && CONST))
: m_data(other.data())
, m_size(other.size())
{ }
template<typename T>
ByteSpanGeneral(Span<T>&& other) requires(is_same_v<T, uint8_t> || (is_same_v<T, const uint8_t> && CONST))
: m_data(other.data())
, m_size(other.size())
{
other.clear();
}
template<bool SRC_CONST>
ByteSpanGeneral& operator=(const ByteSpanGeneral<SRC_CONST>& other) requires(CONST || !SRC_CONST)
{ {
m_data = other.data(); m_data = other.data();
m_size = other.size(); m_size = other.size();
return *this; return *this;
} }
template<bool C2> template<bool SRC_CONST>
ByteSpanGeneral& operator=(const ByteSpanGeneral<C2>& other) requires(CONST) ByteSpanGeneral& operator=(ByteSpanGeneral<SRC_CONST>&& other) requires(CONST || !SRC_CONST)
{
m_data = other.data();
m_size = other.size();
return *this;
}
ByteSpanGeneral& operator=(Span<uint8_t> other)
{
m_data = other.data();
m_size = other.size();
return *this;
}
ByteSpanGeneral& operator=(const Span<const uint8_t>& other) requires(CONST)
{ {
m_data = other.data(); m_data = other.data();
m_size = other.size(); m_size = other.size();
other.clear();
return *this; return *this;
} }
template<typename S> template<typename S>
requires(CONST || !is_const_v<S>) static ByteSpanGeneral from(S& value) requires(CONST || !is_const_v<S>)
static ByteSpanGeneral from(S& value)
{ {
return ByteSpanGeneral(reinterpret_cast<value_type*>(&value), sizeof(S)); return ByteSpanGeneral(reinterpret_cast<value_type*>(&value), sizeof(S));
} }
template<typename S> template<typename S>
requires(!CONST && !is_const_v<S>) S& as() const requires(!CONST || is_const_v<S>)
S& as()
{ {
ASSERT(m_data); ASSERT(m_data);
ASSERT(m_size >= sizeof(S)); ASSERT(m_size >= sizeof(S));
@ -97,30 +78,13 @@ namespace BAN
} }
template<typename S> template<typename S>
requires(is_const_v<S>) Span<S> as_span() const requires(!CONST || is_const_v<S>)
S& as() const
{
ASSERT(m_data);
ASSERT(m_size >= sizeof(S));
return *reinterpret_cast<S*>(m_data);
}
template<typename S>
requires(!CONST && !is_const_v<S>)
Span<S> as_span()
{ {
ASSERT(m_data); ASSERT(m_data);
return Span<S>(reinterpret_cast<S*>(m_data), m_size / sizeof(S)); return Span<S>(reinterpret_cast<S*>(m_data), m_size / sizeof(S));
} }
template<typename S> ByteSpanGeneral slice(size_type offset, size_type length = size_type(-1)) const
const Span<S> as_span() const
{
ASSERT(m_data);
return Span<S>(reinterpret_cast<S*>(m_data), m_size / sizeof(S));
}
ByteSpanGeneral slice(size_type offset, size_type length = size_type(-1))
{ {
ASSERT(m_data); ASSERT(m_data);
ASSERT(m_size >= offset); ASSERT(m_size >= offset);
@ -130,22 +94,23 @@ namespace BAN
return ByteSpanGeneral(m_data + offset, length); return ByteSpanGeneral(m_data + offset, length);
} }
value_type& operator[](size_type offset) value_type& operator[](size_type offset) const
{
ASSERT(offset < m_size);
return m_data[offset];
}
const value_type& operator[](size_type offset) const
{ {
ASSERT(offset < m_size); ASSERT(offset < m_size);
return m_data[offset]; return m_data[offset];
} }
value_type* data() { return m_data; } value_type* data() const { return m_data; }
const value_type* data() const { return m_data; }
bool empty() const { return m_size == 0; }
size_type size() const { return m_size; } size_type size() const { return m_size; }
void clear()
{
m_data = nullptr;
m_size = 0;
}
private: private:
value_type* m_data { nullptr }; value_type* m_data { nullptr };
size_type m_size { 0 }; size_type m_size { 0 };

View File

@ -24,7 +24,7 @@ namespace BAN
void push(const T&); void push(const T&);
void push(T&&); void push(T&&);
template<typename... Args> template<typename... Args>
void emplace(Args&&... args); void emplace(Args&&... args) requires is_constructible_v<T, Args...>;
void pop(); void pop();
@ -71,7 +71,7 @@ namespace BAN
template<typename T, size_t S> template<typename T, size_t S>
template<typename... Args> template<typename... Args>
void CircularQueue<T, S>::emplace(Args&&... args) void CircularQueue<T, S>::emplace(Args&&... args) requires is_constructible_v<T, Args...>
{ {
ASSERT(!full()); ASSERT(!full());
new (element_at(((m_first + m_size) % capacity()))) T(BAN::forward<Args>(args)...); new (element_at(((m_first + m_size) % capacity()))) T(BAN::forward<Args>(args)...);

View File

@ -36,7 +36,14 @@ namespace BAN
{ {
return Error((uint64_t)error | kernel_error_mask); return Error((uint64_t)error | kernel_error_mask);
} }
#else
template<size_t N>
consteval static Error from_literal(const char (&message)[N])
{
return Error(message);
}
#endif #endif
static Error from_errno(int error) static Error from_errno(int error)
{ {
return Error(error); return Error(error);
@ -54,12 +61,15 @@ namespace BAN
} }
#endif #endif
uint64_t get_error_code() const { return m_error_code; } constexpr uint64_t get_error_code() const { return m_error_code; }
const char* get_message() const const char* get_message() const
{ {
#ifdef __is_kernel #ifdef __is_kernel
if (m_error_code & kernel_error_mask) if (m_error_code & kernel_error_mask)
return Kernel::error_string(kernel_error()); return Kernel::error_string(kernel_error());
#else
if (m_message)
return m_message;
#endif #endif
if (auto* desc = strerrordesc_np(m_error_code)) if (auto* desc = strerrordesc_np(m_error_code))
return desc; return desc;
@ -67,11 +77,21 @@ namespace BAN
} }
private: private:
Error(uint64_t error) constexpr Error(uint64_t error)
: m_error_code(error) : m_error_code(error)
{} {}
uint64_t m_error_code; #ifndef __is_kernel
constexpr Error(const char* message)
: m_message(message)
{}
#endif
uint64_t m_error_code { 0 };
#ifndef __is_kernel
const char* m_message { nullptr };
#endif
}; };
template<typename T> template<typename T>

View File

@ -20,13 +20,13 @@ namespace BAN
new (m_storage) CallablePointer(function); new (m_storage) CallablePointer(function);
} }
template<typename Own> template<typename Own>
Function(Ret(Own::*function)(Args...), Own* owner) Function(Ret(Own::*function)(Args...), Own& owner)
{ {
static_assert(sizeof(CallableMember<Own>) <= m_size); static_assert(sizeof(CallableMember<Own>) <= m_size);
new (m_storage) CallableMember<Own>(function, owner); new (m_storage) CallableMember<Own>(function, owner);
} }
template<typename Own> template<typename Own>
Function(Ret(Own::*function)(Args...) const, const Own* owner) Function(Ret(Own::*function)(Args...) const, const Own& owner)
{ {
static_assert(sizeof(CallableMemberConst<Own>) <= m_size); static_assert(sizeof(CallableMemberConst<Own>) <= m_size);
new (m_storage) CallableMemberConst<Own>(function, owner); new (m_storage) CallableMemberConst<Own>(function, owner);
@ -91,36 +91,36 @@ namespace BAN
template<typename Own> template<typename Own>
struct CallableMember : public CallableBase struct CallableMember : public CallableBase
{ {
CallableMember(Ret(Own::*function)(Args...), Own* owner) CallableMember(Ret(Own::*function)(Args...), Own& owner)
: m_owner(owner) : m_owner(owner)
, m_function(function) , m_function(function)
{ } { }
virtual Ret call(Args... args) const override virtual Ret call(Args... args) const override
{ {
return (m_owner->*m_function)(forward<Args>(args)...); return (m_owner.*m_function)(forward<Args>(args)...);
} }
private: private:
Own* m_owner = nullptr; Own& m_owner;
Ret(Own::*m_function)(Args...) = nullptr; Ret(Own::*m_function)(Args...) = nullptr;
}; };
template<typename Own> template<typename Own>
struct CallableMemberConst : public CallableBase struct CallableMemberConst : public CallableBase
{ {
CallableMemberConst(Ret(Own::*function)(Args...) const, const Own* owner) CallableMemberConst(Ret(Own::*function)(Args...) const, const Own& owner)
: m_owner(owner) : m_owner(owner)
, m_function(function) , m_function(function)
{ } { }
virtual Ret call(Args... args) const override virtual Ret call(Args... args) const override
{ {
return (m_owner->*m_function)(forward<Args>(args)...); return (m_owner.*m_function)(forward<Args>(args)...);
} }
private: private:
const Own* m_owner = nullptr; const Own& m_owner;
Ret(Own::*m_function)(Args...) const = nullptr; Ret(Own::*m_function)(Args...) const = nullptr;
}; };

View File

@ -14,7 +14,7 @@ namespace BAN
struct Entry struct Entry
{ {
template<typename... Args> template<typename... Args>
Entry(const Key& key, Args&&... args) Entry(const Key& key, Args&&... args) requires is_constructible_v<T, Args...>
: key(key) : key(key)
, value(forward<Args>(args)...) , value(forward<Args>(args)...)
{} {}
@ -39,10 +39,27 @@ namespace BAN
HashMap<Key, T, HASH>& operator=(const HashMap<Key, T, HASH>&); HashMap<Key, T, HASH>& operator=(const HashMap<Key, T, HASH>&);
HashMap<Key, T, HASH>& operator=(HashMap<Key, T, HASH>&&); HashMap<Key, T, HASH>& operator=(HashMap<Key, T, HASH>&&);
ErrorOr<void> insert(const Key&, const T&); ErrorOr<void> insert(const Key& key, const T& value) { return emplace(key, value); }
ErrorOr<void> insert(const Key&, T&&); ErrorOr<void> insert(const Key& key, T&& value) { return emplace(key, move(value)); }
ErrorOr<void> insert(Key&& key, const T& value) { return emplace(move(key), value); }
ErrorOr<void> insert(Key&& key, T&& value) { return emplace(move(key), move(value)); }
ErrorOr<void> insert_or_assign(const Key& key, const T& value) { return emplace_or_assign(key, value); }
ErrorOr<void> insert_or_assign(const Key& key, T&& value) { return emplace_or_assign(key, move(value)); }
ErrorOr<void> insert_or_assign(Key&& key, const T& value) { return emplace_or_assign(move(key), value); }
ErrorOr<void> insert_or_assign(Key&& key, T&& value) { return emplace_or_assign(move(key), move(value)); }
template<typename... Args> template<typename... Args>
ErrorOr<void> emplace(const Key&, Args&&...); ErrorOr<void> emplace(const Key& key, Args&&... args) requires is_constructible_v<T, Args...>
{ return emplace(Key(key), forward<Args>(args)...); }
template<typename... Args>
ErrorOr<void> emplace(Key&&, Args&&...) requires is_constructible_v<T, Args...>;
template<typename... Args>
ErrorOr<void> emplace_or_assign(const Key& key, Args&&... args) requires is_constructible_v<T, Args...>
{ return emplace_or_assign(Key(key), forward<Args>(args)...); }
template<typename... Args>
ErrorOr<void> emplace_or_assign(Key&&, Args&&...) requires is_constructible_v<T, Args...>;
iterator begin() { return iterator(m_buckets.end(), m_buckets.begin()); } iterator begin() { return iterator(m_buckets.end(), m_buckets.begin()); }
iterator end() { return iterator(m_buckets.end(), m_buckets.end()); } iterator end() { return iterator(m_buckets.end(), m_buckets.end()); }
@ -116,26 +133,29 @@ namespace BAN
return *this; return *this;
} }
template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH>::insert(const Key& key, const T& value)
{
return insert(key, move(T(value)));
}
template<typename Key, typename T, typename HASH>
ErrorOr<void> HashMap<Key, T, HASH>::insert(const Key& key, T&& value)
{
return emplace(key, move(value));
}
template<typename Key, typename T, typename HASH> template<typename Key, typename T, typename HASH>
template<typename... Args> template<typename... Args>
ErrorOr<void> HashMap<Key, T, HASH>::emplace(const Key& key, Args&&... args) ErrorOr<void> HashMap<Key, T, HASH>::emplace(Key&& key, Args&&... args) requires is_constructible_v<T, Args...>
{ {
ASSERT(!contains(key)); ASSERT(!contains(key));
TRY(rebucket(m_size + 1)); TRY(rebucket(m_size + 1));
auto& bucket = get_bucket(key); auto& bucket = get_bucket(key);
TRY(bucket.emplace_back(key, forward<Args>(args)...)); TRY(bucket.emplace_back(move(key), forward<Args>(args)...));
m_size++;
return {};
}
template<typename Key, typename T, typename HASH>
template<typename... Args>
ErrorOr<void> HashMap<Key, T, HASH>::emplace_or_assign(Key&& key, Args&&... args) requires is_constructible_v<T, Args...>
{
if (empty())
return emplace(move(key), forward<Args>(args)...);
auto& bucket = get_bucket(key);
for (Entry& entry : bucket)
if (entry.key == key)
return {};
TRY(bucket.emplace_back(move(key), forward<Args>(args)...));
m_size++; m_size++;
return {}; return {};
} }

View File

@ -34,9 +34,9 @@ namespace BAN
ErrorOr<void> insert(iterator, const T&); ErrorOr<void> insert(iterator, const T&);
ErrorOr<void> insert(iterator, T&&); ErrorOr<void> insert(iterator, T&&);
template<typename... Args> template<typename... Args>
ErrorOr<void> emplace_back(Args&&...); ErrorOr<void> emplace_back(Args&&...) requires is_constructible_v<T, Args...>;
template<typename... Args> template<typename... Args>
ErrorOr<void> emplace(iterator, Args&&...); ErrorOr<void> emplace(iterator, Args&&...) requires is_constructible_v<T, Args...>;
void pop_back(); void pop_back();
iterator remove(iterator); iterator remove(iterator);
@ -196,14 +196,14 @@ namespace BAN
template<typename T> template<typename T>
template<typename... Args> template<typename... Args>
ErrorOr<void> LinkedList<T>::emplace_back(Args&&... args) ErrorOr<void> LinkedList<T>::emplace_back(Args&&... args) requires is_constructible_v<T, Args...>
{ {
return emplace(end(), forward<Args>(args)...); return emplace(end(), forward<Args>(args)...);
} }
template<typename T> template<typename T>
template<typename... Args> template<typename... Args>
ErrorOr<void> LinkedList<T>::emplace(iterator iter, Args&&... args) ErrorOr<void> LinkedList<T>::emplace(iterator iter, Args&&... args) requires is_constructible_v<T, Args...>
{ {
Node* new_node = TRY(allocate_node(forward<Args>(args)...)); Node* new_node = TRY(allocate_node(forward<Args>(args)...));
insert_node(iter, new_node); insert_node(iter, new_node);

View File

@ -25,7 +25,7 @@ namespace BAN
constexpr Optional& operator=(const Optional&); constexpr Optional& operator=(const Optional&);
template<typename... Args> template<typename... Args>
constexpr Optional& emplace(Args&&...); constexpr Optional& emplace(Args&&...) requires is_constructible_v<T, Args...>;
constexpr T* operator->(); constexpr T* operator->();
constexpr const T* operator->() const; constexpr const T* operator->() const;
@ -111,7 +111,7 @@ namespace BAN
template<typename T> template<typename T>
template<typename... Args> template<typename... Args>
constexpr Optional<T>& Optional<T>::emplace(Args&&... args) constexpr Optional<T>& Optional<T>::emplace(Args&&... args) requires is_constructible_v<T, Args...>
{ {
clear(); clear();
m_has_value = true; m_has_value = true;

View File

@ -2,5 +2,9 @@
#include <stddef.h> #include <stddef.h>
#ifdef __banan_os__
inline void* operator new(size_t, void* addr) { return addr; } inline void* operator new(size_t, void* addr) { return addr; }
inline void* operator new[](size_t, void* addr) { return addr; } inline void* operator new[](size_t, void* addr) { return addr; }
#else
#include <new>
#endif

View File

@ -31,7 +31,7 @@ namespace BAN
ErrorOr<void> push(T&&); ErrorOr<void> push(T&&);
ErrorOr<void> push(const T&); ErrorOr<void> push(const T&);
template<typename... Args> template<typename... Args>
ErrorOr<void> emplace(Args&&...); ErrorOr<void> emplace(Args&&...) requires is_constructible_v<T, Args...>;
ErrorOr<void> reserve(size_type); ErrorOr<void> reserve(size_type);
ErrorOr<void> shrink_to_fit(); ErrorOr<void> shrink_to_fit();
@ -131,7 +131,7 @@ namespace BAN
template<typename T> template<typename T>
template<typename... Args> template<typename... Args>
ErrorOr<void> Queue<T>::emplace(Args&&... args) ErrorOr<void> Queue<T>::emplace(Args&&... args) requires is_constructible_v<T, Args...>
{ {
TRY(ensure_capacity(m_size + 1)); TRY(ensure_capacity(m_size + 1));
new (m_data + m_size) T(forward<Args>(args)...); new (m_data + m_size) T(forward<Args>(args)...);

View File

@ -76,8 +76,9 @@ namespace BAN
return ptr; return ptr;
} }
// NOTE: don't use is_constructible_v<T, Args...> as RefPtr<T> is allowed with friends
template<typename... Args> template<typename... Args>
static ErrorOr<RefPtr> create(Args&&... args) static ErrorOr<RefPtr> create(Args&&... args) requires requires(Args&&... args) { T(forward<Args>(args)...); }
{ {
T* pointer = new T(forward<Args>(args)...); T* pointer = new T(forward<Args>(args)...);
if (pointer == nullptr) if (pointer == nullptr)

View File

@ -14,121 +14,91 @@ namespace BAN
public: public:
using value_type = T; using value_type = T;
using size_type = size_t; using size_type = size_t;
using iterator = IteratorSimple<T, Span>; using iterator = IteratorSimple<value_type, Span>;
using const_iterator = ConstIteratorSimple<T, Span>; using const_iterator = ConstIteratorSimple<value_type, Span>;
private:
template<typename S>
static inline constexpr bool can_init_from_v = is_same_v<value_type, const S> || is_same_v<value_type, S>;
public: public:
Span() = default; Span() = default;
Span(T*, size_type); Span(value_type* data, size_type size)
Span(Span<T>&); : m_data(data)
, m_size(size)
{ }
template<typename S> template<typename S>
requires(is_same_v<T, const S>) Span(const Span<S>& other) requires can_init_from_v<S>
Span(const Span<S>&); : m_data(other.m_data)
, m_size(other.m_size)
{ }
template<typename S>
Span(Span<S>&& other) requires can_init_from_v<S>
: m_data(other.m_data)
, m_size(other.m_size)
{
other.clear();
}
template<typename S>
Span& operator=(const Span<S>& other) requires can_init_from_v<S>
{
m_data = other.m_data;
m_size = other.m_size;
return *this;
}
template<typename S>
Span& operator=(Span<S>&& other) requires can_init_from_v<S>
{
m_data = other.m_data;
m_size = other.m_size;
return *this;
}
iterator begin() { return iterator(m_data); } iterator begin() { return iterator(m_data); }
iterator end() { return iterator(m_data + m_size); } iterator end() { return iterator(m_data + m_size); }
const_iterator begin() const { return const_iterator(m_data); } const_iterator begin() const { return const_iterator(m_data); }
const_iterator end() const { return const_iterator(m_data + m_size); } const_iterator end() const { return const_iterator(m_data + m_size); }
T& operator[](size_type); value_type& operator[](size_type index) const
const T& operator[](size_type) const; {
ASSERT(index < m_size);
return m_data[index];
}
T* data(); value_type* data() const
const T* data() const; {
ASSERT(m_data);
return m_data;
}
bool empty() const; bool empty() const { return m_size == 0; }
size_type size() const; size_type size() const { return m_size; }
void clear(); void clear()
{
m_data = nullptr;
m_size = 0;
}
Span slice(size_type, size_type = ~size_type(0)); Span slice(size_type start, size_type length = ~size_type(0)) const
{
ASSERT(m_data);
ASSERT(start <= m_size);
if (length == ~size_type(0))
length = m_size - start;
ASSERT(m_size - start >= length);
return Span(m_data + start, length);
}
Span<const T> as_const() const { return Span<const T>(m_data, m_size); } Span<const value_type> as_const() const { return *this; }
private: private:
T* m_data = nullptr; value_type* m_data = nullptr;
size_type m_size = 0; size_type m_size = 0;
friend class Span<const value_type>;
}; };
template<typename T>
Span<T>::Span(T* data, size_type size)
: m_data(data)
, m_size(size)
{
}
template<typename T>
Span<T>::Span(Span& other)
: m_data(other.data())
, m_size(other.size())
{
}
template<typename T>
template<typename S>
requires(is_same_v<T, const S>)
Span<T>::Span(const Span<S>& other)
: m_data(other.data())
, m_size(other.size())
{
}
template<typename T>
T& Span<T>::operator[](size_type index)
{
ASSERT(m_data);
ASSERT(index < m_size);
return m_data[index];
}
template<typename T>
const T& Span<T>::operator[](size_type index) const
{
ASSERT(m_data);
ASSERT(index < m_size);
return m_data[index];
}
template<typename T>
T* Span<T>::data()
{
return m_data;
}
template<typename T>
const T* Span<T>::data() const
{
return m_data;
}
template<typename T>
bool Span<T>::empty() const
{
return m_size == 0;
}
template<typename T>
typename Span<T>::size_type Span<T>::size() const
{
return m_size;
}
template<typename T>
void Span<T>::clear()
{
m_data = nullptr;
m_size = 0;
}
template<typename T>
Span<T> Span<T>::slice(size_type start, size_type length)
{
ASSERT(m_data);
ASSERT(start <= m_size);
if (length == ~size_type(0))
length = m_size - start;
ASSERT(m_size - start >= length);
return Span(m_data + start, length);
}
} }

View File

@ -127,7 +127,7 @@ namespace BAN
void remove(size_type index) void remove(size_type index)
{ {
ASSERT(index < m_size); ASSERT(index < m_size);
memcpy(data() + index, data() + index + 1, m_size - index); memmove(data() + index, data() + index + 1, m_size - index);
m_size--; m_size--;
data()[m_size] = '\0'; data()[m_size] = '\0';
} }

View File

@ -33,8 +33,9 @@ namespace BAN
return uniq; return uniq;
} }
// NOTE: don't use is_constructible_v<T, Args...> as UniqPtr<T> is allowed with friends
template<typename... Args> template<typename... Args>
static BAN::ErrorOr<UniqPtr> create(Args&&... args) static BAN::ErrorOr<UniqPtr> create(Args&&... args) requires requires(Args&&... args) { T(forward<Args>(args)...); }
{ {
UniqPtr uniq; UniqPtr uniq;
uniq.m_pointer = new T(BAN::forward<Args>(args)...); uniq.m_pointer = new T(BAN::forward<Args>(args)...);

View File

@ -217,7 +217,7 @@ namespace BAN
} }
template<typename T, typename... Args> template<typename T, typename... Args>
void emplace(Args&&... args) requires (can_have<T>()) void emplace(Args&&... args) requires (can_have<T>() && is_constructible_v<T, Args...>)
{ {
clear(); clear();
m_index = detail::index<T, Ts...>(); m_index = detail::index<T, Ts...>();

View File

@ -35,9 +35,9 @@ namespace BAN
ErrorOr<void> push_back(T&&); ErrorOr<void> push_back(T&&);
ErrorOr<void> push_back(const T&); ErrorOr<void> push_back(const T&);
template<typename... Args> template<typename... Args>
ErrorOr<void> emplace_back(Args&&...); ErrorOr<void> emplace_back(Args&&...) requires is_constructible_v<T, Args...>;
template<typename... Args> template<typename... Args>
ErrorOr<void> emplace(size_type, Args&&...); ErrorOr<void> emplace(size_type, Args&&...) requires is_constructible_v<T, Args...>;
ErrorOr<void> insert(size_type, T&&); ErrorOr<void> insert(size_type, T&&);
ErrorOr<void> insert(size_type, const T&); ErrorOr<void> insert(size_type, const T&);
@ -169,7 +169,7 @@ namespace BAN
template<typename T> template<typename T>
template<typename... Args> template<typename... Args>
ErrorOr<void> Vector<T>::emplace_back(Args&&... args) ErrorOr<void> Vector<T>::emplace_back(Args&&... args) requires is_constructible_v<T, Args...>
{ {
TRY(ensure_capacity(m_size + 1)); TRY(ensure_capacity(m_size + 1));
new (m_data + m_size) T(forward<Args>(args)...); new (m_data + m_size) T(forward<Args>(args)...);
@ -179,7 +179,7 @@ namespace BAN
template<typename T> template<typename T>
template<typename... Args> template<typename... Args>
ErrorOr<void> Vector<T>::emplace(size_type index, Args&&... args) ErrorOr<void> Vector<T>::emplace(size_type index, Args&&... args) requires is_constructible_v<T, Args...>
{ {
ASSERT(index <= m_size); ASSERT(index <= m_size);
TRY(ensure_capacity(m_size + 1)); TRY(ensure_capacity(m_size + 1));

View File

@ -46,7 +46,7 @@ namespace Kernel
return result; return result;
} }
void PageTable::initialize() void PageTable::initialize_pre_heap()
{ {
if (CPUID::has_nxe()) if (CPUID::has_nxe())
s_has_nxe = true; s_has_nxe = true;
@ -65,6 +65,11 @@ namespace Kernel
s_kernel->initial_load(); s_kernel->initial_load();
} }
void PageTable::initialize_post_heap()
{
// NOTE: this is no-op as our 32 bit target does not use hhdm
}
void PageTable::initial_load() void PageTable::initial_load()
{ {
if (s_has_nxe) if (s_has_nxe)
@ -150,14 +155,6 @@ namespace Kernel
prepare_fast_page(); prepare_fast_page();
// Map main bios area below 1 MiB
map_range_at(
0x000E0000,
P2V(0x000E0000),
0x00100000 - 0x000E0000,
PageTable::Flags::Present
);
// Map (phys_kernel_start -> phys_kernel_end) to (virt_kernel_start -> virt_kernel_end) // Map (phys_kernel_start -> phys_kernel_end) to (virt_kernel_start -> virt_kernel_end)
ASSERT((vaddr_t)g_kernel_start % PAGE_SIZE == 0); ASSERT((vaddr_t)g_kernel_start % PAGE_SIZE == 0);
map_range_at( map_range_at(
@ -196,7 +193,6 @@ namespace Kernel
{ {
constexpr uint64_t pdpte = (fast_page() >> 30) & 0x1FF; constexpr uint64_t pdpte = (fast_page() >> 30) & 0x1FF;
constexpr uint64_t pde = (fast_page() >> 21) & 0x1FF; constexpr uint64_t pde = (fast_page() >> 21) & 0x1FF;
constexpr uint64_t pte = (fast_page() >> 12) & 0x1FF;
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct)); uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct));
ASSERT(pdpt[pdpte] & Flags::Present); ASSERT(pdpt[pdpte] & Flags::Present);
@ -204,10 +200,6 @@ namespace Kernel
uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte]) & PAGE_ADDR_MASK); uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte]) & PAGE_ADDR_MASK);
ASSERT(!(pd[pde] & Flags::Present)); ASSERT(!(pd[pde] & Flags::Present));
pd[pde] = V2P(allocate_zeroed_page_aligned_page()) | Flags::ReadWrite | Flags::Present; pd[pde] = V2P(allocate_zeroed_page_aligned_page()) | Flags::ReadWrite | Flags::Present;
uint64_t* pt = reinterpret_cast<uint64_t*>(P2V(pd[pde]) & PAGE_ADDR_MASK);
ASSERT(!(pt[pte] & Flags::Present));
pt[pte] = V2P(allocate_zeroed_page_aligned_page());
} }
void PageTable::map_fast_page(paddr_t paddr) void PageTable::map_fast_page(paddr_t paddr)

View File

@ -1,6 +1,7 @@
#include <kernel/BootInfo.h> #include <kernel/BootInfo.h>
#include <kernel/CPUID.h> #include <kernel/CPUID.h>
#include <kernel/Lock/SpinLock.h> #include <kernel/Lock/SpinLock.h>
#include <kernel/Memory/Heap.h>
#include <kernel/Memory/kmalloc.h> #include <kernel/Memory/kmalloc.h>
#include <kernel/Memory/PageTable.h> #include <kernel/Memory/PageTable.h>
@ -21,12 +22,18 @@ namespace Kernel
SpinLock PageTable::s_fast_page_lock; SpinLock PageTable::s_fast_page_lock;
static constexpr vaddr_t s_hhdm_offset = 0xFFFF800000000000;
static bool s_is_hddm_initialized = false;
constexpr uint64_t s_page_flag_mask = 0x8000000000000FFF;
constexpr uint64_t s_page_addr_mask = ~s_page_flag_mask;
static PageTable* s_kernel = nullptr; static PageTable* s_kernel = nullptr;
static bool s_has_nxe = false; static bool s_has_nxe = false;
static bool s_has_pge = false; static bool s_has_pge = false;
static bool s_has_gib = false;
// PML4 entry for kernel memory static paddr_t s_global_pml4_entries[512] { 0 };
static paddr_t s_global_pml4e = 0;
static constexpr inline bool is_canonical(uintptr_t addr) static constexpr inline bool is_canonical(uintptr_t addr)
{ {
@ -47,6 +54,67 @@ namespace Kernel
return addr; return addr;
} }
struct FuncsKmalloc
{
static paddr_t allocate_zeroed_page_aligned_page()
{
void* page = kmalloc(PAGE_SIZE, PAGE_SIZE, true);
ASSERT(page);
memset(page, 0, PAGE_SIZE);
return kmalloc_paddr_of(reinterpret_cast<vaddr_t>(page)).value();
}
static void unallocate_page(paddr_t paddr)
{
kfree(reinterpret_cast<void*>(kmalloc_vaddr_of(paddr).value()));
}
static paddr_t V2P(vaddr_t vaddr)
{
return vaddr - KERNEL_OFFSET + g_boot_info.kernel_paddr;
}
static uint64_t* P2V(paddr_t paddr)
{
return reinterpret_cast<uint64_t*>(paddr - g_boot_info.kernel_paddr + KERNEL_OFFSET);
}
};
struct FuncsHHDM
{
static paddr_t allocate_zeroed_page_aligned_page()
{
const paddr_t paddr = Heap::get().take_free_page();
ASSERT(paddr);
memset(reinterpret_cast<void*>(paddr + s_hhdm_offset), 0, PAGE_SIZE);
return paddr;
}
static void unallocate_page(paddr_t paddr)
{
Heap::get().release_page(paddr);
}
static paddr_t V2P(vaddr_t vaddr)
{
ASSERT(vaddr >= s_hhdm_offset);
ASSERT(vaddr < KERNEL_OFFSET);
return vaddr - s_hhdm_offset;
}
static uint64_t* P2V(paddr_t paddr)
{
ASSERT(paddr != 0);
ASSERT(!BAN::Math::will_addition_overflow(paddr, s_hhdm_offset));
return reinterpret_cast<uint64_t*>(paddr + s_hhdm_offset);
}
};
static paddr_t (*allocate_zeroed_page_aligned_page)() = &FuncsKmalloc::allocate_zeroed_page_aligned_page;
static void (*unallocate_page)(paddr_t) = &FuncsKmalloc::unallocate_page;
static paddr_t (*V2P)(vaddr_t) = &FuncsKmalloc::V2P;
static uint64_t* (*P2V)(paddr_t) = &FuncsKmalloc::P2V;
static inline PageTable::flags_t parse_flags(uint64_t entry) static inline PageTable::flags_t parse_flags(uint64_t entry)
{ {
using Flags = PageTable::Flags; using Flags = PageTable::Flags;
@ -65,7 +133,190 @@ namespace Kernel
return result; return result;
} }
void PageTable::initialize() // page size:
// 0: 4 KiB
// 1: 2 MiB
// 2: 1 GiB
static void init_map_hhdm_page(paddr_t pml4, paddr_t paddr, uint8_t page_size)
{
ASSERT(0 <= page_size && page_size <= 2);
const vaddr_t vaddr = paddr + s_hhdm_offset;
ASSERT(vaddr < KERNEL_OFFSET);
const vaddr_t uc_vaddr = uncanonicalize(vaddr);
const uint16_t pml4e = (uc_vaddr >> 39) & 0x1FF;
const uint16_t pdpte = (uc_vaddr >> 30) & 0x1FF;
const uint16_t pde = (uc_vaddr >> 21) & 0x1FF;
const uint16_t pte = (uc_vaddr >> 12) & 0x1FF;
static constexpr uint64_t hhdm_flags = (1u << 1) | (1u << 0);
const auto get_or_allocate_entry =
[](paddr_t table, uint16_t table_entry, uint64_t extra_flags) -> paddr_t
{
paddr_t result = 0;
PageTable::with_fast_page(table, [&] {
const uint64_t entry = PageTable::fast_page_as_sized<uint64_t>(table_entry);
if (entry & (1u << 0))
result = entry & s_page_addr_mask;
});
if (result != 0)
return result;
const paddr_t new_paddr = Heap::get().take_free_page();
ASSERT(new_paddr);
PageTable::with_fast_page(new_paddr, [] {
memset(reinterpret_cast<void*>(PageTable::fast_page_as_ptr()), 0, PAGE_SIZE);
});
PageTable::with_fast_page(table, [&] {
uint64_t& entry = PageTable::fast_page_as_sized<uint64_t>(table_entry);
entry = new_paddr | hhdm_flags | extra_flags;
});
return new_paddr;
};
const uint64_t pgsize_flag = page_size ? (static_cast<uint64_t>(1) << 7) : 0;
const uint64_t global_flag = s_has_pge ? (static_cast<uint64_t>(1) << 8) : 0;
const uint64_t noexec_flag = s_has_nxe ? (static_cast<uint64_t>(1) << 63) : 0;
const paddr_t pdpt = get_or_allocate_entry(pml4, pml4e, noexec_flag);
s_global_pml4_entries[pml4e] = pdpt | hhdm_flags;
paddr_t lowest_paddr = pdpt;
uint16_t lowest_entry = pdpte;
if (page_size < 2)
{
lowest_paddr = get_or_allocate_entry(lowest_paddr, lowest_entry, noexec_flag);
lowest_entry = pde;
}
if (page_size < 1)
{
lowest_paddr = get_or_allocate_entry(lowest_paddr, lowest_entry, noexec_flag);
lowest_entry = pte;
}
PageTable::with_fast_page(lowest_paddr, [&] {
uint64_t& entry = PageTable::fast_page_as_sized<uint64_t>(lowest_entry);
entry = paddr | hhdm_flags | noexec_flag | global_flag | pgsize_flag;
});
}
static void init_map_hhdm(paddr_t pml4)
{
for (const auto& entry : g_boot_info.memory_map_entries)
{
bool should_map = false;
switch (entry.type)
{
case MemoryMapEntry::Type::Available:
should_map = true;
break;
case MemoryMapEntry::Type::ACPIReclaim:
case MemoryMapEntry::Type::ACPINVS:
case MemoryMapEntry::Type::Reserved:
should_map = false;
break;
}
if (!should_map)
continue;
constexpr size_t one_gib = 1024 * 1024 * 1024;
constexpr size_t two_mib = 2 * 1024 * 1024;
const paddr_t entry_start = (entry.address + PAGE_SIZE - 1) & PAGE_ADDR_MASK;
const paddr_t entry_end = (entry.address + entry.length) & PAGE_ADDR_MASK;
for (paddr_t paddr = entry_start; paddr < entry_end;)
{
if (s_has_gib && paddr % one_gib == 0 && paddr + one_gib <= entry_end)
{
init_map_hhdm_page(pml4, paddr, 2);
paddr += one_gib;
}
else if (paddr % two_mib == 0 && paddr + two_mib <= entry_end)
{
init_map_hhdm_page(pml4, paddr, 1);
paddr += two_mib;
}
else
{
init_map_hhdm_page(pml4, paddr, 0);
paddr += PAGE_SIZE;
}
}
}
}
static paddr_t copy_page_from_kmalloc_to_heap(paddr_t kmalloc_paddr)
{
const paddr_t heap_paddr = Heap::get().take_free_page();
ASSERT(heap_paddr);
const vaddr_t kmalloc_vaddr = kmalloc_vaddr_of(kmalloc_paddr).value();
PageTable::with_fast_page(heap_paddr, [kmalloc_vaddr] {
memcpy(PageTable::fast_page_as_ptr(), reinterpret_cast<void*>(kmalloc_vaddr), PAGE_SIZE);
});
return heap_paddr;
}
static void copy_paging_structure_to_heap(uint64_t* old_table, uint64_t* new_table, int depth)
{
if (depth == 0)
return;
constexpr uint64_t page_flag_mask = 0x8000000000000FFF;
constexpr uint64_t page_addr_mask = ~page_flag_mask;
for (uint16_t index = 0; index < 512; index++)
{
const uint64_t old_entry = old_table[index];
if (old_entry == 0)
{
new_table[index] = 0;
continue;
}
const paddr_t old_paddr = old_entry & page_addr_mask;
const paddr_t new_paddr = copy_page_from_kmalloc_to_heap(old_paddr);
new_table[index] = new_paddr | (old_entry & page_flag_mask);
uint64_t* next_old_table = reinterpret_cast<uint64_t*>(old_paddr + s_hhdm_offset);
uint64_t* next_new_table = reinterpret_cast<uint64_t*>(new_paddr + s_hhdm_offset);
copy_paging_structure_to_heap(next_old_table, next_new_table, depth - 1);
}
}
static void free_kmalloc_paging_structure(uint64_t* table, int depth)
{
if (depth == 0)
return;
constexpr uint64_t page_flag_mask = 0x8000000000000FFF;
constexpr uint64_t page_addr_mask = ~page_flag_mask;
for (uint16_t index = 0; index < 512; index++)
{
const uint64_t entry = table[index];
if (entry == 0)
continue;
const paddr_t paddr = entry & page_addr_mask;
uint64_t* next_table = reinterpret_cast<uint64_t*>(paddr + s_hhdm_offset);
free_kmalloc_paging_structure(next_table, depth - 1);
kfree(reinterpret_cast<void*>(kmalloc_vaddr_of(paddr).value()));
}
}
void PageTable::initialize_pre_heap()
{ {
if (CPUID::has_nxe()) if (CPUID::has_nxe())
s_has_nxe = true; s_has_nxe = true;
@ -73,11 +324,64 @@ namespace Kernel
if (CPUID::has_pge()) if (CPUID::has_pge())
s_has_pge = true; s_has_pge = true;
if (CPUID::has_1gib_pages())
s_has_gib = true;
ASSERT(s_kernel == nullptr); ASSERT(s_kernel == nullptr);
s_kernel = new PageTable(); s_kernel = new PageTable();
ASSERT(s_kernel); ASSERT(s_kernel);
s_kernel->m_highest_paging_struct = allocate_zeroed_page_aligned_page();
s_kernel->prepare_fast_page();
s_kernel->initialize_kernel(); s_kernel->initialize_kernel();
for (auto pml4e : s_global_pml4_entries)
ASSERT(pml4e == 0);
const uint64_t* pml4 = P2V(s_kernel->m_highest_paging_struct);
s_global_pml4_entries[511] = pml4[511];
}
void PageTable::initialize_post_heap()
{
ASSERT(s_kernel);
init_map_hhdm(s_kernel->m_highest_paging_struct);
const paddr_t old_pml4_paddr = s_kernel->m_highest_paging_struct;
const paddr_t new_pml4_paddr = copy_page_from_kmalloc_to_heap(old_pml4_paddr);
uint64_t* old_pml4 = reinterpret_cast<uint64_t*>(kmalloc_vaddr_of(old_pml4_paddr).value());
uint64_t* new_pml4 = reinterpret_cast<uint64_t*>(new_pml4_paddr + s_hhdm_offset);
const paddr_t old_pdpt_paddr = old_pml4[511] & s_page_addr_mask;
const paddr_t new_pdpt_paddr = Heap::get().take_free_page();
ASSERT(new_pdpt_paddr);
uint64_t* old_pdpt = reinterpret_cast<uint64_t*>(old_pdpt_paddr + s_hhdm_offset);
uint64_t* new_pdpt = reinterpret_cast<uint64_t*>(new_pdpt_paddr + s_hhdm_offset);
copy_paging_structure_to_heap(old_pdpt, new_pdpt, 2);
new_pml4[511] = new_pdpt_paddr | (old_pml4[511] & s_page_flag_mask);
s_global_pml4_entries[511] = new_pml4[511];
s_kernel->m_highest_paging_struct = new_pml4_paddr;
s_kernel->load();
free_kmalloc_paging_structure(old_pdpt, 2);
kfree(reinterpret_cast<void*>(kmalloc_vaddr_of(old_pdpt_paddr).value()));
kfree(reinterpret_cast<void*>(kmalloc_vaddr_of(old_pml4_paddr).value()));
allocate_zeroed_page_aligned_page = &FuncsHHDM::allocate_zeroed_page_aligned_page;
unallocate_page = &FuncsHHDM::unallocate_page;
V2P = &FuncsHHDM::V2P;
P2V = &FuncsHHDM::P2V;
s_is_hddm_initialized = true;
// This is a hack to unmap fast page. fast page pt is copied
// while it is mapped, so we need to manually unmap it
SpinLockGuard _(s_fast_page_lock);
unmap_fast_page();
} }
void PageTable::initial_load() void PageTable::initial_load()
@ -136,75 +440,40 @@ namespace Kernel
return true; return true;
} }
static uint64_t* allocate_zeroed_page_aligned_page()
{
void* page = kmalloc(PAGE_SIZE, PAGE_SIZE, true);
ASSERT(page);
memset(page, 0, PAGE_SIZE);
return (uint64_t*)page;
}
template<typename T>
static paddr_t V2P(const T vaddr)
{
return (vaddr_t)vaddr - KERNEL_OFFSET + g_boot_info.kernel_paddr;
}
template<typename T>
static vaddr_t P2V(const T paddr)
{
return (paddr_t)paddr - g_boot_info.kernel_paddr + KERNEL_OFFSET;
}
void PageTable::initialize_kernel() void PageTable::initialize_kernel()
{ {
ASSERT(s_global_pml4e == 0);
s_global_pml4e = V2P(allocate_zeroed_page_aligned_page());
m_highest_paging_struct = V2P(allocate_zeroed_page_aligned_page());
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct);
pml4[511] = s_global_pml4e;
prepare_fast_page();
// Map main bios area below 1 MiB
map_range_at(
0x000E0000,
P2V(0x000E0000),
0x00100000 - 0x000E0000,
PageTable::Flags::Present
);
// Map (phys_kernel_start -> phys_kernel_end) to (virt_kernel_start -> virt_kernel_end) // Map (phys_kernel_start -> phys_kernel_end) to (virt_kernel_start -> virt_kernel_end)
ASSERT((vaddr_t)g_kernel_start % PAGE_SIZE == 0); const vaddr_t kernel_start = reinterpret_cast<vaddr_t>(g_kernel_start);
map_range_at( map_range_at(
V2P(g_kernel_start), V2P(kernel_start),
(vaddr_t)g_kernel_start, kernel_start,
g_kernel_end - g_kernel_start, g_kernel_end - g_kernel_start,
Flags::Present Flags::Present
); );
// Map executable kernel memory as executable // Map executable kernel memory as executable
const vaddr_t kernel_execute_start = reinterpret_cast<vaddr_t>(g_kernel_execute_start);
map_range_at( map_range_at(
V2P(g_kernel_execute_start), V2P(kernel_execute_start),
(vaddr_t)g_kernel_execute_start, kernel_execute_start,
g_kernel_execute_end - g_kernel_execute_start, g_kernel_execute_end - g_kernel_execute_start,
Flags::Execute | Flags::Present Flags::Execute | Flags::Present
); );
// Map writable kernel memory as writable // Map writable kernel memory as writable
const vaddr_t kernel_writable_start = reinterpret_cast<vaddr_t>(g_kernel_writable_start);
map_range_at( map_range_at(
V2P(g_kernel_writable_start), V2P(kernel_writable_start),
(vaddr_t)g_kernel_writable_start, kernel_writable_start,
g_kernel_writable_end - g_kernel_writable_start, g_kernel_writable_end - g_kernel_writable_start,
Flags::ReadWrite | Flags::Present Flags::ReadWrite | Flags::Present
); );
// Map userspace memory // Map userspace memory
const vaddr_t userspace_start = reinterpret_cast<vaddr_t>(g_userspace_start);
map_range_at( map_range_at(
V2P(g_userspace_start), V2P(userspace_start),
(vaddr_t)g_userspace_start, userspace_start,
g_userspace_end - g_userspace_start, g_userspace_end - g_userspace_start,
Flags::Execute | Flags::UserSupervisor | Flags::Present Flags::Execute | Flags::UserSupervisor | Flags::Present
); );
@ -216,23 +485,18 @@ namespace Kernel
constexpr uint64_t pml4e = (uc_vaddr >> 39) & 0x1FF; constexpr uint64_t pml4e = (uc_vaddr >> 39) & 0x1FF;
constexpr uint64_t pdpte = (uc_vaddr >> 30) & 0x1FF; constexpr uint64_t pdpte = (uc_vaddr >> 30) & 0x1FF;
constexpr uint64_t pde = (uc_vaddr >> 21) & 0x1FF; constexpr uint64_t pde = (uc_vaddr >> 21) & 0x1FF;
constexpr uint64_t pte = (uc_vaddr >> 12) & 0x1FF;
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct); uint64_t* pml4 = P2V(m_highest_paging_struct);
ASSERT(!(pml4[pml4e] & Flags::Present)); ASSERT(!(pml4[pml4e] & Flags::Present));
pml4[pml4e] = V2P(allocate_zeroed_page_aligned_page()) | Flags::ReadWrite | Flags::Present; pml4[pml4e] = allocate_zeroed_page_aligned_page() | Flags::ReadWrite | Flags::Present;
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); uint64_t* pdpt = P2V(pml4[pml4e] & s_page_addr_mask);
ASSERT(!(pdpt[pdpte] & Flags::Present)); ASSERT(!(pdpt[pdpte] & Flags::Present));
pdpt[pdpte] = V2P(allocate_zeroed_page_aligned_page()) | Flags::ReadWrite | Flags::Present; pdpt[pdpte] = allocate_zeroed_page_aligned_page() | Flags::ReadWrite | Flags::Present;
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK); uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
ASSERT(!(pd[pde] & Flags::Present)); ASSERT(!(pd[pde] & Flags::Present));
pd[pde] = V2P(allocate_zeroed_page_aligned_page()) | Flags::ReadWrite | Flags::Present; pd[pde] = allocate_zeroed_page_aligned_page() | Flags::ReadWrite | Flags::Present;
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK);
ASSERT(!(pt[pte] & Flags::Present));
pt[pte] = V2P(allocate_zeroed_page_aligned_page());
} }
void PageTable::map_fast_page(paddr_t paddr) void PageTable::map_fast_page(paddr_t paddr)
@ -248,10 +512,10 @@ namespace Kernel
constexpr uint64_t pde = (uc_vaddr >> 21) & 0x1FF; constexpr uint64_t pde = (uc_vaddr >> 21) & 0x1FF;
constexpr uint64_t pte = (uc_vaddr >> 12) & 0x1FF; constexpr uint64_t pte = (uc_vaddr >> 12) & 0x1FF;
uint64_t* pml4 = (uint64_t*)P2V(s_kernel->m_highest_paging_struct); const uint64_t* pml4 = P2V(s_kernel->m_highest_paging_struct);
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); const uint64_t* pdpt = P2V(pml4[pml4e] & s_page_addr_mask);
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK); const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK); uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
ASSERT(!(pt[pte] & Flags::Present)); ASSERT(!(pt[pte] & Flags::Present));
pt[pte] = paddr | Flags::ReadWrite | Flags::Present; pt[pte] = paddr | Flags::ReadWrite | Flags::Present;
@ -271,10 +535,10 @@ namespace Kernel
constexpr uint64_t pde = (uc_vaddr >> 21) & 0x1FF; constexpr uint64_t pde = (uc_vaddr >> 21) & 0x1FF;
constexpr uint64_t pte = (uc_vaddr >> 12) & 0x1FF; constexpr uint64_t pte = (uc_vaddr >> 12) & 0x1FF;
uint64_t* pml4 = (uint64_t*)P2V(s_kernel->m_highest_paging_struct); const uint64_t* pml4 = P2V(s_kernel->m_highest_paging_struct);
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); const uint64_t* pdpt = P2V(pml4[pml4e] & s_page_addr_mask);
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK); const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK); uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
ASSERT(pt[pte] & Flags::Present); ASSERT(pt[pte] & Flags::Present);
pt[pte] = 0; pt[pte] = 0;
@ -295,43 +559,46 @@ namespace Kernel
void PageTable::map_kernel_memory() void PageTable::map_kernel_memory()
{ {
ASSERT(s_kernel); ASSERT(s_kernel);
ASSERT(s_global_pml4e); ASSERT(s_global_pml4_entries[511]);
ASSERT(m_highest_paging_struct == 0); ASSERT(m_highest_paging_struct == 0);
m_highest_paging_struct = V2P(allocate_zeroed_page_aligned_page()); m_highest_paging_struct = allocate_zeroed_page_aligned_page();
uint64_t* kernel_pml4 = (uint64_t*)P2V(s_kernel->m_highest_paging_struct); PageTable::with_fast_page(m_highest_paging_struct, [] {
for (size_t i = 0; i < 512; i++)
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct); {
pml4[511] = kernel_pml4[511]; if (s_global_pml4_entries[i] == 0)
continue;
ASSERT(i >= 256);
PageTable::fast_page_as_sized<uint64_t>(i) = s_global_pml4_entries[i];
}
});
} }
PageTable::~PageTable() PageTable::~PageTable()
{ {
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct); // NOTE: we only loop until 256 since after that is hhdm
const uint64_t* pml4 = P2V(m_highest_paging_struct);
// NOTE: we only loop until 511 since the last one is the kernel memory for (uint64_t pml4e = 0; pml4e < 256; pml4e++)
for (uint64_t pml4e = 0; pml4e < 511; pml4e++)
{ {
if (!(pml4[pml4e] & Flags::Present)) if (!(pml4[pml4e] & Flags::Present))
continue; continue;
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); const uint64_t* pdpt = P2V(pml4[pml4e] & s_page_addr_mask);
for (uint64_t pdpte = 0; pdpte < 512; pdpte++) for (uint64_t pdpte = 0; pdpte < 512; pdpte++)
{ {
if (!(pdpt[pdpte] & Flags::Present)) if (!(pdpt[pdpte] & Flags::Present))
continue; continue;
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK); const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
for (uint64_t pde = 0; pde < 512; pde++) for (uint64_t pde = 0; pde < 512; pde++)
{ {
if (!(pd[pde] & Flags::Present)) if (!(pd[pde] & Flags::Present))
continue; continue;
kfree((void*)P2V(pd[pde] & PAGE_ADDR_MASK)); unallocate_page(pd[pde] & s_page_addr_mask);
} }
kfree(pd); unallocate_page(pdpt[pdpte] & s_page_addr_mask);
} }
kfree(pdpt); unallocate_page(pml4[pml4e] & s_page_addr_mask);
} }
kfree(pml4); unallocate_page(m_highest_paging_struct);
} }
void PageTable::load() void PageTable::load()
@ -368,24 +635,24 @@ namespace Kernel
Kernel::panic("unmapping {8H}, kernel: {}", vaddr, this == s_kernel); Kernel::panic("unmapping {8H}, kernel: {}", vaddr, this == s_kernel);
ASSERT(is_canonical(vaddr)); ASSERT(is_canonical(vaddr));
vaddr_t uc_vaddr = uncanonicalize(vaddr); const vaddr_t uc_vaddr = uncanonicalize(vaddr);
ASSERT(vaddr % PAGE_SIZE == 0); ASSERT(vaddr % PAGE_SIZE == 0);
uint64_t pml4e = (uc_vaddr >> 39) & 0x1FF; const uint16_t pml4e = (uc_vaddr >> 39) & 0x1FF;
uint64_t pdpte = (uc_vaddr >> 30) & 0x1FF; const uint16_t pdpte = (uc_vaddr >> 30) & 0x1FF;
uint64_t pde = (uc_vaddr >> 21) & 0x1FF; const uint16_t pde = (uc_vaddr >> 21) & 0x1FF;
uint64_t pte = (uc_vaddr >> 12) & 0x1FF; const uint16_t pte = (uc_vaddr >> 12) & 0x1FF;
SpinLockGuard _(m_lock); SpinLockGuard _(m_lock);
if (is_page_free(vaddr)) if (is_page_free(vaddr))
Kernel::panic("trying to unmap unmapped page 0x{H}", vaddr); Kernel::panic("trying to unmap unmapped page 0x{H}", vaddr);
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct); uint64_t* pml4 = P2V(m_highest_paging_struct);
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); uint64_t* pdpt = P2V(pml4[pml4e] & s_page_addr_mask);
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK); uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK); uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
pt[pte] = 0; pt[pte] = 0;
invalidate(vaddr, send_smp_message); invalidate(vaddr, send_smp_message);
@ -414,20 +681,22 @@ namespace Kernel
{ {
ASSERT(vaddr); ASSERT(vaddr);
ASSERT(vaddr != fast_page()); ASSERT(vaddr != fast_page());
if ((vaddr >= KERNEL_OFFSET) != (this == s_kernel)) if (vaddr < KERNEL_OFFSET && this == s_kernel)
Kernel::panic("mapping {8H} to {8H}, kernel: {}", paddr, vaddr, this == s_kernel); panic("kernel is mapping below kernel offset");
if (vaddr >= s_hhdm_offset && this != s_kernel)
panic("user is mapping above hhdm offset");
ASSERT(is_canonical(vaddr)); ASSERT(is_canonical(vaddr));
vaddr_t uc_vaddr = uncanonicalize(vaddr); const vaddr_t uc_vaddr = uncanonicalize(vaddr);
ASSERT(paddr % PAGE_SIZE == 0); ASSERT(paddr % PAGE_SIZE == 0);
ASSERT(vaddr % PAGE_SIZE == 0); ASSERT(vaddr % PAGE_SIZE == 0);
ASSERT(flags & Flags::Used); ASSERT(flags & Flags::Used);
uint64_t pml4e = (uc_vaddr >> 39) & 0x1FF; const uint16_t pml4e = (uc_vaddr >> 39) & 0x1FF;
uint64_t pdpte = (uc_vaddr >> 30) & 0x1FF; const uint16_t pdpte = (uc_vaddr >> 30) & 0x1FF;
uint64_t pde = (uc_vaddr >> 21) & 0x1FF; const uint16_t pde = (uc_vaddr >> 21) & 0x1FF;
uint64_t pte = (uc_vaddr >> 12) & 0x1FF; const uint16_t pte = (uc_vaddr >> 12) & 0x1FF;
uint64_t extra_flags = 0; uint64_t extra_flags = 0;
if (s_has_pge && pml4e == 511) // Map kernel memory as global if (s_has_pge && pml4e == 511) // Map kernel memory as global
@ -449,34 +718,26 @@ namespace Kernel
SpinLockGuard _(m_lock); SpinLockGuard _(m_lock);
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct); const auto allocate_entry_if_needed =
if ((pml4[pml4e] & uwr_flags) != uwr_flags) [](uint64_t* table, uint16_t index, flags_t flags) -> uint64_t*
{ {
if (!(pml4[pml4e] & Flags::Present)) uint64_t entry = table[index];
pml4[pml4e] = V2P(allocate_zeroed_page_aligned_page()); if ((entry & flags) == flags)
pml4[pml4e] |= uwr_flags; return P2V(entry & s_page_addr_mask);
} if (!(entry & Flags::Present))
entry = allocate_zeroed_page_aligned_page();
table[index] = entry | flags;
return P2V(entry & s_page_addr_mask);
};
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); uint64_t* pml4 = P2V(m_highest_paging_struct);
if ((pdpt[pdpte] & uwr_flags) != uwr_flags) uint64_t* pdpt = allocate_entry_if_needed(pml4, pml4e, uwr_flags);
{ uint64_t* pd = allocate_entry_if_needed(pdpt, pdpte, uwr_flags);
if (!(pdpt[pdpte] & Flags::Present)) uint64_t* pt = allocate_entry_if_needed(pd, pde, uwr_flags);
pdpt[pdpte] = V2P(allocate_zeroed_page_aligned_page());
pdpt[pdpte] |= uwr_flags;
}
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK);
if ((pd[pde] & uwr_flags) != uwr_flags)
{
if (!(pd[pde] & Flags::Present))
pd[pde] = V2P(allocate_zeroed_page_aligned_page());
pd[pde] |= uwr_flags;
}
if (!(flags & Flags::Present)) if (!(flags & Flags::Present))
uwr_flags &= ~Flags::Present; uwr_flags &= ~Flags::Present;
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK);
pt[pte] = paddr | uwr_flags | extra_flags; pt[pte] = paddr | uwr_flags | extra_flags;
invalidate(vaddr, send_smp_message); invalidate(vaddr, send_smp_message);
@ -508,30 +769,30 @@ namespace Kernel
uint64_t PageTable::get_page_data(vaddr_t vaddr) const uint64_t PageTable::get_page_data(vaddr_t vaddr) const
{ {
ASSERT(is_canonical(vaddr)); ASSERT(is_canonical(vaddr));
vaddr_t uc_vaddr = uncanonicalize(vaddr); const vaddr_t uc_vaddr = uncanonicalize(vaddr);
ASSERT(vaddr % PAGE_SIZE == 0); ASSERT(vaddr % PAGE_SIZE == 0);
uint64_t pml4e = (uc_vaddr >> 39) & 0x1FF; const uint16_t pml4e = (uc_vaddr >> 39) & 0x1FF;
uint64_t pdpte = (uc_vaddr >> 30) & 0x1FF; const uint16_t pdpte = (uc_vaddr >> 30) & 0x1FF;
uint64_t pde = (uc_vaddr >> 21) & 0x1FF; const uint16_t pde = (uc_vaddr >> 21) & 0x1FF;
uint64_t pte = (uc_vaddr >> 12) & 0x1FF; const uint16_t pte = (uc_vaddr >> 12) & 0x1FF;
SpinLockGuard _(m_lock); SpinLockGuard _(m_lock);
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct); const uint64_t* pml4 = P2V(m_highest_paging_struct);
if (!(pml4[pml4e] & Flags::Present)) if (!(pml4[pml4e] & Flags::Present))
return 0; return 0;
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); const uint64_t* pdpt = P2V(pml4[pml4e] & s_page_addr_mask);
if (!(pdpt[pdpte] & Flags::Present)) if (!(pdpt[pdpte] & Flags::Present))
return 0; return 0;
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK); const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
if (!(pd[pde] & Flags::Present)) if (!(pd[pde] & Flags::Present))
return 0; return 0;
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK); const uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
if (!(pt[pte] & Flags::Used)) if (!(pt[pte] & Flags::Used))
return 0; return 0;
@ -546,7 +807,7 @@ namespace Kernel
paddr_t PageTable::physical_address_of(vaddr_t addr) const paddr_t PageTable::physical_address_of(vaddr_t addr) const
{ {
uint64_t page_data = get_page_data(addr); uint64_t page_data = get_page_data(addr);
return (page_data & PAGE_ADDR_MASK) & ~(1ull << 63); return page_data & s_page_addr_mask;
} }
bool PageTable::reserve_page(vaddr_t vaddr, bool only_free) bool PageTable::reserve_page(vaddr_t vaddr, bool only_free)
@ -601,28 +862,28 @@ namespace Kernel
// Try to find free page that can be mapped without // Try to find free page that can be mapped without
// allocations (page table with unused entries) // allocations (page table with unused entries)
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct); const uint64_t* pml4 = P2V(m_highest_paging_struct);
for (; pml4e < 512; pml4e++) for (; pml4e < 512; pml4e++)
{ {
if (pml4e > e_pml4e) if (pml4e > e_pml4e)
break; break;
if (!(pml4[pml4e] & Flags::Present)) if (!(pml4[pml4e] & Flags::Present))
continue; continue;
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); const uint64_t* pdpt = P2V(pml4[pml4e] & s_page_addr_mask);
for (; pdpte < 512; pdpte++) for (; pdpte < 512; pdpte++)
{ {
if (pml4e == e_pml4e && pdpte > e_pdpte) if (pml4e == e_pml4e && pdpte > e_pdpte)
break; break;
if (!(pdpt[pdpte] & Flags::Present)) if (!(pdpt[pdpte] & Flags::Present))
continue; continue;
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK); const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
for (; pde < 512; pde++) for (; pde < 512; pde++)
{ {
if (pml4e == e_pml4e && pdpte == e_pdpte && pde > e_pde) if (pml4e == e_pml4e && pdpte == e_pdpte && pde > e_pde)
break; break;
if (!(pd[pde] & Flags::Present)) if (!(pd[pde] & Flags::Present))
continue; continue;
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK); const uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
for (; pte < 512; pte++) for (; pte < 512; pte++)
{ {
if (pml4e == e_pml4e && pdpte == e_pdpte && pde == e_pde && pte >= e_pte) if (pml4e == e_pml4e && pdpte == e_pdpte && pde == e_pde && pte >= e_pte)
@ -630,10 +891,10 @@ namespace Kernel
if (!(pt[pte] & Flags::Used)) if (!(pt[pte] & Flags::Used))
{ {
vaddr_t vaddr = 0; vaddr_t vaddr = 0;
vaddr |= (uint64_t)pml4e << 39; vaddr |= static_cast<uint64_t>(pml4e) << 39;
vaddr |= (uint64_t)pdpte << 30; vaddr |= static_cast<uint64_t>(pdpte) << 30;
vaddr |= (uint64_t)pde << 21; vaddr |= static_cast<uint64_t>(pde) << 21;
vaddr |= (uint64_t)pte << 12; vaddr |= static_cast<uint64_t>(pte) << 12;
vaddr = canonicalize(vaddr); vaddr = canonicalize(vaddr);
ASSERT(reserve_page(vaddr)); ASSERT(reserve_page(vaddr));
return vaddr; return vaddr;
@ -643,16 +904,13 @@ namespace Kernel
} }
} }
// Find any free page for (vaddr_t uc_vaddr = uc_vaddr_start; uc_vaddr < uc_vaddr_end; uc_vaddr += PAGE_SIZE)
vaddr_t uc_vaddr = uc_vaddr_start;
while (uc_vaddr < uc_vaddr_end)
{ {
if (vaddr_t vaddr = canonicalize(uc_vaddr); is_page_free(vaddr)) if (vaddr_t vaddr = canonicalize(uc_vaddr); is_page_free(vaddr))
{ {
ASSERT(reserve_page(vaddr)); ASSERT(reserve_page(vaddr));
return vaddr; return vaddr;
} }
uc_vaddr += PAGE_SIZE;
} }
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
@ -739,16 +997,16 @@ namespace Kernel
flags_t flags = 0; flags_t flags = 0;
vaddr_t start = 0; vaddr_t start = 0;
uint64_t* pml4 = (uint64_t*)P2V(m_highest_paging_struct); const uint64_t* pml4 = P2V(m_highest_paging_struct);
for (uint64_t pml4e = 0; pml4e < 512; pml4e++) for (uint64_t pml4e = 0; pml4e < 512; pml4e++)
{ {
if (!(pml4[pml4e] & Flags::Present)) if (!(pml4[pml4e] & Flags::Present) || (pml4e >= 256 && pml4e < 511))
{ {
dump_range(start, (pml4e << 39), flags); dump_range(start, (pml4e << 39), flags);
start = 0; start = 0;
continue; continue;
} }
uint64_t* pdpt = (uint64_t*)P2V(pml4[pml4e] & PAGE_ADDR_MASK); const uint64_t* pdpt = P2V(pml4[pml4e] & s_page_addr_mask);
for (uint64_t pdpte = 0; pdpte < 512; pdpte++) for (uint64_t pdpte = 0; pdpte < 512; pdpte++)
{ {
if (!(pdpt[pdpte] & Flags::Present)) if (!(pdpt[pdpte] & Flags::Present))
@ -757,7 +1015,7 @@ namespace Kernel
start = 0; start = 0;
continue; continue;
} }
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK); const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
for (uint64_t pde = 0; pde < 512; pde++) for (uint64_t pde = 0; pde < 512; pde++)
{ {
if (!(pd[pde] & Flags::Present)) if (!(pd[pde] & Flags::Present))
@ -766,7 +1024,7 @@ namespace Kernel
start = 0; start = 0;
continue; continue;
} }
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK); const uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
for (uint64_t pte = 0; pte < 512; pte++) for (uint64_t pte = 0; pte < 512; pte++)
{ {
if (parse_flags(pt[pte]) != flags) if (parse_flags(pt[pte]) != flags)

View File

@ -8,36 +8,45 @@
namespace Kernel namespace Kernel
{ {
enum class FramebufferType
{
NONE,
UNKNOWN,
RGB
};
struct FramebufferInfo struct FramebufferInfo
{ {
paddr_t address; enum class Type
uint32_t pitch; {
uint32_t width; None,
uint32_t height; Unknown,
uint8_t bpp; RGB,
FramebufferType type = FramebufferType::NONE; };
paddr_t address;
uint32_t pitch;
uint32_t width;
uint32_t height;
uint8_t bpp;
Type type;
}; };
struct MemoryMapEntry struct MemoryMapEntry
{ {
uint32_t type; enum class Type
paddr_t address; {
uint64_t length; Available,
Reserved,
ACPIReclaim,
ACPINVS,
};
paddr_t address;
uint64_t length;
Type type;
}; };
struct BootInfo struct BootInfo
{ {
BAN::String command_line; BAN::String command_line;
FramebufferInfo framebuffer {}; FramebufferInfo framebuffer {};
RSDP rsdp {}; RSDP rsdp {};
paddr_t kernel_paddr {}; paddr_t kernel_paddr {};
BAN::Vector<MemoryMapEntry> memory_map_entries; BAN::Vector<MemoryMapEntry> memory_map_entries;
}; };

View File

@ -80,5 +80,6 @@ namespace CPUID
bool has_nxe(); bool has_nxe();
bool has_pge(); bool has_pge();
bool has_pat(); bool has_pat();
bool has_1gib_pages();
} }

View File

@ -43,7 +43,8 @@ namespace Kernel
}; };
public: public:
static void initialize(); static void initialize_pre_heap();
static void initialize_post_heap();
static PageTable& kernel(); static PageTable& kernel();
static PageTable& current() { return *reinterpret_cast<PageTable*>(Processor::get_current_page_table()); } static PageTable& current() { return *reinterpret_cast<PageTable*>(Processor::get_current_page_table()); }

View File

@ -19,31 +19,17 @@ namespace Kernel
void release_contiguous_pages(paddr_t paddr, size_t pages); void release_contiguous_pages(paddr_t paddr, size_t pages);
paddr_t start() const { return m_paddr; } paddr_t start() const { return m_paddr; }
paddr_t end() const { return m_paddr + m_size; } paddr_t end() const { return m_paddr + m_page_count * PAGE_SIZE; }
bool contains(paddr_t addr) const { return m_paddr <= addr && addr < m_paddr + m_size; } bool contains(paddr_t addr) const { return start() <= addr && addr < end(); }
size_t usable_memory() const { return m_data_pages * PAGE_SIZE; } size_t usable_memory() const { return m_page_count * PAGE_SIZE; }
size_t used_pages() const { return m_data_pages - m_free_pages; } size_t used_pages() const { return m_page_count - m_free_pages; }
size_t free_pages() const { return m_free_pages; } size_t free_pages() const { return m_free_pages; }
private:
unsigned long long* ull_bitmap_ptr() { return (unsigned long long*)m_vaddr; }
const unsigned long long* ull_bitmap_ptr() const { return (const unsigned long long*)m_vaddr; }
paddr_t paddr_for_bit(unsigned long long) const;
unsigned long long bit_for_paddr(paddr_t paddr) const;
unsigned long long contiguous_bits_set(unsigned long long start, unsigned long long count) const;
private: private:
const paddr_t m_paddr { 0 }; const paddr_t m_paddr { 0 };
const size_t m_size { 0 }; const size_t m_page_count { 0 };
vaddr_t m_vaddr { 0 };
const size_t m_bitmap_pages { 0 };
const size_t m_data_pages { 0 };
size_t m_free_pages { 0 }; size_t m_free_pages { 0 };
}; };

View File

@ -13,3 +13,4 @@ void* kmalloc(size_t size, size_t align, bool force_identity_map = false);
void kfree(void*); void kfree(void*);
BAN::Optional<Kernel::paddr_t> kmalloc_paddr_of(Kernel::vaddr_t); BAN::Optional<Kernel::paddr_t> kmalloc_paddr_of(Kernel::vaddr_t);
BAN::Optional<Kernel::vaddr_t> kmalloc_vaddr_of(Kernel::paddr_t);

View File

@ -281,16 +281,32 @@ acpi_release_global_lock:
return true; return true;
} }
static const RSDP* locate_rsdp() static BAN::Optional<RSDP> locate_rsdp()
{ {
if (g_boot_info.rsdp.length) if (g_boot_info.rsdp.length)
return &g_boot_info.rsdp; return g_boot_info.rsdp;
// Look in main BIOS area below 1 MB // Look in main BIOS area below 1 MB
for (vaddr_t addr = 0x000E0000 + KERNEL_OFFSET; addr < 0x000FFFFF + KERNEL_OFFSET; addr += 16) for (paddr_t paddr = 0x000E0000; paddr < 0x00100000; paddr += PAGE_SIZE)
if (is_rsdp(addr)) {
return reinterpret_cast<const RSDP*>(addr); BAN::Optional<RSDP> rsdp;
return nullptr;
PageTable::with_fast_page(paddr, [&rsdp] {
for (size_t offset = 0; offset + sizeof(RSDP) <= PAGE_SIZE; offset += 16)
{
if (is_rsdp(PageTable::fast_page() + offset))
{
rsdp = PageTable::fast_page_as<RSDP>(offset);
break;
}
}
});
if (rsdp.has_value())
return rsdp.release_value();
}
return {};
} }
static bool is_valid_std_header(const SDTHeader* header) static bool is_valid_std_header(const SDTHeader* header)
@ -303,24 +319,25 @@ acpi_release_global_lock:
BAN::ErrorOr<void> ACPI::initialize_impl() BAN::ErrorOr<void> ACPI::initialize_impl()
{ {
const RSDP* rsdp = locate_rsdp(); auto opt_rsdp = locate_rsdp();
if (rsdp == nullptr) if (!opt_rsdp.has_value())
return BAN::Error::from_error_code(ErrorCode::ACPI_NoRootSDT); return BAN::Error::from_error_code(ErrorCode::ACPI_NoRootSDT);
const RSDP rsdp = opt_rsdp.release_value();
uint32_t root_entry_count = 0; uint32_t root_entry_count = 0;
if (rsdp->revision >= 2) if (rsdp.revision >= 2)
{ {
TRY(PageTable::with_fast_page(rsdp->xsdt_address & PAGE_ADDR_MASK, TRY(PageTable::with_fast_page(rsdp.xsdt_address & PAGE_ADDR_MASK,
[&]() -> BAN::ErrorOr<void> [&]() -> BAN::ErrorOr<void>
{ {
auto& xsdt = PageTable::fast_page_as<const XSDT>(rsdp->xsdt_address % PAGE_SIZE); auto& xsdt = PageTable::fast_page_as<const XSDT>(rsdp.xsdt_address % PAGE_SIZE);
if (memcmp(xsdt.signature, "XSDT", 4) != 0) if (memcmp(xsdt.signature, "XSDT", 4) != 0)
return BAN::Error::from_error_code(ErrorCode::ACPI_RootInvalid); return BAN::Error::from_error_code(ErrorCode::ACPI_RootInvalid);
if (!is_valid_std_header(&xsdt)) if (!is_valid_std_header(&xsdt))
return BAN::Error::from_error_code(ErrorCode::ACPI_RootInvalid); return BAN::Error::from_error_code(ErrorCode::ACPI_RootInvalid);
m_header_table_paddr = rsdp->xsdt_address + offsetof(XSDT, entries); m_header_table_paddr = rsdp.xsdt_address + offsetof(XSDT, entries);
m_entry_size = 8; m_entry_size = 8;
root_entry_count = (xsdt.length - sizeof(SDTHeader)) / 8; root_entry_count = (xsdt.length - sizeof(SDTHeader)) / 8;
return {}; return {};
@ -329,16 +346,16 @@ acpi_release_global_lock:
} }
else else
{ {
TRY(PageTable::with_fast_page(rsdp->rsdt_address & PAGE_ADDR_MASK, TRY(PageTable::with_fast_page(rsdp.rsdt_address & PAGE_ADDR_MASK,
[&]() -> BAN::ErrorOr<void> [&]() -> BAN::ErrorOr<void>
{ {
auto& rsdt = PageTable::fast_page_as<const RSDT>(rsdp->rsdt_address % PAGE_SIZE); auto& rsdt = PageTable::fast_page_as<const RSDT>(rsdp.rsdt_address % PAGE_SIZE);
if (memcmp(rsdt.signature, "RSDT", 4) != 0) if (memcmp(rsdt.signature, "RSDT", 4) != 0)
return BAN::Error::from_error_code(ErrorCode::ACPI_RootInvalid); return BAN::Error::from_error_code(ErrorCode::ACPI_RootInvalid);
if (!is_valid_std_header(&rsdt)) if (!is_valid_std_header(&rsdt))
return BAN::Error::from_error_code(ErrorCode::ACPI_RootInvalid); return BAN::Error::from_error_code(ErrorCode::ACPI_RootInvalid);
m_header_table_paddr = rsdp->rsdt_address + offsetof(RSDT, entries); m_header_table_paddr = rsdp.rsdt_address + offsetof(RSDT, entries);
m_entry_size = 4; m_entry_size = 4;
root_entry_count = (rsdt.length - sizeof(SDTHeader)) / 4; root_entry_count = (rsdt.length - sizeof(SDTHeader)) / 4;
return {}; return {};

View File

@ -7,6 +7,18 @@ namespace Kernel
BootInfo g_boot_info; BootInfo g_boot_info;
static MemoryMapEntry::Type bios_number_to_memory_type(uint32_t number)
{
switch (number)
{
case 1: return MemoryMapEntry::Type::Available;
case 2: return MemoryMapEntry::Type::Reserved;
case 3: return MemoryMapEntry::Type::ACPIReclaim;
case 4: return MemoryMapEntry::Type::ACPINVS;
}
return MemoryMapEntry::Type::Reserved;
}
static void parse_boot_info_multiboot2(uint32_t info) static void parse_boot_info_multiboot2(uint32_t info)
{ {
const auto& multiboot2_info = *reinterpret_cast<const multiboot2_info_t*>(info); const auto& multiboot2_info = *reinterpret_cast<const multiboot2_info_t*>(info);
@ -27,9 +39,9 @@ namespace Kernel
g_boot_info.framebuffer.height = framebuffer_tag.framebuffer_height; g_boot_info.framebuffer.height = framebuffer_tag.framebuffer_height;
g_boot_info.framebuffer.bpp = framebuffer_tag.framebuffer_bpp; g_boot_info.framebuffer.bpp = framebuffer_tag.framebuffer_bpp;
if (framebuffer_tag.framebuffer_type == MULTIBOOT2_FRAMEBUFFER_TYPE_RGB) if (framebuffer_tag.framebuffer_type == MULTIBOOT2_FRAMEBUFFER_TYPE_RGB)
g_boot_info.framebuffer.type = FramebufferType::RGB; g_boot_info.framebuffer.type = FramebufferInfo::Type::RGB;
else else
g_boot_info.framebuffer.type = FramebufferType::UNKNOWN; g_boot_info.framebuffer.type = FramebufferInfo::Type::Unknown;
} }
else if (tag->type == MULTIBOOT2_TAG_MMAP) else if (tag->type == MULTIBOOT2_TAG_MMAP)
{ {
@ -47,9 +59,9 @@ namespace Kernel
(uint64_t)mmap_entry.length, (uint64_t)mmap_entry.length,
(uint64_t)mmap_entry.type (uint64_t)mmap_entry.type
); );
g_boot_info.memory_map_entries[i].address = mmap_entry.base_addr; g_boot_info.memory_map_entries[i].address = mmap_entry.base_addr;
g_boot_info.memory_map_entries[i].length = mmap_entry.length; g_boot_info.memory_map_entries[i].length = mmap_entry.length;
g_boot_info.memory_map_entries[i].type = mmap_entry.type; g_boot_info.memory_map_entries[i].type = bios_number_to_memory_type(mmap_entry.type);
} }
} }
else if (tag->type == MULTIBOOT2_TAG_OLD_RSDP) else if (tag->type == MULTIBOOT2_TAG_OLD_RSDP)
@ -87,24 +99,24 @@ namespace Kernel
MUST(g_boot_info.command_line.append(command_line)); MUST(g_boot_info.command_line.append(command_line));
const auto& framebuffer = *reinterpret_cast<BananBootFramebufferInfo*>(banan_bootloader_info.framebuffer_addr); const auto& framebuffer = *reinterpret_cast<BananBootFramebufferInfo*>(banan_bootloader_info.framebuffer_addr);
g_boot_info.framebuffer.address = framebuffer.address;
g_boot_info.framebuffer.width = framebuffer.width;
g_boot_info.framebuffer.height = framebuffer.height;
g_boot_info.framebuffer.pitch = framebuffer.pitch;
g_boot_info.framebuffer.bpp = framebuffer.bpp;
if (framebuffer.type == BANAN_BOOTLOADER_FB_RGB) if (framebuffer.type == BANAN_BOOTLOADER_FB_RGB)
{ g_boot_info.framebuffer.type = FramebufferInfo::Type::RGB;
g_boot_info.framebuffer.address = framebuffer.address; else
g_boot_info.framebuffer.width = framebuffer.width; g_boot_info.framebuffer.type = FramebufferInfo::Type::Unknown;
g_boot_info.framebuffer.height = framebuffer.height;
g_boot_info.framebuffer.pitch = framebuffer.pitch;
g_boot_info.framebuffer.bpp = framebuffer.bpp;
g_boot_info.framebuffer.type = FramebufferType::RGB;
}
const auto& memory_map = *reinterpret_cast<BananBootloaderMemoryMapInfo*>(banan_bootloader_info.memory_map_addr); const auto& memory_map = *reinterpret_cast<BananBootloaderMemoryMapInfo*>(banan_bootloader_info.memory_map_addr);
MUST(g_boot_info.memory_map_entries.resize(memory_map.entry_count)); MUST(g_boot_info.memory_map_entries.resize(memory_map.entry_count));
for (size_t i = 0; i < memory_map.entry_count; i++) for (size_t i = 0; i < memory_map.entry_count; i++)
{ {
const auto& mmap_entry = memory_map.entries[i]; const auto& mmap_entry = memory_map.entries[i];
g_boot_info.memory_map_entries[i].address = mmap_entry.address; g_boot_info.memory_map_entries[i].address = mmap_entry.address;
g_boot_info.memory_map_entries[i].length = mmap_entry.length; g_boot_info.memory_map_entries[i].length = mmap_entry.length;
g_boot_info.memory_map_entries[i].type = mmap_entry.type; g_boot_info.memory_map_entries[i].type = bios_number_to_memory_type(mmap_entry.type);
} }
g_boot_info.kernel_paddr = banan_bootloader_info.kernel_paddr; g_boot_info.kernel_paddr = banan_bootloader_info.kernel_paddr;

View File

@ -64,6 +64,16 @@ namespace CPUID
return edx & CPUID::EDX_PAT; return edx & CPUID::EDX_PAT;
} }
bool has_1gib_pages()
{
uint32_t buffer[4] {};
get_cpuid(0x80000000, buffer);
if (buffer[0] < 0x80000001)
return false;
get_cpuid(0x80000001, buffer);
return buffer[3] & (1 << 26);
}
const char* feature_string_ecx(uint32_t feat) const char* feature_string_ecx(uint32_t feat)
{ {
switch (feat) switch (feat)

View File

@ -19,7 +19,7 @@ namespace Kernel
BAN::ErrorOr<BAN::RefPtr<FramebufferDevice>> FramebufferDevice::create_from_boot_framebuffer() BAN::ErrorOr<BAN::RefPtr<FramebufferDevice>> FramebufferDevice::create_from_boot_framebuffer()
{ {
if (g_boot_info.framebuffer.type != FramebufferType::RGB) if (g_boot_info.framebuffer.type != FramebufferInfo::Type::RGB)
return BAN::Error::from_errno(ENODEV); return BAN::Error::from_errno(ENODEV);
if (g_boot_info.framebuffer.bpp != 24 && g_boot_info.framebuffer.bpp != 32) if (g_boot_info.framebuffer.bpp != 24 && g_boot_info.framebuffer.bpp != 32)
return BAN::Error::from_errno(ENOTSUP); return BAN::Error::from_errno(ENOTSUP);

View File

@ -191,15 +191,17 @@ namespace Kernel
paddr_t page_containing = find_indirect(m_data_pages, index_of_page, 2); paddr_t page_containing = find_indirect(m_data_pages, index_of_page, 2);
paddr_t paddr_to_free = 0;
PageTable::with_fast_page(page_containing, [&] { PageTable::with_fast_page(page_containing, [&] {
auto& page_info = PageTable::fast_page_as_sized<PageInfo>(index_in_page); auto& page_info = PageTable::fast_page_as_sized<PageInfo>(index_in_page);
ASSERT(page_info.flags() & PageInfo::Flags::Present); ASSERT(page_info.flags() & PageInfo::Flags::Present);
Heap::get().release_page(page_info.paddr()); paddr_to_free = page_info.paddr();
m_used_pages--; m_used_pages--;
page_info.set_paddr(0); page_info.set_paddr(0);
page_info.set_flags(0); page_info.set_flags(0);
}); });
Heap::get().release_page(paddr_to_free);
} }
BAN::ErrorOr<size_t> TmpFileSystem::allocate_block() BAN::ErrorOr<size_t> TmpFileSystem::allocate_block()

View File

@ -26,17 +26,36 @@ namespace Kernel
void Heap::initialize_impl() void Heap::initialize_impl()
{ {
if (g_boot_info.memory_map_entries.empty()) if (g_boot_info.memory_map_entries.empty())
Kernel::panic("Bootloader did not provide a memory map"); panic("Bootloader did not provide a memory map");
for (const auto& entry : g_boot_info.memory_map_entries) for (const auto& entry : g_boot_info.memory_map_entries)
{ {
dprintln("{16H}, {16H}, {8H}", const char* entry_type_string = nullptr;
switch (entry.type)
{
case MemoryMapEntry::Type::Available:
entry_type_string = "available";
break;
case MemoryMapEntry::Type::Reserved:
entry_type_string = "reserved";
break;
case MemoryMapEntry::Type::ACPIReclaim:
entry_type_string = "acpi reclaim";
break;
case MemoryMapEntry::Type::ACPINVS:
entry_type_string = "acpi nvs";
break;
default:
ASSERT_NOT_REACHED();
}
dprintln("{16H}, {16H}, {}",
entry.address, entry.address,
entry.length, entry.length,
entry.type entry_type_string
); );
if (entry.type != 1) if (entry.type != MemoryMapEntry::Type::Available)
continue; continue;
paddr_t start = entry.address; paddr_t start = entry.address;
@ -79,7 +98,7 @@ namespace Kernel
for (auto& range : m_physical_ranges) for (auto& range : m_physical_ranges)
if (range.contains(paddr)) if (range.contains(paddr))
return range.release_page(paddr); return range.release_page(paddr);
ASSERT_NOT_REACHED(); panic("tried to free invalid paddr {16H}", paddr);
} }
paddr_t Heap::take_free_contiguous_pages(size_t pages) paddr_t Heap::take_free_contiguous_pages(size_t pages)

View File

@ -1,81 +1,67 @@
#include <BAN/Assert.h> #include <BAN/Assert.h>
#include <BAN/Math.h> #include <BAN/Math.h>
#include <BAN/Optional.h>
#include <kernel/Memory/PageTable.h> #include <kernel/Memory/PageTable.h>
#include <kernel/Memory/PhysicalRange.h> #include <kernel/Memory/PhysicalRange.h>
namespace Kernel namespace Kernel
{ {
using ull = unsigned long long; static constexpr size_t bits_per_page = PAGE_SIZE * 8;
static constexpr ull ull_bits = sizeof(ull) * 8;
PhysicalRange::PhysicalRange(paddr_t paddr, size_t size) PhysicalRange::PhysicalRange(paddr_t paddr, size_t size)
: m_paddr(paddr) : m_paddr(paddr)
, m_size(size) , m_page_count(size / PAGE_SIZE)
, m_bitmap_pages(BAN::Math::div_round_up<size_t>(size / PAGE_SIZE, PAGE_SIZE * 8)) , m_free_pages(m_page_count)
, m_data_pages((size / PAGE_SIZE) - m_bitmap_pages)
, m_free_pages(m_data_pages)
{ {
ASSERT(paddr % PAGE_SIZE == 0); ASSERT(paddr % PAGE_SIZE == 0);
ASSERT(size % PAGE_SIZE == 0); ASSERT(size % PAGE_SIZE == 0);
ASSERT(m_bitmap_pages < size / PAGE_SIZE);
m_vaddr = PageTable::kernel().reserve_free_contiguous_pages(m_bitmap_pages, KERNEL_OFFSET); const size_t bitmap_page_count = BAN::Math::div_round_up<size_t>(m_page_count, bits_per_page);
ASSERT(m_vaddr); for (size_t i = 0; i < bitmap_page_count; i++)
PageTable::kernel().map_range_at(m_paddr, m_vaddr, m_bitmap_pages * PAGE_SIZE, PageTable::Flags::ReadWrite | PageTable::Flags::Present);
memset((void*)m_vaddr, 0x00, m_bitmap_pages * PAGE_SIZE);
for (ull i = 0; i < m_data_pages / ull_bits; i++)
ull_bitmap_ptr()[i] = ~0ull;
if (m_data_pages % ull_bits)
{ {
ull off = m_data_pages / ull_bits; PageTable::with_fast_page(paddr + i * PAGE_SIZE, [] {
ull bits = m_data_pages % ull_bits; memset(PageTable::fast_page_as_ptr(), 0, PAGE_SIZE);
ull_bitmap_ptr()[off] = ~(~0ull << bits); });
} }
}
paddr_t PhysicalRange::paddr_for_bit(ull bit) const ASSERT(reserve_contiguous_pages(bitmap_page_count) == m_paddr);
{
return m_paddr + (m_bitmap_pages + bit) * PAGE_SIZE;
}
ull PhysicalRange::bit_for_paddr(paddr_t paddr) const
{
return (paddr - m_paddr) / PAGE_SIZE - m_bitmap_pages;
}
ull PhysicalRange::contiguous_bits_set(ull start, ull count) const
{
for (ull i = 0; i < count; i++)
{
ull off = (start + i) / ull_bits;
ull bit = (start + i) % ull_bits;
if (!(ull_bitmap_ptr()[off] & (1ull << bit)))
return i;
}
return count;
} }
paddr_t PhysicalRange::reserve_page() paddr_t PhysicalRange::reserve_page()
{ {
ASSERT(free_pages() > 0); ASSERT(free_pages() > 0);
ull ull_count = BAN::Math::div_round_up<ull>(m_data_pages, ull_bits); const size_t bitmap_page_count = BAN::Math::div_round_up<size_t>(m_page_count, bits_per_page);
for (ull i = 0; i < ull_count; i++) for (size_t i = 0; i < bitmap_page_count; i++)
{ {
if (ull_bitmap_ptr()[i] == 0) BAN::Optional<size_t> page_matched_bit;
continue;
int lsb = __builtin_ctzll(ull_bitmap_ptr()[i]); const paddr_t current_paddr = m_paddr + i * PAGE_SIZE;
PageTable::with_fast_page(current_paddr, [&page_matched_bit] {
for (size_t j = 0; j < PAGE_SIZE / sizeof(size_t); j++)
{
static_assert(sizeof(size_t) == sizeof(long));
const size_t current = PageTable::fast_page_as_sized<volatile size_t>(j);
if (current == BAN::numeric_limits<size_t>::max())
continue;
const int ctz = __builtin_ctzl(~current);
PageTable::fast_page_as_sized<volatile size_t>(j) = current | (static_cast<size_t>(1) << ctz);
page_matched_bit = j * sizeof(size_t) * 8 + ctz;
return;
}
});
ull_bitmap_ptr()[i] &= ~(1ull << lsb); if (page_matched_bit.has_value())
m_free_pages--; {
return paddr_for_bit(i * ull_bits + lsb); m_free_pages--;
const size_t matched_bit = (i * bits_per_page) + page_matched_bit.value();
ASSERT(matched_bit < m_page_count);
return m_paddr + matched_bit * PAGE_SIZE;
}
} }
ASSERT_NOT_REACHED(); ASSERT_NOT_REACHED();
@ -84,15 +70,21 @@ namespace Kernel
void PhysicalRange::release_page(paddr_t paddr) void PhysicalRange::release_page(paddr_t paddr)
{ {
ASSERT(paddr % PAGE_SIZE == 0); ASSERT(paddr % PAGE_SIZE == 0);
ASSERT(paddr - m_paddr <= m_size); ASSERT(paddr >= m_paddr);
ASSERT(paddr < m_paddr + m_page_count * PAGE_SIZE);
ull full_bit = bit_for_paddr(paddr); const size_t paddr_index = (paddr - m_paddr) / PAGE_SIZE;
ull off = full_bit / ull_bits;
ull bit = full_bit % ull_bits;
ull mask = 1ull << bit;
ASSERT(!(ull_bitmap_ptr()[off] & mask)); PageTable::with_fast_page(m_paddr + paddr_index / bits_per_page * PAGE_SIZE, [paddr_index] {
ull_bitmap_ptr()[off] |= mask; const size_t bitmap_bit = paddr_index % bits_per_page;
const size_t byte = bitmap_bit / 8;
const size_t bit = bitmap_bit % 8;
volatile uint8_t& bitmap_byte = PageTable::fast_page_as_sized<volatile uint8_t>(byte);
ASSERT(bitmap_byte & (1u << bit));
bitmap_byte = bitmap_byte & ~(1u << bit);
});
m_free_pages++; m_free_pages++;
} }
@ -100,58 +92,60 @@ namespace Kernel
paddr_t PhysicalRange::reserve_contiguous_pages(size_t pages) paddr_t PhysicalRange::reserve_contiguous_pages(size_t pages)
{ {
ASSERT(pages > 0); ASSERT(pages > 0);
ASSERT(free_pages() > 0); ASSERT(pages <= free_pages());
if (pages == 1) const auto bitmap_is_set =
return reserve_page(); [this](size_t buffer_bit) -> bool
ull ull_count = BAN::Math::div_round_up<ull>(m_data_pages, ull_bits);
// NOTE: This feels kinda slow, but I don't want to be
// doing premature optimization. This will be only
// used when creating DMA regions.
for (ull i = 0; i < ull_count; i++)
{
if (ull_bitmap_ptr()[i] == 0)
continue;
for (ull bit = 0; bit < ull_bits;)
{ {
ull start = i * ull_bits + bit; const size_t page_index = buffer_bit / bits_per_page;
ull set_cnt = contiguous_bits_set(start, pages); const size_t byte = buffer_bit / 8;
if (set_cnt == pages) const size_t bit = buffer_bit % 8;
{
for (ull j = 0; j < pages; j++) uint8_t current;
ull_bitmap_ptr()[(start + j) / ull_bits] &= ~(1ull << ((start + j) % ull_bits)); PageTable::with_fast_page(m_paddr + page_index * PAGE_SIZE, [&current, byte] {
m_free_pages -= pages; current = PageTable::fast_page_as_sized<volatile uint8_t>(byte);
return paddr_for_bit(start); });
}
bit += set_cnt + 1; return current & (1u << bit);
} };
const auto bitmap_set_bit =
[this](size_t buffer_bit) -> void
{
const size_t page_index = buffer_bit / bits_per_page;
const size_t byte = buffer_bit / 8;
const size_t bit = buffer_bit % 8;
PageTable::with_fast_page(m_paddr + page_index * PAGE_SIZE, [byte, bit] {
volatile uint8_t& current = PageTable::fast_page_as_sized<volatile uint8_t>(byte);
current = current | (1u << bit);
});
};
// FIXME: optimize this :)
for (size_t i = 0; i <= m_page_count - pages; i++)
{
bool all_unset = true;
for (size_t j = 0; j < pages && all_unset; j++)
if (bitmap_is_set(i + j))
all_unset = false;
if (!all_unset)
continue;
for (size_t j = 0; j < pages; j++)
bitmap_set_bit(i + j);
m_free_pages -= pages;
return m_paddr + i * PAGE_SIZE;
} }
ASSERT_NOT_REACHED(); return 0;
} }
void PhysicalRange::release_contiguous_pages(paddr_t paddr, size_t pages) void PhysicalRange::release_contiguous_pages(paddr_t paddr, size_t pages)
{ {
ASSERT(paddr % PAGE_SIZE == 0);
ASSERT(paddr - m_paddr <= m_size);
ASSERT(pages > 0); ASSERT(pages > 0);
// FIXME: optimize this :)
ull start_bit = bit_for_paddr(paddr);
for (size_t i = 0; i < pages; i++) for (size_t i = 0; i < pages; i++)
{ release_page(paddr + i * PAGE_SIZE);
ull off = (start_bit + i) / ull_bits;
ull bit = (start_bit + i) % ull_bits;
ull mask = 1ull << bit;
ASSERT(!(ull_bitmap_ptr()[off] & mask));
ull_bitmap_ptr()[off] |= mask;
}
m_free_pages += pages;
} }
} }

View File

@ -418,12 +418,29 @@ void kfree(void* address)
} }
static bool is_kmalloc_vaddr(Kernel::vaddr_t vaddr)
{
using namespace Kernel;
if (vaddr < reinterpret_cast<vaddr_t>(s_kmalloc_storage))
return false;
if (vaddr >= reinterpret_cast<vaddr_t>(s_kmalloc_storage) + sizeof(s_kmalloc_storage))
return false;
return true;
}
BAN::Optional<Kernel::paddr_t> kmalloc_paddr_of(Kernel::vaddr_t vaddr) BAN::Optional<Kernel::paddr_t> kmalloc_paddr_of(Kernel::vaddr_t vaddr)
{ {
using namespace Kernel; using namespace Kernel;
if (!is_kmalloc_vaddr(vaddr))
if ((vaddr_t)s_kmalloc_storage <= vaddr && vaddr < (vaddr_t)s_kmalloc_storage + sizeof(s_kmalloc_storage)) return {};
return vaddr - KERNEL_OFFSET + g_boot_info.kernel_paddr; return vaddr - KERNEL_OFFSET + g_boot_info.kernel_paddr;
}
return {};
BAN::Optional<Kernel::vaddr_t> kmalloc_vaddr_of(Kernel::paddr_t paddr)
{
using namespace Kernel;
const vaddr_t vaddr = paddr + KERNEL_OFFSET - g_boot_info.kernel_paddr;
if (!is_kmalloc_vaddr(vaddr))
return {};
return vaddr;
} }

View File

@ -131,13 +131,16 @@ extern "C" void kernel_main(uint32_t boot_magic, uint32_t boot_info)
Processor::initialize(); Processor::initialize();
dprintln("BSP initialized"); dprintln("BSP initialized");
PageTable::initialize(); PageTable::initialize_pre_heap();
PageTable::kernel().initial_load(); PageTable::kernel().initial_load();
dprintln("PageTable initialized"); dprintln("PageTable stage1 initialized");
Heap::initialize(); Heap::initialize();
dprintln("Heap initialzed"); dprintln("Heap initialzed");
PageTable::initialize_post_heap();
dprintln("PageTable stage2 initialized");
parse_command_line(); parse_command_line();
dprintln("command line parsed, root='{}', console='{}'", cmdline.root, cmdline.console); dprintln("command line parsed, root='{}', console='{}'", cmdline.root, cmdline.console);

View File

@ -2,13 +2,16 @@
#include <BAN/Limits.h> #include <BAN/Limits.h>
#include <BAN/Math.h> #include <BAN/Math.h>
#include <BAN/UTF8.h> #include <BAN/UTF8.h>
#include <ctype.h> #include <ctype.h>
#include <errno.h> #include <errno.h>
#include <locale.h> #include <locale.h>
#include <signal.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <strings.h> #include <strings.h>
#include <sys/stat.h>
#include <sys/syscall.h> #include <sys/syscall.h>
#include <unistd.h> #include <unistd.h>
@ -26,6 +29,7 @@ void abort(void)
fflush(nullptr); fflush(nullptr);
fprintf(stderr, "abort()\n"); fprintf(stderr, "abort()\n");
exit(1); exit(1);
ASSERT_NOT_REACHED();
} }
void exit(int status) void exit(int status)
@ -395,22 +399,58 @@ char* realpath(const char* __restrict file_name, char* __restrict resolved_name)
int system(const char* command) int system(const char* command)
{ {
// FIXME // FIXME: maybe implement POSIX compliant shell?
constexpr const char* shell_path = "/bin/Shell";
if (command == nullptr) if (command == nullptr)
return 1; {
struct stat st;
if (stat(shell_path, &st) == -1)
return 0;
if (S_ISDIR(st.st_mode))
return 0;
return !!(st.st_mode & (S_IXUSR | S_IXGRP | S_IXOTH));
}
struct sigaction sa;
sa.sa_flags = 0;
sa.sa_handler = SIG_IGN;
sigemptyset(&sa.sa_mask);
struct sigaction sigint_save, sigquit_save;
sigaction(SIGINT, &sa, &sigint_save);
sigaction(SIGQUIT, &sa, &sigquit_save);
sigset_t sigchld_save;
sigaddset(&sa.sa_mask, SIGCHLD);
sigprocmask(SIG_BLOCK, &sa.sa_mask, &sigchld_save);
int pid = fork(); int pid = fork();
if (pid == 0) if (pid == 0)
{ {
execl("/bin/Shell", "Shell", "-c", command, (char*)0); sigaction(SIGINT, &sigint_save, nullptr);
exit(1); sigaction(SIGQUIT, &sigquit_save, nullptr);
sigprocmask(SIG_SETMASK, &sigchld_save, nullptr);
execl(shell_path, "sh", "-c", command, nullptr);
exit(127);
} }
if (pid == -1) int stat_val = -1;
return -1; if (pid != -1)
{
while (waitpid(pid, &stat_val, 0) == -1)
{
if (errno == EINTR)
continue;
stat_val = -1;
break;
}
}
sigaction(SIGINT, &sigint_save, nullptr);
sigaction(SIGQUIT, &sigquit_save, nullptr);
sigprocmask(SIG_SETMASK, &sigchld_save, nullptr);
int stat_val;
waitpid(pid, &stat_val, 0);
return stat_val; return stat_val;
} }

View File

@ -87,6 +87,8 @@ int strcmp(const char* s1, const char* s2)
int strncmp(const char* s1, const char* s2, size_t n) int strncmp(const char* s1, const char* s2, size_t n)
{ {
if (n == 0)
return 0;
const unsigned char* u1 = (unsigned char*)s1; const unsigned char* u1 = (unsigned char*)s1;
const unsigned char* u2 = (unsigned char*)s2; const unsigned char* u2 = (unsigned char*)s2;
for (; --n && *u1 && *u2; u1++, u2++) for (; --n && *u1 && *u2; u1++, u2++)
@ -220,11 +222,11 @@ char* strchr(const char* str, int c)
{ {
while (*str) while (*str)
{ {
if (*str == c) if (*str == (char)c)
return (char*)str; return (char*)str;
str++; str++;
} }
return (*str == c) ? (char*)str : nullptr; return (*str == (char)c) ? (char*)str : nullptr;
} }
char* strchrnul(const char* str, int c) char* strchrnul(const char* str, int c)
@ -252,7 +254,9 @@ char* strrchr(const char* str, int c)
char* strstr(const char* haystack, const char* needle) char* strstr(const char* haystack, const char* needle)
{ {
size_t needle_len = strlen(needle); const size_t needle_len = strlen(needle);
if (needle_len == 0)
return const_cast<char*>(haystack);
for (size_t i = 0; haystack[i]; i++) for (size_t i = 0; haystack[i]; i++)
if (strncmp(haystack + i, needle, needle_len) == 0) if (strncmp(haystack + i, needle, needle_len) == 0)
return const_cast<char*>(haystack + i); return const_cast<char*>(haystack + i);

View File

@ -7,7 +7,6 @@
#include <fcntl.h> #include <fcntl.h>
#include <math.h> #include <math.h>
#include <stdlib.h>
#include <sys/mman.h> #include <sys/mman.h>
namespace LibImage namespace LibImage
@ -87,9 +86,9 @@ namespace LibImage
constexpr Image::Color as_color() const constexpr Image::Color as_color() const
{ {
return Image::Color { return Image::Color {
.r = static_cast<uint8_t>(r < 0.0 ? 0.0 : r > 255.0 ? 255.0 : r),
.g = static_cast<uint8_t>(g < 0.0 ? 0.0 : g > 255.0 ? 255.0 : g),
.b = static_cast<uint8_t>(b < 0.0 ? 0.0 : b > 255.0 ? 255.0 : b), .b = static_cast<uint8_t>(b < 0.0 ? 0.0 : b > 255.0 ? 255.0 : b),
.g = static_cast<uint8_t>(g < 0.0 ? 0.0 : g > 255.0 ? 255.0 : g),
.r = static_cast<uint8_t>(r < 0.0 ? 0.0 : r > 255.0 ? 255.0 : r),
.a = static_cast<uint8_t>(a < 0.0 ? 0.0 : a > 255.0 ? 255.0 : a), .a = static_cast<uint8_t>(a < 0.0 ? 0.0 : a > 255.0 ? 255.0 : a),
}; };
} }

View File

@ -560,52 +560,47 @@ namespace LibImage
const auto extract_color = const auto extract_color =
[&](auto& bit_buffer) -> Image::Color [&](auto& bit_buffer) -> Image::Color
{ {
uint8_t tmp; Image::Color color;
switch (ihdr.colour_type) switch (ihdr.colour_type)
{ {
case ColourType::Greyscale: case ColourType::Greyscale:
tmp = extract_channel(bit_buffer); color.r = extract_channel(bit_buffer);
return Image::Color { color.g = color.r;
.r = tmp, color.b = color.r;
.g = tmp, color.a = 0xFF;
.b = tmp, break;
.a = 0xFF
};
case ColourType::Truecolour: case ColourType::Truecolour:
return Image::Color { color.r = extract_channel(bit_buffer);
.r = extract_channel(bit_buffer), color.g = extract_channel(bit_buffer);
.g = extract_channel(bit_buffer), color.b = extract_channel(bit_buffer);
.b = extract_channel(bit_buffer), color.a = 0xFF;
.a = 0xFF break;
};
case ColourType::IndexedColour: case ColourType::IndexedColour:
return palette[MUST(bit_buffer.get_bits(bits_per_channel))]; color = palette[MUST(bit_buffer.get_bits(bits_per_channel))];
break;
case ColourType::GreyscaleAlpha: case ColourType::GreyscaleAlpha:
tmp = extract_channel(bit_buffer); color.r = extract_channel(bit_buffer);
return Image::Color { color.g = color.r;
.r = tmp, color.b = color.r;
.g = tmp, color.a = extract_channel(bit_buffer);
.b = tmp, break;
.a = extract_channel(bit_buffer)
};
case ColourType::TruecolourAlpha: case ColourType::TruecolourAlpha:
return Image::Color { color.r = extract_channel(bit_buffer);
.r = extract_channel(bit_buffer), color.g = extract_channel(bit_buffer);
.g = extract_channel(bit_buffer), color.b = extract_channel(bit_buffer);
.b = extract_channel(bit_buffer), color.a = extract_channel(bit_buffer);
.a = extract_channel(bit_buffer) break;
};
} }
ASSERT_NOT_REACHED(); return color;
}; };
constexpr auto paeth_predictor = constexpr auto paeth_predictor =
[](int16_t a, int16_t b, int16_t c) -> uint8_t [](int16_t a, int16_t b, int16_t c) -> uint8_t
{ {
int16_t p = a + b - c; const int16_t p = a + b - c;
int16_t pa = BAN::Math::abs(p - a); const int16_t pa = BAN::Math::abs(p - a);
int16_t pb = BAN::Math::abs(p - b); const int16_t pb = BAN::Math::abs(p - b);
int16_t pc = BAN::Math::abs(p - c); const int16_t pc = BAN::Math::abs(p - c);
if (pa <= pb && pa <= pc) if (pa <= pb && pa <= pc)
return a; return a;
if (pb <= pc) if (pb <= pc)

View File

@ -13,9 +13,9 @@ namespace LibImage
public: public:
struct Color struct Color
{ {
uint8_t r;
uint8_t g;
uint8_t b; uint8_t b;
uint8_t g;
uint8_t r;
uint8_t a; uint8_t a;
// Calculate weighted average of colors // Calculate weighted average of colors
@ -25,14 +25,14 @@ namespace LibImage
const double b_mult = weight < 0.0 ? 0.0 : weight > 1.0 ? 1.0 : weight; const double b_mult = weight < 0.0 ? 0.0 : weight > 1.0 ? 1.0 : weight;
const double a_mult = 1.0 - b_mult; const double a_mult = 1.0 - b_mult;
return Color { return Color {
.r = static_cast<uint8_t>(a.r * a_mult + b.r * b_mult),
.g = static_cast<uint8_t>(a.g * a_mult + b.g * b_mult),
.b = static_cast<uint8_t>(a.b * a_mult + b.b * b_mult), .b = static_cast<uint8_t>(a.b * a_mult + b.b * b_mult),
.g = static_cast<uint8_t>(a.g * a_mult + b.g * b_mult),
.r = static_cast<uint8_t>(a.r * a_mult + b.r * b_mult),
.a = static_cast<uint8_t>(a.a * a_mult + b.a * b_mult), .a = static_cast<uint8_t>(a.a * a_mult + b.a * b_mult),
}; };
} }
uint32_t as_rgba() const uint32_t as_argb() const
{ {
return ((uint32_t)a << 24) | ((uint32_t)r << 16) | ((uint32_t)g << 8) | b; return ((uint32_t)a << 24) | ((uint32_t)r << 16) | ((uint32_t)g << 8) | b;
} }

View File

@ -8,6 +8,7 @@ set(USERSPACE_PROGRAMS
dhcp-client dhcp-client
DynamicLoader DynamicLoader
echo echo
env
getopt getopt
http-server http-server
id id

View File

@ -0,0 +1,34 @@
#include "Alias.h"
BAN::ErrorOr<void> Alias::set_alias(BAN::StringView name, BAN::StringView value)
{
TRY(m_aliases.insert_or_assign(
TRY(BAN::String::formatted("{}", name)),
TRY(BAN::String::formatted("{}", value))
));
return {};
}
BAN::Optional<BAN::StringView> Alias::get_alias(const BAN::String& name) const
{
auto it = m_aliases.find(name);
if (it == m_aliases.end())
return {};
return it->value.sv();
}
void Alias::for_each_alias(BAN::Function<BAN::Iteration(BAN::StringView, BAN::StringView)> callback) const
{
for (const auto& [name, value] : m_aliases)
{
switch (callback(name.sv(), value.sv()))
{
case BAN::Iteration::Break:
break;
case BAN::Iteration::Continue:
continue;;
}
break;
}
}

View File

@ -0,0 +1,31 @@
#pragma once
#include <BAN/Function.h>
#include <BAN/HashMap.h>
#include <BAN/Iteration.h>
#include <BAN/NoCopyMove.h>
#include <BAN/String.h>
class Alias
{
BAN_NON_COPYABLE(Alias);
BAN_NON_MOVABLE(Alias);
public:
Alias() = default;
static Alias& get()
{
static Alias s_instance;
return s_instance;
}
BAN::ErrorOr<void> set_alias(BAN::StringView name, BAN::StringView value);
// NOTE: `const BAN::String&` instead of `BAN::StringView` to avoid BAN::String construction
// for hashmap accesses
BAN::Optional<BAN::StringView> get_alias(const BAN::String& name) const;
void for_each_alias(BAN::Function<BAN::Iteration(BAN::StringView, BAN::StringView)>) const;
private:
BAN::HashMap<BAN::String, BAN::String> m_aliases;
};

View File

@ -0,0 +1,306 @@
#include "Alias.h"
#include "Builtin.h"
#include "Execute.h"
#include <ctype.h>
#include <limits.h>
#include <sys/stat.h>
#include <time.h>
#include <unistd.h>
#define ERROR_RETURN(__msg, __ret) do { perror(__msg); return __ret; } while (false)
void Builtin::initialize()
{
MUST(m_builtin_commands.emplace("clear"_sv,
[](Execute&, BAN::Span<const BAN::String>, FILE*, FILE* fout) -> int
{
fprintf(fout, "\e[H\e[3J\e[2J");
fflush(fout);
return 0;
}, true
));
MUST(m_builtin_commands.emplace("exit"_sv,
[](Execute&, BAN::Span<const BAN::String> arguments, FILE*, FILE*) -> int
{
int exit_code = 0;
if (arguments.size() > 1)
{
auto exit_string = arguments[1].sv();
for (size_t i = 0; i < exit_string.size() && isdigit(exit_string[i]); i++)
exit_code = (exit_code * 10) + (exit_string[i] - '0');
}
exit(exit_code);
ASSERT_NOT_REACHED();
}, true
));
MUST(m_builtin_commands.emplace("export"_sv,
[](Execute&, BAN::Span<const BAN::String> arguments, FILE*, FILE*) -> int
{
bool first = false;
for (const auto& argument : arguments)
{
if (first)
{
first = false;
continue;
}
auto split = MUST(argument.sv().split('=', true));
if (split.size() != 2)
continue;
if (setenv(BAN::String(split[0]).data(), BAN::String(split[1]).data(), true) == -1)
ERROR_RETURN("setenv", 1);
}
return 0;
}, true
));
MUST(m_builtin_commands.emplace("unset"_sv,
[](Execute&, BAN::Span<const BAN::String> arguments, FILE*, FILE*) -> int
{
for (const auto& argument : arguments)
if (unsetenv(argument.data()) == -1)
ERROR_RETURN("unsetenv", 1);
return 0;
}, true
));
MUST(m_builtin_commands.emplace("alias"_sv,
[](Execute&, BAN::Span<const BAN::String> arguments, FILE*, FILE* fout) -> int
{
if (arguments.size() == 1)
{
Alias::get().for_each_alias(
[fout](BAN::StringView name, BAN::StringView value) -> BAN::Iteration
{
fprintf(fout, "%.*s='%.*s'\n",
(int)name.size(), name.data(),
(int)value.size(), value.data()
);
return BAN::Iteration::Continue;
}
);
return 0;
}
for (size_t i = 1; i < arguments.size(); i++)
{
auto idx = arguments[i].sv().find('=');
if (idx.has_value() && idx.value() == 0)
continue;
if (!idx.has_value())
{
auto value = Alias::get().get_alias(arguments[i]);
if (value.has_value())
fprintf(fout, "%s='%.*s'\n", arguments[i].data(), (int)value->size(), value->data());
}
else
{
auto alias = arguments[i].sv().substring(0, idx.value());
auto value = arguments[i].sv().substring(idx.value() + 1);
if (auto ret = Alias::get().set_alias(alias, value); ret.is_error())
fprintf(stderr, "could not set alias: %s\n", ret.error().get_message());
}
}
return 0;
}, true
));
MUST(m_builtin_commands.emplace("source"_sv,
[](Execute& execute, BAN::Span<const BAN::String> arguments, FILE*, FILE* fout) -> int
{
if (arguments.size() != 2)
{
fprintf(fout, "usage: source FILE\n");
return 1;
}
if (execute.source_script(arguments[1]).is_error())
return 1;
return 0;
}, true
));
MUST(m_builtin_commands.emplace("cd"_sv,
[](Execute&, BAN::Span<const BAN::String> arguments, FILE*, FILE* fout) -> int
{
if (arguments.size() > 2)
{
fprintf(fout, "cd: too many arguments\n");
return 1;
}
BAN::StringView path;
if (arguments.size() == 1)
{
if (const char* path_env = getenv("HOME"))
path = path_env;
else
return 0;
}
else
path = arguments[1];
if (chdir(path.data()) == -1)
ERROR_RETURN("chdir", 1);
return 0;
}, true
));
MUST(m_builtin_commands.emplace("type"_sv,
[](Execute&, BAN::Span<const BAN::String> arguments, FILE*, FILE* fout) -> int
{
const auto is_executable_file =
[](const char* path) -> bool
{
struct stat st;
if (stat(path, &st) == -1)
return false;
if (!(st.st_mode & (S_IXUSR | S_IXGRP | S_IXOTH)))
return false;
return true;
};
if (!arguments.empty())
arguments = arguments.slice(1);
BAN::Vector<BAN::StringView> path_dirs;
if (const char* path_env = getenv("PATH"))
if (auto split_ret = BAN::StringView(path_env ? path_env : "").split(':'); !split_ret.is_error())
path_dirs = split_ret.release_value();
for (const auto& argument : arguments)
{
if (auto alias = Alias::get().get_alias(argument); alias.has_value())
{
fprintf(fout, "%s is an alias for %s\n", argument.data(), alias->data());
continue;
}
if (Builtin::get().find_builtin(argument))
{
fprintf(fout, "%s is a shell builtin\n", argument.data());
continue;
}
if (argument.sv().contains('/'))
{
if (is_executable_file(argument.data()))
{
fprintf(fout, "%s is %s\n", argument.data(), argument.data());
continue;
}
}
else
{
bool found = false;
for (const auto& path_dir : path_dirs)
{
char path_buffer[PATH_MAX];
memcpy(path_buffer, path_dir.data(), path_dir.size());
memcpy(path_buffer + path_dir.size(), argument.data(), argument.size());
path_buffer[path_dir.size() + argument.size()] = '\0';
if (is_executable_file(path_buffer))
{
fprintf(fout, "%s is %s\n", argument.data(), path_buffer);
found = true;
break;
}
}
if (found)
continue;
}
fprintf(fout, "%s not found\n", argument.data());
}
return 0;
}, true
));
// FIXME: time should not actually be a builtin command but a shell reserved keyword
// e.g. `time foobar=lol sh -c 'echo $foobar'` should resolve set foobar env
MUST(m_builtin_commands.emplace("time"_sv,
[](Execute& execute, BAN::Span<const BAN::String> arguments, FILE* fin, FILE* fout) -> int
{
timespec start, end;
if (clock_gettime(CLOCK_MONOTONIC, &start) == -1)
ERROR_RETURN("clock_gettime", 1);
auto execute_ret = execute.execute_command_sync(arguments.slice(1), fileno(fin), fileno(fout));
if (clock_gettime(CLOCK_MONOTONIC, &end) == -1)
ERROR_RETURN("clock_gettime", 1);
uint64_t total_ns = 0;
total_ns += (end.tv_sec - start.tv_sec) * 1'000'000'000;
total_ns += end.tv_nsec - start.tv_nsec;
int secs = total_ns / 1'000'000'000;
int msecs = (total_ns % 1'000'000'000) / 1'000'000;
fprintf(fout, "took %d.%03d s\n", secs, msecs);
if (execute_ret.is_error())
return 256 + execute_ret.error().get_error_code();
return execute_ret.value();
}, false
));
}
void Builtin::for_each_builtin(BAN::Function<BAN::Iteration(BAN::StringView, const BuiltinCommand&)> callback) const
{
for (const auto& [name, function] : m_builtin_commands)
{
switch (callback(name.sv(), function))
{
case BAN::Iteration::Break:
break;
case BAN::Iteration::Continue:
continue;;
}
break;
}
}
const Builtin::BuiltinCommand* Builtin::find_builtin(const BAN::String& name) const
{
auto it = m_builtin_commands.find(name);
if (it == m_builtin_commands.end())
return nullptr;
return &it->value;
}
BAN::ErrorOr<int> Builtin::BuiltinCommand::execute(Execute& execute, BAN::Span<const BAN::String> arguments, int fd_in, int fd_out) const
{
const auto fd_to_file =
[](int fd, FILE* file, const char* mode) -> BAN::ErrorOr<FILE*>
{
if (fd == fileno(file))
return file;
int fd_dup = dup(fd);
if (fd_dup == -1)
return BAN::Error::from_errno(errno);
file = fdopen(fd_dup, mode);
if (file == nullptr)
return BAN::Error::from_errno(errno);
return file;
};
FILE* fin = TRY(fd_to_file(fd_in, stdin, "r"));
FILE* fout = TRY(fd_to_file(fd_out, stdout, "w"));
int ret = function(execute, arguments, fin, fout);
if (fileno(fin) != fd_in ) fclose(fin);
if (fileno(fout) != fd_out) fclose(fout);
return ret;
}

View File

@ -0,0 +1,50 @@
#pragma once
#include <BAN/Function.h>
#include <BAN/HashMap.h>
#include <BAN/Iteration.h>
#include <BAN/NoCopyMove.h>
#include <BAN/String.h>
#include <stdio.h>
class Execute;
class Builtin
{
BAN_NON_COPYABLE(Builtin);
BAN_NON_MOVABLE(Builtin);
public:
struct BuiltinCommand
{
using function_t = int (*)(Execute&, BAN::Span<const BAN::String>, FILE* fin, FILE* fout);
function_t function { nullptr };
bool immediate { false };
BuiltinCommand(function_t function, bool immediate)
: function(function)
, immediate(immediate)
{ }
BAN::ErrorOr<int> execute(Execute&, BAN::Span<const BAN::String> arguments, int fd_in, int fd_out) const;
};
public:
Builtin() = default;
static Builtin& get()
{
static Builtin s_instance;
return s_instance;
}
void initialize();
void for_each_builtin(BAN::Function<BAN::Iteration(BAN::StringView, const BuiltinCommand&)>) const;
// return nullptr if not found
const BuiltinCommand* find_builtin(const BAN::String& name) const;
private:
BAN::HashMap<BAN::String, BuiltinCommand> m_builtin_commands;
};

View File

@ -1,5 +1,13 @@
set(SOURCES set(SOURCES
main.cpp main.cpp
Alias.cpp
Builtin.cpp
CommandTypes.cpp
Execute.cpp
Input.cpp
Lexer.cpp
Token.cpp
TokenParser.cpp
) )
add_executable(Shell ${SOURCES}) add_executable(Shell ${SOURCES})

View File

@ -0,0 +1,141 @@
#include "CommandTypes.h"
#include "Execute.h"
#include <BAN/ScopeGuard.h>
#include <ctype.h>
#include <limits.h>
#include <stdio.h>
#include <unistd.h>
extern int g_pid;
extern int g_argc;
extern char** g_argv;
BAN::ErrorOr<BAN::String> CommandArgument::evaluate(Execute& execute) const
{
static_assert(
BAN::is_same_v<CommandArgument::ArgumentPart,
BAN::Variant<
FixedString,
EnvironmentVariable,
BuiltinVariable,
CommandTree
>
>
);
BAN::String evaluated;
for (const auto& part : parts)
{
ASSERT(part.has_value());
if (part.has<FixedString>())
TRY(evaluated.append(part.get<FixedString>().value));
else if (part.has<EnvironmentVariable>())
{
const char* env = getenv(part.get<EnvironmentVariable>().value.data());
if (env != nullptr)
TRY(evaluated.append(env));
}
else if (part.has<BuiltinVariable>())
{
const auto& builtin = part.get<BuiltinVariable>();
ASSERT(!builtin.value.empty());
if (!isdigit(builtin.value.front()))
{
ASSERT(builtin.value.size() == 1);
switch (builtin.value.front())
{
case '_':
case '@':
case '*':
case '-':
fprintf(stderr, "TODO: $%c\n", builtin.value.front());
break;
case '$':
evaluated = TRY(BAN::String::formatted("{}", g_pid));
break;
case '#':
evaluated = TRY(BAN::String::formatted("{}", g_argc - 1));
break;
case '?':
evaluated = TRY(BAN::String::formatted("{}", execute.last_return_value()));
break;
case '!':
evaluated = TRY(BAN::String::formatted("{}", execute.last_background_pid()));
break;
default:
ASSERT_NOT_REACHED();
}
}
else
{
int argv_index = 0;
for (char c : builtin.value)
{
ASSERT(isdigit(c));
if (BAN::Math::will_multiplication_overflow<int>(argv_index, 10) ||
BAN::Math::will_addition_overflow<int>(argv_index * 10, c - '0'))
{
argv_index = INT_MAX;
fprintf(stderr, "integer overflow, capping at %d\n", argv_index);
break;
}
argv_index = (argv_index * 10) + (c - '0');
}
if (argv_index < g_argc)
TRY(evaluated.append(const_cast<const char*>(g_argv[argv_index])));
}
}
else if (part.has<CommandTree>())
{
// FIXME: this should resolve to multiple arguments if not double quoted
int execute_pipe[2];
if (pipe(execute_pipe) == -1)
return BAN::Error::from_errno(errno);
BAN::ScopeGuard pipe_rd_closer([execute_pipe] { close(execute_pipe[0]); });
BAN::ScopeGuard pipe_wr_closer([execute_pipe] { close(execute_pipe[1]); });
const pid_t child_pid = fork();
if (child_pid == -1)
return BAN::Error::from_errno(errno);
if (child_pid == 0)
{
if (dup2(execute_pipe[1], STDOUT_FILENO) == -1)
return BAN::Error::from_errno(errno);
setpgrp();
auto ret = execute.execute_command(part.get<CommandTree>());
if (ret.is_error())
exit(ret.error().get_error_code());
exit(execute.last_return_value());
}
pipe_wr_closer.disable();
close(execute_pipe[1]);
char buffer[128];
while (true)
{
const ssize_t nread = read(execute_pipe[0], buffer, sizeof(buffer));
if (nread < 0)
perror("read");
if (nread <= 0)
break;
TRY(evaluated.append(BAN::StringView(buffer, nread)));
}
while (!evaluated.empty() && isspace(evaluated.back()))
evaluated.pop_back();
}
else
{
ASSERT_NOT_REACHED();
}
}
return evaluated;
}

View File

@ -0,0 +1,107 @@
#pragma once
#include <BAN/String.h>
#define COMMAND_GET_MACRO(_0, _1, _2, NAME, ...) NAME
#define COMMAND_MOVE_0(class) \
class(class&& o) { } \
class& operator=(class&& o) { }
#define COMMAND_MOVE_1(class, var) \
class(class&& o) { var = BAN::move(o.var); } \
class& operator=(class&& o) { var = BAN::move(o.var); return *this; }
#define COMMAND_MOVE_2(class, var1, var2) \
class(class&& o) { var1 = BAN::move(o.var1); var2 = BAN::move(o.var2); } \
class& operator=(class&& o) { var1 = BAN::move(o.var1); var2 = BAN::move(o.var2); return *this; }
#define COMMAND_MOVE(class, ...) COMMAND_GET_MACRO(_0 __VA_OPT__(,) __VA_ARGS__, COMMAND_MOVE_2, COMMAND_MOVE_1, COMMAND_MOVE_0)(class, __VA_ARGS__)
#define COMMAND_RULE5(class, ...) \
class() = default; \
class(const class&) = delete; \
class& operator=(const class&) = delete; \
COMMAND_MOVE(class, __VA_ARGS__)
struct CommandTree;
class Execute;
struct CommandArgument
{
struct FixedString
{
COMMAND_RULE5(FixedString, value);
BAN::String value;
};
struct EnvironmentVariable
{
COMMAND_RULE5(EnvironmentVariable, value);
BAN::String value;
};
struct BuiltinVariable
{
COMMAND_RULE5(BuiltinVariable, value);
BAN::String value;
};
using ArgumentPart =
BAN::Variant<
FixedString,
EnvironmentVariable,
BuiltinVariable,
CommandTree
>;
BAN::ErrorOr<BAN::String> evaluate(Execute& execute) const;
COMMAND_RULE5(CommandArgument, parts);
BAN::Vector<ArgumentPart> parts;
};
struct SingleCommand
{
struct EnvironmentVariable
{
COMMAND_RULE5(EnvironmentVariable, name, value);
BAN::String name;
CommandArgument value;
};
COMMAND_RULE5(SingleCommand, environment, arguments);
BAN::Vector<EnvironmentVariable> environment;
BAN::Vector<CommandArgument> arguments;
};
struct PipedCommand
{
COMMAND_RULE5(PipedCommand, commands, background);
BAN::Vector<SingleCommand> commands;
bool background { false };
};
struct ConditionalCommand
{
enum class Condition
{
Always,
OnFailure,
OnSuccess,
};
COMMAND_RULE5(ConditionalCommand, command, condition);
PipedCommand command;
Condition condition { Condition::Always };
};
struct CommandTree
{
COMMAND_RULE5(CommandTree, commands);
BAN::Vector<ConditionalCommand> commands;
};
#undef COMMAND_GET_MACRO
#undef COMMAND_MOVE_0
#undef COMMAND_MOVE_1
#undef COMMAND_MOVE_2
#undef COMMAND_MOVE
#undef COMMAND_RULE5

View File

@ -0,0 +1,356 @@
#include "Builtin.h"
#include "Execute.h"
#include "TokenParser.h"
#include <BAN/ScopeGuard.h>
#include <limits.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <termios.h>
#include <unistd.h>
#define CHECK_FD_OR_PERROR_AND_EXIT(oldfd, newfd) ({ if ((oldfd) != (newfd) && dup2((oldfd), (newfd)) == -1) { perror("dup2"); exit(errno); } })
#define TRY_OR_PERROR_AND_BREAK(expr) ({ auto&& eval = (expr); if (eval.is_error()) { fprintf(stderr, "%s\n", eval.error().get_message()); continue; } eval.release_value(); })
#define TRY_OR_EXIT(expr) ({ auto&& eval = (expr); if (eval.is_error()) exit(eval.error().get_error_code()); eval.release_value(); })
static BAN::ErrorOr<BAN::String> find_absolute_path_of_executable(const BAN::String& command)
{
if (command.size() >= PATH_MAX)
return BAN::Error::from_errno(ENAMETOOLONG);
const auto check_executable_file =
[](const char* path) -> BAN::ErrorOr<void>
{
struct stat st;
if (stat(path, &st) == -1)
return BAN::Error::from_errno(errno);
if (!(st.st_mode & (S_IXUSR | S_IXGRP | S_IXOTH)))
return BAN::Error::from_errno(ENOEXEC);
return {};
};
if (command.sv().contains('/'))
{
TRY(check_executable_file(command.data()));
return TRY(BAN::String::formatted("{}", command));
}
const char* path_env = getenv("PATH");
if (path_env == nullptr)
return BAN::Error::from_errno(ENOENT);
auto path_dirs = TRY(BAN::StringView(path_env).split(':'));
for (auto path_dir : path_dirs)
{
const auto absolute_path = TRY(BAN::String::formatted("{}/{}", path_dir, command));
auto check_result = check_executable_file(absolute_path.data());
if (!check_result.is_error())
return absolute_path;
if (check_result.error().get_error_code() == ENOENT)
continue;
return check_result.release_error();
}
return BAN::Error::from_errno(ENOENT);
}
BAN::ErrorOr<Execute::ExecuteResult> Execute::execute_command_no_wait(const InternalCommand& command)
{
ASSERT(!command.arguments.empty());
if (command.command.has<Builtin::BuiltinCommand>() && !command.background)
{
const auto& builtin = command.command.get<Builtin::BuiltinCommand>();
if (builtin.immediate)
{
return ExecuteResult {
.pid = -1,
.exit_code = TRY(builtin.execute(*this, command.arguments, command.fd_in, command.fd_out))
};
}
}
const pid_t child_pid = fork();
if (child_pid == -1)
return BAN::Error::from_errno(errno);
if (child_pid == 0)
{
if (command.command.has<Builtin::BuiltinCommand>())
{
auto builtin_ret = command.command.get<Builtin::BuiltinCommand>().execute(*this, command.arguments, command.fd_in, command.fd_out);
if (builtin_ret.is_error())
exit(builtin_ret.error().get_error_code());
exit(builtin_ret.value());
}
for (const auto& environment : command.environments)
setenv(environment.name.data(), environment.value.data(), true);
BAN::Vector<const char*> exec_args;
TRY_OR_EXIT(exec_args.reserve(command.arguments.size() + 1));
for (const auto& argument : command.arguments)
TRY_OR_EXIT(exec_args.push_back(argument.data()));
TRY_OR_EXIT(exec_args.push_back(nullptr));
CHECK_FD_OR_PERROR_AND_EXIT(command.fd_in, STDIN_FILENO);
CHECK_FD_OR_PERROR_AND_EXIT(command.fd_out, STDOUT_FILENO);
execv(command.command.get<BAN::String>().data(), const_cast<char* const*>(exec_args.data()));
exit(errno);
}
if (setpgid(child_pid, command.pgrp ? command.pgrp : child_pid))
perror("setpgid");
if (!command.background && command.pgrp == 0 && isatty(STDIN_FILENO))
if (tcsetpgrp(STDIN_FILENO, child_pid) == -1)
perror("tcsetpgrp");
return ExecuteResult {
.pid = child_pid,
.exit_code = -1,
};
}
BAN::ErrorOr<int> Execute::execute_command_sync(BAN::Span<const BAN::String> arguments, int fd_in, int fd_out)
{
if (arguments.empty())
return 0;
InternalCommand command {
.command = {},
.arguments = arguments,
.environments = {},
.fd_in = fd_in,
.fd_out = fd_out,
.background = false,
.pgrp = getpgrp(),
};
if (const auto* builtin = Builtin::get().find_builtin(arguments[0]))
command.command = *builtin;
else
{
auto absolute_path_or_error = find_absolute_path_of_executable(arguments[0]);
if (absolute_path_or_error.is_error())
{
if (absolute_path_or_error.error().get_error_code() == ENOENT)
{
fprintf(stderr, "command not found: %s\n", arguments[0].data());
return 127;
}
fprintf(stderr, "could not execute command: %s\n", absolute_path_or_error.error().get_message());
return 126;
}
command.command = absolute_path_or_error.release_value();
}
const auto execute_result = TRY(execute_command_no_wait(command));
if (execute_result.pid == -1)
return execute_result.exit_code;
int status;
if (waitpid(execute_result.pid, &status, 0) == -1)
return BAN::Error::from_errno(errno);
if (!WIFSIGNALED(status))
return WEXITSTATUS(status);
return 128 + WTERMSIG(status);
}
BAN::ErrorOr<void> Execute::execute_command(const PipedCommand& piped_command)
{
ASSERT(!piped_command.commands.empty());
int last_pipe_rd = STDIN_FILENO;
BAN::Vector<pid_t> child_pids;
TRY(child_pids.resize(piped_command.commands.size(), 0));
BAN::Vector<int> child_codes;
TRY(child_codes.resize(piped_command.commands.size(), 126));
const auto evaluate_arguments =
[this](BAN::Span<const CommandArgument> arguments) -> BAN::ErrorOr<BAN::Vector<BAN::String>>
{
BAN::Vector<BAN::String> result;
TRY(result.reserve(arguments.size()));
for (const auto& argument : arguments)
TRY(result.push_back(TRY(argument.evaluate(*this))));
return result;
};
const auto evaluate_environment =
[this](BAN::Span<const SingleCommand::EnvironmentVariable> environments) -> BAN::ErrorOr<BAN::Vector<InternalCommand::Environment>>
{
BAN::Vector<InternalCommand::Environment> result;
TRY(result.reserve(environments.size()));
for (const auto& environment : environments)
TRY(result.emplace_back(environment.name, TRY(environment.value.evaluate(*this))));
return result;
};
for (size_t i = 0; i < piped_command.commands.size(); i++)
{
int new_pipe[2] { STDIN_FILENO, STDOUT_FILENO };
if (i != piped_command.commands.size() - 1)
if (pipe(new_pipe) == -1)
return BAN::Error::from_errno(errno);
BAN::ScopeGuard pipe_closer(
[&]()
{
if (new_pipe[1] != STDOUT_FILENO)
close(new_pipe[1]);
if (last_pipe_rd != STDIN_FILENO)
close(last_pipe_rd);
last_pipe_rd = new_pipe[0];
}
);
const int fd_in = last_pipe_rd;
const int fd_out = new_pipe[1];
const auto arguments = TRY_OR_PERROR_AND_BREAK(evaluate_arguments(piped_command.commands[i].arguments.span()));
const auto environments = TRY_OR_PERROR_AND_BREAK(evaluate_environment(piped_command.commands[i].environment.span()));
InternalCommand command {
.command = {},
.arguments = arguments.span(),
.environments = environments.span(),
.fd_in = fd_in,
.fd_out = fd_out,
.background = piped_command.background,
.pgrp = child_pids.front(),
};
if (const auto* builtin = Builtin::get().find_builtin(arguments[0]))
command.command = *builtin;
else
{
auto absolute_path_or_error = find_absolute_path_of_executable(arguments[0]);
if (absolute_path_or_error.is_error())
{
if (absolute_path_or_error.error().get_error_code() == ENOENT)
{
fprintf(stderr, "command not found: %s\n", arguments[0].data());
child_codes[i] = 127;
}
else
{
fprintf(stderr, "could not execute command: %s\n", absolute_path_or_error.error().get_message());
child_codes[i] = 126;
}
continue;
}
command.command = absolute_path_or_error.release_value();
}
auto execute_result = TRY_OR_PERROR_AND_BREAK(execute_command_no_wait(command));
if (execute_result.pid == -1)
child_codes[i] = execute_result.exit_code;
else
child_pids[i] = execute_result.pid;
}
if (last_pipe_rd != STDIN_FILENO)
close(last_pipe_rd);
if (piped_command.background)
return {};
for (size_t i = 0; i < piped_command.commands.size(); i++)
{
if (child_pids[i] == 0)
continue;
int status = 0;
if (waitpid(child_pids[i], &status, 0) == -1)
perror("waitpid");
if (WIFEXITED(status))
child_codes[i] = WEXITSTATUS(status);
else if (WIFSIGNALED(status))
child_codes[i] = 128 + WTERMSIG(status);
else
ASSERT_NOT_REACHED();
}
if (isatty(STDIN_FILENO) && tcsetpgrp(0, getpgrp()) == -1)
perror("tcsetpgrp");
m_last_return_value = child_codes.back();
return {};
}
BAN::ErrorOr<void> Execute::execute_command(const CommandTree& command_tree)
{
for (const auto& [command, condition] : command_tree.commands)
{
bool should_run = false;
switch (condition)
{
case ConditionalCommand::Condition::Always:
should_run = true;
break;
case ConditionalCommand::Condition::OnFailure:
should_run = (m_last_return_value != 0);
break;
case ConditionalCommand::Condition::OnSuccess:
should_run = (m_last_return_value == 0);
break;
}
if (!should_run)
continue;
TRY(execute_command(command));
}
return {};
}
BAN::ErrorOr<void> Execute::source_script(BAN::StringView path)
{
BAN::Vector<BAN::String> script_lines;
{
FILE* fp = fopen(path.data(), "r");
if (fp == nullptr)
return BAN::Error::from_errno(errno);
BAN::String current;
char temp_buffer[128];
while (fgets(temp_buffer, sizeof(temp_buffer), fp))
{
TRY(current.append(temp_buffer));
if (current.back() != '\n')
continue;
current.pop_back();
if (!current.empty())
TRY(script_lines.push_back(BAN::move(current)));
current.clear();
}
if (!current.empty())
TRY(script_lines.push_back(BAN::move(current)));
fclose(fp);
}
size_t index = 0;
TokenParser parser(
[&](BAN::Optional<BAN::StringView>) -> BAN::Optional<BAN::String>
{
if (index >= script_lines.size())
return {};
return script_lines[index++];
}
);
if (!parser.main_loop(true))
return BAN::Error::from_literal("oop");
return {};
}

View File

@ -0,0 +1,62 @@
#pragma once
#include "Builtin.h"
#include "CommandTypes.h"
#include <BAN/NoCopyMove.h>
class Execute
{
BAN_NON_COPYABLE(Execute);
BAN_NON_MOVABLE(Execute);
public:
Execute() = default;
BAN::ErrorOr<int> execute_command_sync(BAN::Span<const BAN::String> arguments, int fd_in, int fd_out);
BAN::ErrorOr<void> execute_command(const SingleCommand&, int fd_in, int fd_out, bool background, pid_t pgrp = 0);
BAN::ErrorOr<void> execute_command(const PipedCommand&);
BAN::ErrorOr<void> execute_command(const CommandTree&);
BAN::ErrorOr<void> source_script(BAN::StringView path);
int last_background_pid() const { return m_last_background_pid; }
int last_return_value() const { return m_last_return_value; }
private:
struct InternalCommand
{
using Command = BAN::Variant<Builtin::BuiltinCommand, BAN::String>;
enum class Type
{
Builtin,
External,
};
struct Environment
{
BAN::String name;
BAN::String value;
};
Command command;
BAN::Span<const BAN::String> arguments;
BAN::Span<const Environment> environments;
int fd_in;
int fd_out;
bool background;
pid_t pgrp;
};
struct ExecuteResult
{
pid_t pid;
int exit_code;
};
BAN::ErrorOr<ExecuteResult> execute_command_no_wait(const InternalCommand& command);
private:
int m_last_background_pid { 0 };
int m_last_return_value { 0 };
};

View File

@ -0,0 +1,682 @@
#include "Alias.h"
#include "Builtin.h"
#include "Input.h"
#include <BAN/ScopeGuard.h>
#include <BAN/Sort.h>
#include <ctype.h>
#include <dirent.h>
#include <pwd.h>
#include <sys/stat.h>
#include <unistd.h>
static struct termios s_original_termios;
static struct termios s_raw_termios;
static bool s_termios_initialized { false };
static BAN::Vector<BAN::String> list_matching_entries(BAN::StringView path, BAN::StringView start, bool require_executable)
{
ASSERT(path.size() < PATH_MAX);
char path_cstr[PATH_MAX];
memcpy(path_cstr, path.data(), path.size());
path_cstr[path.size()] = '\0';
DIR* dirp = opendir(path_cstr);
if (dirp == nullptr)
return {};
BAN::Vector<BAN::String> result;
dirent* entry;
while ((entry = readdir(dirp)))
{
if (entry->d_name[0] == '.' && !start.starts_with("."_sv))
continue;
if (strncmp(entry->d_name, start.data(), start.size()))
continue;
struct stat st;
if (fstatat(dirfd(dirp), entry->d_name, &st, 0))
continue;
if (require_executable)
{
if (S_ISDIR(st.st_mode))
continue;
if (!(st.st_mode & (S_IXUSR | S_IXGRP | S_IXUSR)))
continue;
}
MUST(result.emplace_back(entry->d_name + start.size()));
if (S_ISDIR(st.st_mode))
MUST(result.back().push_back('/'));
}
closedir(dirp);
return BAN::move(result);
}
struct TabCompletion
{
bool should_escape_spaces { false };
BAN::StringView prefix;
BAN::Vector<BAN::String> completions;
};
static TabCompletion list_tab_completion_entries(BAN::StringView current_input)
{
enum class CompletionType
{
Command,
File,
};
BAN::StringView prefix = current_input;
BAN::String last_argument;
CompletionType completion_type = CompletionType::Command;
bool should_escape_spaces = true;
for (size_t i = 0; i < current_input.size(); i++)
{
if (current_input[i] == '\\')
{
i++;
if (i < current_input.size())
MUST(last_argument.push_back(current_input[i]));
}
else if (isspace(current_input[i]) || current_input[i] == ';' || current_input[i] == '|' || current_input.substring(i).starts_with("&&"_sv))
{
if (!isspace(current_input[i]))
completion_type = CompletionType::Command;
else if (!last_argument.empty())
completion_type = CompletionType::File;
if (auto rest = current_input.substring(i); rest.starts_with("||"_sv) || rest.starts_with("&&"_sv))
i++;
prefix = current_input.substring(i + 1);
last_argument.clear();
should_escape_spaces = true;
}
else if (current_input[i] == '\'' || current_input[i] == '"')
{
const char quote_type = current_input[i++];
while (i < current_input.size() && current_input[i] != quote_type)
MUST(last_argument.push_back(current_input[i++]));
should_escape_spaces = false;
}
else
{
MUST(last_argument.push_back(current_input[i]));
}
}
if (last_argument.sv().contains('/'))
completion_type = CompletionType::File;
BAN::Vector<BAN::String> result;
switch (completion_type)
{
case CompletionType::Command:
{
const char* path_env = getenv("PATH");
if (path_env)
{
auto splitted_path_env = MUST(BAN::StringView(path_env).split(':'));
for (auto path : splitted_path_env)
{
auto matching_entries = list_matching_entries(path, last_argument, true);
MUST(result.reserve(result.size() + matching_entries.size()));
for (auto&& entry : matching_entries)
MUST(result.push_back(BAN::move(entry)));
}
}
Builtin::get().for_each_builtin(
[&](BAN::StringView name, const Builtin::BuiltinCommand&) -> BAN::Iteration
{
if (name.starts_with(last_argument))
MUST(result.emplace_back(name.substring(last_argument.size())));
return BAN::Iteration::Continue;
}
);
Alias::get().for_each_alias(
[&](BAN::StringView name, BAN::StringView) -> BAN::Iteration
{
if (name.starts_with(last_argument))
MUST(result.emplace_back(name.substring(last_argument.size())));
return BAN::Iteration::Continue;
}
);
break;
}
case CompletionType::File:
{
BAN::String dir_path;
if (last_argument.sv().starts_with("/"_sv))
MUST(dir_path.push_back('/'));
else
{
char cwd_buffer[PATH_MAX];
if (getcwd(cwd_buffer, sizeof(cwd_buffer)) == nullptr)
return {};
MUST(dir_path.reserve(strlen(cwd_buffer) + 1));
MUST(dir_path.append(cwd_buffer));
MUST(dir_path.push_back('/'));
}
auto match_against = last_argument.sv();
if (auto idx = match_against.rfind('/'); idx.has_value())
{
MUST(dir_path.append(match_against.substring(0, idx.value())));
match_against = match_against.substring(idx.value() + 1);
}
result = list_matching_entries(dir_path, match_against, false);
break;
}
}
if (auto idx = prefix.rfind('/'); idx.has_value())
prefix = prefix.substring(idx.value() + 1);
return { should_escape_spaces, prefix, BAN::move(result) };
}
static int character_length(BAN::StringView prompt)
{
int length { 0 };
bool in_escape { false };
for (char c : prompt)
{
if (in_escape)
{
if (isalpha(c))
in_escape = false;
}
else
{
if (c == '\e')
in_escape = true;
else if (((uint8_t)c & 0xC0) != 0x80)
length++;
}
}
return length;
}
BAN::String Input::parse_ps1_prompt()
{
const char* raw_prompt = getenv("PS1");
if (raw_prompt == nullptr)
return "$ "_sv;
BAN::String prompt;
for (int i = 0; raw_prompt[i]; i++)
{
char ch = raw_prompt[i];
if (ch == '\\')
{
switch (raw_prompt[++i])
{
case 'e':
MUST(prompt.push_back('\e'));
break;
case 'n':
MUST(prompt.push_back('\n'));
break;
case '\\':
MUST(prompt.push_back('\\'));
break;
case '~':
{
char buffer[256];
if (getcwd(buffer, sizeof(buffer)) == nullptr)
strcpy(buffer, strerrorname_np(errno));
const char* home = getenv("HOME");
size_t home_len = home ? strlen(home) : 0;
if (home && strncmp(buffer, home, home_len) == 0)
{
MUST(prompt.push_back('~'));
MUST(prompt.append(buffer + home_len));
}
else
{
MUST(prompt.append(buffer));
}
break;
}
case 'u':
{
static char* username = nullptr;
if (username == nullptr)
{
auto* passwd = getpwuid(geteuid());
if (passwd == nullptr)
break;
username = new char[strlen(passwd->pw_name) + 1];
strcpy(username, passwd->pw_name);
endpwent();
}
MUST(prompt.append(username));
break;
}
case 'h':
{
MUST(prompt.append(m_hostname));
break;
}
case '\0':
MUST(prompt.push_back('\\'));
break;
default:
MUST(prompt.push_back('\\'));
MUST(prompt.push_back(*raw_prompt));
break;
}
}
else
{
MUST(prompt.push_back(ch));
}
}
return prompt;
}
BAN::Optional<BAN::String> Input::get_input(BAN::Optional<BAN::StringView> custom_prompt)
{
tcsetattr(0, TCSANOW, &s_raw_termios);
BAN::ScopeGuard _([] { tcsetattr(0, TCSANOW, &s_original_termios); });
BAN::String ps1_prompt;
if (!custom_prompt.has_value())
ps1_prompt = parse_ps1_prompt();
const auto print_prompt =
[&]()
{
if (custom_prompt.has_value())
printf("%.*s", (int)custom_prompt->size(), custom_prompt->data());
else
printf("%.*s", (int)ps1_prompt.size(), ps1_prompt.data());
};
const auto prompt_length =
[&]() -> int
{
if (custom_prompt.has_value())
return custom_prompt->size();
return character_length(ps1_prompt);
};
print_prompt();
fflush(stdout);
while (true)
{
int chi = getchar();
if (chi == EOF)
{
if (errno != EINTR)
{
perror("getchar");
exit(1);
}
clearerr(stdin);
m_buffers = m_history;
MUST(m_buffers.emplace_back(""_sv));
m_buffer_index = m_buffers.size() - 1;
m_buffer_col = 0;
putchar('\n');
print_prompt();
fflush(stdout);
continue;
}
uint8_t ch = chi;
if (ch != '\t')
{
m_tab_completions.clear();
m_tab_index.clear();
}
if (m_waiting_utf8 > 0)
{
m_waiting_utf8--;
ASSERT((ch & 0xC0) == 0x80);
putchar(ch);
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
if (m_waiting_utf8 == 0)
{
printf("\e[s%s\e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
}
continue;
}
else if (ch & 0x80)
{
if ((ch & 0xE0) == 0xC0)
m_waiting_utf8 = 1;
else if ((ch & 0xF0) == 0xE0)
m_waiting_utf8 = 2;
else if ((ch & 0xF8) == 0xF0)
m_waiting_utf8 = 3;
else
ASSERT_NOT_REACHED();
putchar(ch);
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
continue;
}
switch (ch)
{
case '\e':
{
ch = getchar();
if (ch != '[')
break;
ch = getchar();
int value = 0;
while (isdigit(ch))
{
value = (value * 10) + (ch - '0');
ch = getchar();
}
switch (ch)
{
case 'A':
if (m_buffer_index > 0)
{
m_buffer_index--;
m_buffer_col = m_buffers[m_buffer_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
}
break;
case 'B':
if (m_buffer_index < m_buffers.size() - 1)
{
m_buffer_index++;
m_buffer_col = m_buffers[m_buffer_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
}
break;
case 'C':
if (m_buffer_col < m_buffers[m_buffer_index].size())
{
m_buffer_col++;
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffer_col++;
printf("\e[C");
fflush(stdout);
}
break;
case 'D':
if (m_buffer_col > 0)
{
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffer_col--;
m_buffer_col--;
printf("\e[D");
fflush(stdout);
}
break;
case '~':
switch (value)
{
case 3: // delete
if (m_buffer_col >= m_buffers[m_buffer_index].size())
break;
m_buffers[m_buffer_index].remove(m_buffer_col);
while (m_buffer_col < m_buffers[m_buffer_index].size() && (m_buffers[m_buffer_index][m_buffer_col] & 0xC0) == 0x80)
m_buffers[m_buffer_index].remove(m_buffer_col);
printf("\e[s%s \e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
}
break;
}
break;
}
case '\x0C': // ^L
{
int x = prompt_length() + character_length(m_buffers[m_buffer_index].sv().substring(m_buffer_col)) + 1;
printf("\e[H\e[J");
print_prompt();
printf("%s\e[u\e[1;%dH", m_buffers[m_buffer_index].data(), x);
fflush(stdout);
break;
}
case '\b':
if (m_buffer_col <= 0)
break;
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffer_col--;
m_buffer_col--;
printf("\e[D");
fflush(stdout);
break;
case '\x01': // ^A
m_buffer_col = 0;
printf("\e[%dG", prompt_length() + 1);
fflush(stdout);
break;
case '\x03': // ^C
putchar('\n');
print_prompt();
fflush(stdout);
m_buffers[m_buffer_index].clear();
m_buffer_col = 0;
break;
case '\x04': // ^D
if (!m_buffers[m_buffer_index].empty())
break;
putchar('\n');
return {};
case '\x7F': // backspace
if (m_buffer_col <= 0)
break;
while ((m_buffers[m_buffer_index][m_buffer_col - 1] & 0xC0) == 0x80)
m_buffers[m_buffer_index].remove(--m_buffer_col);
m_buffers[m_buffer_index].remove(--m_buffer_col);
printf("\b\e[s%s \e[u", m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
case '\n':
{
BAN::String input;
MUST(input.append(m_buffers[m_buffer_index]));
if (!m_buffers[m_buffer_index].empty())
{
MUST(m_history.push_back(m_buffers[m_buffer_index]));
m_buffers = m_history;
MUST(m_buffers.emplace_back(""_sv));
}
m_buffer_index = m_buffers.size() - 1;
m_buffer_col = 0;
putchar('\n');
return input;
}
case '\t':
{
// FIXME: tab completion is really hacked together currently.
// this should ask token parser about the current parse state
// and do completions based on that, not raw strings
if (m_buffer_col != m_buffers[m_buffer_index].size())
continue;
if (m_tab_completions.has_value())
{
ASSERT(m_tab_completions->size() >= 2);
if (!m_tab_index.has_value())
m_tab_index = 0;
else
{
MUST(m_buffers[m_buffer_index].resize(m_tab_completion_keep));
m_buffer_col = m_tab_completion_keep;
*m_tab_index = (*m_tab_index + 1) % m_tab_completions->size();
}
MUST(m_buffers[m_buffer_index].append(m_tab_completions.value()[*m_tab_index]));
m_buffer_col += m_tab_completions.value()[*m_tab_index].size();
printf("\e[%dG%s\e[K", prompt_length() + 1, m_buffers[m_buffer_index].data());
fflush(stdout);
break;
}
m_tab_completion_keep = m_buffer_col;
auto [should_escape_spaces, prefix, completions] = list_tab_completion_entries(m_buffers[m_buffer_index].sv().substring(0, m_tab_completion_keep));
BAN::sort::sort(completions.begin(), completions.end(),
[](const BAN::String& a, const BAN::String& b) {
if (auto cmp = strcmp(a.data(), b.data()))
return cmp < 0;
return a.size() < b.size();
}
);
for (size_t i = 1; i < completions.size();)
{
if (completions[i - 1] == completions[i])
completions.remove(i);
else
i++;
}
if (completions.empty())
break;
size_t all_match_len = 0;
for (;;)
{
if (completions.front().size() <= all_match_len)
break;
const char target = completions.front()[all_match_len];
bool all_matched = true;
for (const auto& completion : completions)
{
if (completion.size() > all_match_len && completion[all_match_len] == target)
continue;
all_matched = false;
break;
}
if (!all_matched)
break;
all_match_len++;
}
if (all_match_len)
{
auto completion = completions.front().sv().substring(0, all_match_len);
BAN::String temp_escaped;
if (should_escape_spaces)
{
MUST(temp_escaped.append(completion));
for (size_t i = 0; i < temp_escaped.size(); i++)
{
if (!isspace(temp_escaped[i]))
continue;
MUST(temp_escaped.insert('\\', i));
i++;
}
completion = temp_escaped.sv();
if (!m_buffers[m_buffer_index].empty() && m_buffers[m_buffer_index].back() == '\\' && completion.front() == '\\')
completion = completion.substring(1);
}
m_buffer_col += completion.size();
MUST(m_buffers[m_buffer_index].append(completion));
printf("%.*s", (int)completion.size(), completion.data());
fflush(stdout);
break;
}
if (completions.size() == 1)
{
ASSERT(all_match_len == completions.front().size());
break;
}
printf("\n");
for (size_t i = 0; i < completions.size(); i++)
{
if (i != 0)
printf(" ");
const char* format = completions[i].sv().contains(' ') ? "'%.*s%s'" : "%.*s%s";
printf(format, (int)prefix.size(), prefix.data(), completions[i].data());
}
printf("\n");
print_prompt();
printf("%s", m_buffers[m_buffer_index].data());
fflush(stdout);
if (should_escape_spaces)
{
for (auto& completion : completions)
{
for (size_t i = 0; i < completion.size(); i++)
{
if (!isspace(completion[i]))
continue;
MUST(completion.insert('\\', i));
i++;
}
}
}
m_tab_completion_keep = m_buffer_col;
m_tab_completions = BAN::move(completions);
break;
}
default:
MUST(m_buffers[m_buffer_index].insert(ch, m_buffer_col++));
if (m_buffer_col == m_buffers[m_buffer_index].size())
putchar(ch);
else
printf("%c\e[s%s\e[u", ch, m_buffers[m_buffer_index].data() + m_buffer_col);
fflush(stdout);
break;
}
}
}
Input::Input()
{
if (!s_termios_initialized)
{
tcgetattr(0, &s_original_termios);
s_raw_termios = s_original_termios;
s_raw_termios.c_lflag &= ~(ECHO | ICANON);
atexit([] { tcsetattr(0, TCSANOW, &s_original_termios); });
s_termios_initialized = true;
}
char hostname_buffer[HOST_NAME_MAX];
if (gethostname(hostname_buffer, sizeof(hostname_buffer)) == 0) {
MUST(m_hostname.append(hostname_buffer));
}
}

View File

@ -0,0 +1,36 @@
#pragma once
#include <BAN/NoCopyMove.h>
#include <BAN/String.h>
#include <BAN/Optional.h>
#include <BAN/Vector.h>
#include <sys/types.h>
#include <termios.h>
class Input
{
BAN_NON_COPYABLE(Input);
BAN_NON_MOVABLE(Input);
public:
Input();
BAN::Optional<BAN::String> get_input(BAN::Optional<BAN::StringView> custom_prompt);
private:
BAN::String parse_ps1_prompt();
private:
BAN::String m_hostname;
BAN::Vector<BAN::String> m_buffers { 1, ""_sv };
BAN::Vector<BAN::String> m_history;
size_t m_buffer_index { 0 };
size_t m_buffer_col { 0 };
BAN::Optional<ssize_t> m_tab_index;
BAN::Optional<BAN::Vector<BAN::String>> m_tab_completions;
size_t m_tab_completion_keep { 0 };
int m_waiting_utf8 { 0 };
};

View File

@ -0,0 +1,79 @@
#include "Lexer.h"
BAN::ErrorOr<BAN::Vector<Token>> tokenize_string(BAN::StringView string)
{
{
size_t i = 0;
while (i < string.size() && isspace(string[i]))
i++;
if (i >= string.size() || string[i] == '#')
return BAN::Vector<Token>();
}
constexpr auto char_to_token_type =
[](char c) -> BAN::Optional<Token::Type>
{
switch (c)
{
case '&': return Token::Type::Ampersand;
case '\\': return Token::Type::Backslash;
case '}': return Token::Type::CloseCurly;
case ')': return Token::Type::CloseParen;
case '$': return Token::Type::Dollar;
case '"': return Token::Type::DoubleQuote;
case '{': return Token::Type::OpenCurly;
case '(': return Token::Type::OpenParen;
case '|': return Token::Type::Pipe;
case ';': return Token::Type::Semicolon;
case '\'': return Token::Type::SingleQuote;
}
return {};
};
BAN::Vector<Token> result;
BAN::String current_string;
const auto append_current_if_exists =
[&]() -> BAN::ErrorOr<void>
{
if (current_string.empty())
return {};
TRY(result.emplace_back(Token::Type::String, BAN::move(current_string)));
current_string = BAN::String();
return {};
};
while (!string.empty())
{
if (isspace(string.front()))
{
TRY(append_current_if_exists());
size_t whitespace_len = 1;
while (whitespace_len < string.size() && isspace(string[whitespace_len]))
whitespace_len++;
BAN::String whitespace_str;
TRY(whitespace_str.append(string.substring(0, whitespace_len)));
TRY(result.emplace_back(Token::Type::Whitespace, BAN::move(whitespace_str)));
string = string.substring(whitespace_len);
continue;
}
if (auto token_type = char_to_token_type(string.front()); token_type.has_value())
{
TRY(append_current_if_exists());
TRY(result.emplace_back(token_type.value()));
string = string.substring(1);
continue;
}
TRY(current_string.push_back(string.front()));
string = string.substring(1);
}
TRY(append_current_if_exists());
return result;
}

View File

@ -0,0 +1,5 @@
#pragma once
#include "Token.h"
BAN::ErrorOr<BAN::Vector<Token>> tokenize_string(BAN::StringView);

View File

@ -0,0 +1,52 @@
#include "Token.h"
#include <BAN/Debug.h>
void Token::debug_dump() const
{
switch (type())
{
case Type::EOF_:
dwarnln("Token <EOF>");
break;
case Type::Ampersand:
dprintln("Token <Ampersand>");
break;
case Type::Backslash:
dprintln("Token <Backslash>");
break;
case Type::CloseCurly:
dprintln("Token <CloseCurly>");
break;
case Type::CloseParen:
dprintln("Token <CloseParen>");
break;
case Type::Dollar:
dprintln("Token <Dollar>");
break;
case Type::DoubleQuote:
dprintln("Token <DoubleQuote>");
break;
case Type::OpenCurly:
dprintln("Token <OpenCurly>");
break;
case Type::OpenParen:
dprintln("Token <OpenParen>");
break;
case Type::Pipe:
dprintln("Token <Pipe>");
break;
case Type::Semicolon:
dprintln("Token <Semicolon>");
break;
case Type::SingleQuote:
dprintln("Token <SingleQuote>");
break;
case Type::String:
dprintln("Token <String \"{}\">", string());
break;
case Type::Whitespace:
dprintln("Token <Whitespace \"{}\">", string());
break;
}
}

View File

@ -0,0 +1,84 @@
#pragma once
#include <BAN/Assert.h>
#include <BAN/String.h>
#include <BAN/Vector.h>
#include <ctype.h>
struct Token
{
public:
enum class Type
{
EOF_,
Ampersand,
Backslash,
CloseCurly,
CloseParen,
Dollar,
DoubleQuote,
OpenCurly,
OpenParen,
Pipe,
Semicolon,
SingleQuote,
String,
Whitespace,
};
Token(Type type)
: m_type(type)
{}
Token(Type type, BAN::String&& string)
: m_type(type)
{
ASSERT(type == Type::String || type == Type::Whitespace);
if (type == Type::Whitespace)
for (char c : string)
ASSERT(isspace(c));
m_value = BAN::move(string);
}
Token(Token&& other)
{
m_type = other.m_type;
m_value = other.m_value;
other.clear();
}
Token& operator=(Token&& other)
{
m_type = other.m_type;
m_value = other.m_value;
other.clear();
return *this;
}
Token(const Token&) = delete;
Token& operator=(const Token&) = delete;
~Token()
{
clear();
}
Type type() const { return m_type; }
BAN::String& string() { ASSERT(m_type == Type::String || m_type == Type::Whitespace); return m_value; }
const BAN::String& string() const { ASSERT(m_type == Type::String || m_type == Type::Whitespace); return m_value; }
void clear()
{
m_type = Type::EOF_;
m_value.clear();
}
void debug_dump() const;
private:
Type m_type { Type::EOF_ };
BAN::String m_value;
};

View File

@ -0,0 +1,702 @@
#include "Alias.h"
#include "Execute.h"
#include "Lexer.h"
#include "TokenParser.h"
#include <BAN/HashSet.h>
#include <stdio.h>
static constexpr bool can_parse_argument_from_token_type(Token::Type token_type)
{
switch (token_type)
{
case Token::Type::Whitespace:
ASSERT_NOT_REACHED();
case Token::Type::EOF_:
case Token::Type::Ampersand:
case Token::Type::CloseCurly:
case Token::Type::CloseParen:
case Token::Type::OpenCurly:
case Token::Type::OpenParen:
case Token::Type::Pipe:
case Token::Type::Semicolon:
return false;
case Token::Type::Backslash:
case Token::Type::Dollar:
case Token::Type::DoubleQuote:
case Token::Type::SingleQuote:
case Token::Type::String:
return true;
}
ASSERT_NOT_REACHED();
}
static constexpr char token_type_to_single_character(Token::Type type)
{
switch (type)
{
case Token::Type::Ampersand: return '&';
case Token::Type::Backslash: return '\\';
case Token::Type::CloseCurly: return '}';
case Token::Type::CloseParen: return ')';
case Token::Type::Dollar: return '$';
case Token::Type::DoubleQuote: return '"';
case Token::Type::OpenCurly: return '{';
case Token::Type::OpenParen: return '(';
case Token::Type::Pipe: return '|';
case Token::Type::Semicolon: return ';';
case Token::Type::SingleQuote: return '\'';
case Token::Type::String: ASSERT_NOT_REACHED();
case Token::Type::Whitespace: ASSERT_NOT_REACHED();
case Token::Type::EOF_: ASSERT_NOT_REACHED();
}
ASSERT_NOT_REACHED();
};
static constexpr BAN::Error unexpected_token_error(Token::Type type)
{
switch (type)
{
case Token::Type::EOF_:
return BAN::Error::from_literal("unexpected EOF");
case Token::Type::Ampersand:
return BAN::Error::from_literal("unexpected token &");
case Token::Type::Backslash:
return BAN::Error::from_literal("unexpected token \\");
case Token::Type::CloseCurly:
return BAN::Error::from_literal("unexpected token }");
case Token::Type::CloseParen:
return BAN::Error::from_literal("unexpected token )");
case Token::Type::Dollar:
return BAN::Error::from_literal("unexpected token $");
case Token::Type::DoubleQuote:
return BAN::Error::from_literal("unexpected token \"");
case Token::Type::OpenCurly:
return BAN::Error::from_literal("unexpected token {");
case Token::Type::Pipe:
return BAN::Error::from_literal("unexpected token |");
case Token::Type::OpenParen:
return BAN::Error::from_literal("unexpected token (");
case Token::Type::Semicolon:
return BAN::Error::from_literal("unexpected token ;");
case Token::Type::SingleQuote:
return BAN::Error::from_literal("unexpected token '");
case Token::Type::String:
return BAN::Error::from_literal("unexpected token <string>");
case Token::Type::Whitespace:
return BAN::Error::from_literal("unexpected token <whitespace>");
}
ASSERT_NOT_REACHED();
}
const Token& TokenParser::peek_token() const
{
if (m_token_stream.empty())
return m_eof_token;
ASSERT(!m_token_stream.front().empty());
return m_token_stream.front().back();
}
Token TokenParser::read_token()
{
if (m_token_stream.empty())
return Token(Token::Type::EOF_);
ASSERT(!m_token_stream.front().empty());
auto token = BAN::move(m_token_stream.front().back());
m_token_stream.front().pop_back();
if (m_token_stream.front().empty())
m_token_stream.pop();
return token;
}
void TokenParser::consume_token()
{
ASSERT(!m_token_stream.empty());
ASSERT(!m_token_stream.front().empty());
m_token_stream.front().pop_back();
if (m_token_stream.front().empty())
m_token_stream.pop();
}
BAN::ErrorOr<void> TokenParser::unget_token(Token&& token)
{
if (m_token_stream.empty())
TRY(m_token_stream.emplace());
TRY(m_token_stream.front().push_back(BAN::move(token)));
return {};
}
BAN::ErrorOr<void> TokenParser::feed_tokens(BAN::Vector<Token>&& tokens)
{
if (tokens.empty())
return {};
for (size_t i = 0; i < tokens.size() / 2; i++)
BAN::swap(tokens[i], tokens[tokens.size() - i - 1]);
TRY(m_token_stream.push(BAN::move(tokens)));
return {};
}
BAN::ErrorOr<void> TokenParser::ask_input_tokens(BAN::StringView prompt, bool add_newline)
{
if (!m_input_function)
return unexpected_token_error(Token::Type::EOF_);
auto opt_input = m_input_function(prompt);
if (!opt_input.has_value())
return unexpected_token_error(Token::Type::EOF_);
auto tokenized = TRY(tokenize_string(opt_input.release_value()));
TRY(feed_tokens(BAN::move(tokenized)));
if (add_newline)
{
auto newline_token = Token(Token::Type::String);
TRY(newline_token.string().push_back('\n'));
TRY(unget_token(BAN::move(newline_token)));
}
return {};
}
BAN::ErrorOr<CommandArgument::ArgumentPart> TokenParser::parse_backslash(bool is_quoted)
{
ASSERT(read_token().type() == Token::Type::Backslash);
auto token = read_token();
CommandArgument::FixedString fixed_string;
switch (token.type())
{
case Token::Type::EOF_:
TRY(ask_input_tokens("> ", false));
TRY(unget_token(Token(Token::Type::Backslash)));
return parse_backslash(is_quoted);
case Token::Type::Ampersand:
case Token::Type::Backslash:
case Token::Type::CloseCurly:
case Token::Type::CloseParen:
case Token::Type::Dollar:
case Token::Type::DoubleQuote:
case Token::Type::OpenCurly:
case Token::Type::OpenParen:
case Token::Type::Pipe:
case Token::Type::Semicolon:
case Token::Type::SingleQuote:
TRY(fixed_string.value.push_back(token_type_to_single_character(token.type())));
break;
case Token::Type::Whitespace:
case Token::Type::String:
{
ASSERT(!token.string().empty());
if (is_quoted)
TRY(fixed_string.value.push_back('\\'));
TRY(fixed_string.value.push_back(token.string().front()));
if (token.string().size() > 1)
{
token.string().remove(0);
TRY(unget_token(BAN::move(token)));
}
break;
}
}
return CommandArgument::ArgumentPart(BAN::move(fixed_string));
}
BAN::ErrorOr<CommandArgument::ArgumentPart> TokenParser::parse_dollar()
{
ASSERT(read_token().type() == Token::Type::Dollar);
const auto parse_dollar_string =
[](BAN::String& string) -> BAN::ErrorOr<CommandArgument::ArgumentPart>
{
if (string.empty())
return CommandArgument::ArgumentPart(CommandArgument::EnvironmentVariable());
if (isdigit(string.front()))
{
size_t number_len = 1;
while (number_len < string.size() && isdigit(string[number_len]))
number_len++;
CommandArgument::BuiltinVariable builtin;
TRY(builtin.value.append(string.sv().substring(0, number_len)));
for (size_t i = 0; i < number_len; i++)
string.remove(0);
return CommandArgument::ArgumentPart(BAN::move(builtin));
}
switch (string.front())
{
case '$':
case '_':
case '@':
case '*':
case '#':
case '-':
case '?':
case '!':
{
CommandArgument::BuiltinVariable builtin;
TRY(builtin.value.push_back(string.front()));
string.remove(0);
return CommandArgument::ArgumentPart(BAN::move(builtin));
}
}
if (isalpha(string.front()))
{
size_t env_len = 1;
while (env_len < string.size() && (isalnum(string[env_len]) || string[env_len] == '_'))
env_len++;
CommandArgument::EnvironmentVariable environment;
TRY(environment.value.append(string.sv().substring(0, env_len)));
for (size_t i = 0; i < env_len; i++)
string.remove(0);
return CommandArgument::ArgumentPart(BAN::move(environment));
}
CommandArgument::FixedString fixed_string;
TRY(fixed_string.value.push_back('$'));
return CommandArgument::ArgumentPart(BAN::move(fixed_string));
};
switch (peek_token().type())
{
case Token::Type::EOF_:
case Token::Type::Ampersand:
case Token::Type::Backslash:
case Token::Type::CloseCurly:
case Token::Type::CloseParen:
case Token::Type::DoubleQuote:
case Token::Type::Pipe:
case Token::Type::Semicolon:
case Token::Type::SingleQuote:
case Token::Type::Whitespace:
{
CommandArgument::FixedString fixed_string;
TRY(fixed_string.value.push_back('$'));
return CommandArgument::ArgumentPart(BAN::move(fixed_string));
}
case Token::Type::Dollar:
{
consume_token();
CommandArgument::BuiltinVariable builtin_variable;
TRY(builtin_variable.value.push_back('$'));
return CommandArgument::ArgumentPart(BAN::move(builtin_variable));
}
case Token::Type::OpenCurly:
{
consume_token();
BAN::String input;
for (auto token = read_token(); token.type() != Token::Type::CloseCurly; token = read_token())
{
if (token.type() == Token::Type::EOF_)
return BAN::Error::from_literal("missing closing curly brace");
if (token.type() == Token::Type::String)
TRY(input.append(token.string()));
else if (token.type() == Token::Type::Dollar)
TRY(input.push_back('$'));
else
return BAN::Error::from_literal("expected closing curly brace");
}
auto result = TRY(parse_dollar_string(input));
if (!input.empty())
return BAN::Error::from_literal("bad substitution");
return result;
}
case Token::Type::OpenParen:
{
consume_token();
auto command_tree = TRY(parse_command_tree());
if (auto token = read_token(); token.type() != Token::Type::CloseParen)
return BAN::Error::from_literal("expected closing parenthesis");
return CommandArgument::ArgumentPart(BAN::move(command_tree));
}
case Token::Type::String:
{
auto token = read_token();
auto string = BAN::move(token.string());
auto result = TRY(parse_dollar_string(string));
if (!string.empty())
{
auto remaining = Token(Token::Type::String);
remaining.string() = BAN::move(string);
TRY(unget_token(BAN::move(remaining)));
}
return result;
}
}
ASSERT_NOT_REACHED();
}
BAN::ErrorOr<CommandArgument::ArgumentPart> TokenParser::parse_single_quote()
{
ASSERT(read_token().type() == Token::Type::SingleQuote);
CommandArgument::FixedString fixed_string;
for (auto token = read_token();; token = read_token())
{
switch (token.type())
{
case Token::Type::EOF_:
TRY(ask_input_tokens("quote> ", true));
break;
case Token::Type::Ampersand:
case Token::Type::Backslash:
case Token::Type::CloseCurly:
case Token::Type::CloseParen:
case Token::Type::Dollar:
case Token::Type::DoubleQuote:
case Token::Type::OpenCurly:
case Token::Type::OpenParen:
case Token::Type::Pipe:
case Token::Type::Semicolon:
TRY(fixed_string.value.push_back(token_type_to_single_character(token.type())));
break;
case Token::Type::String:
case Token::Type::Whitespace:
TRY(fixed_string.value.append(token.string()));
break;
case Token::Type::SingleQuote:
return CommandArgument::ArgumentPart(BAN::move(fixed_string));
}
}
}
BAN::ErrorOr<CommandArgument> TokenParser::parse_argument()
{
using FixedString = CommandArgument::FixedString;
const auto token_type = peek_token().type();
if (!can_parse_argument_from_token_type(token_type))
return unexpected_token_error(token_type);
CommandArgument result;
bool is_in_double_quotes = false;
for (auto token_type = peek_token().type(); token_type != Token::Type::EOF_ || is_in_double_quotes; token_type = peek_token().type())
{
CommandArgument::ArgumentPart new_part;
switch (token_type)
{
case Token::Type::EOF_:
ASSERT(is_in_double_quotes);
TRY(ask_input_tokens("dquote> ", true));
new_part = FixedString(); // do continue
break;
case Token::Type::Ampersand:
case Token::Type::CloseCurly:
case Token::Type::CloseParen:
case Token::Type::OpenCurly:
case Token::Type::OpenParen:
case Token::Type::Pipe:
case Token::Type::Semicolon:
if (is_in_double_quotes)
{
new_part = FixedString();
TRY(new_part.get<FixedString>().value.push_back(token_type_to_single_character(token_type)));
consume_token();
}
break;
case Token::Type::Whitespace:
if (is_in_double_quotes)
{
new_part = FixedString();
TRY(new_part.get<FixedString>().value.append(peek_token().string()));
consume_token();
}
break;
case Token::Type::Backslash:
new_part = TRY(parse_backslash(is_in_double_quotes));
break;
case Token::Type::DoubleQuote:
is_in_double_quotes = !is_in_double_quotes;
new_part = FixedString(); // do continue
consume_token();
break;
case Token::Type::Dollar:
new_part = TRY(parse_dollar());
break;
case Token::Type::SingleQuote:
new_part = TRY(parse_single_quote());
break;
case Token::Type::String:
new_part = CommandArgument::ArgumentPart(FixedString {});
TRY(new_part.get<FixedString>().value.append(peek_token().string()));
consume_token();
break;
}
if (!new_part.has_value())
break;
if (new_part.has<FixedString>())
{
auto& fixed_string = new_part.get<FixedString>();
// discard empty fixed strings
if (fixed_string.value.empty())
continue;
// combine consecutive fixed strings
if (!result.parts.empty() && result.parts.back().has<FixedString>())
{
TRY(result.parts.back().get<FixedString>().value.append(fixed_string.value));
continue;
}
}
TRY(result.parts.push_back(BAN::move(new_part)));
}
return result;
}
BAN::ErrorOr<SingleCommand> TokenParser::parse_single_command()
{
SingleCommand result;
while (peek_token().type() == Token::Type::Whitespace)
consume_token();
while (peek_token().type() == Token::Type::String)
{
BAN::String env_name;
const auto& string = peek_token().string();
if (!isalpha(string.front()))
break;
const auto env_len = string.sv().find([](char ch) { return !(isalnum(ch) || ch == '_'); });
if (!env_len.has_value() || string[*env_len] != '=')
break;
TRY(env_name.append(string.sv().substring(0, *env_len)));
auto full_value = TRY(parse_argument());
ASSERT(!full_value.parts.empty());
ASSERT(full_value.parts.front().has<CommandArgument::FixedString>());
auto& first_arg = full_value.parts.front().get<CommandArgument::FixedString>();
ASSERT(first_arg.value.sv().starts_with(env_name));
ASSERT(first_arg.value[*env_len] == '=');
for (size_t i = 0; i < *env_len + 1; i++)
first_arg.value.remove(0);
if (first_arg.value.empty())
full_value.parts.remove(0);
SingleCommand::EnvironmentVariable environment_variable;
environment_variable.name = BAN::move(env_name);
environment_variable.value = BAN::move(full_value);
TRY(result.environment.emplace_back(BAN::move(environment_variable)));
while (peek_token().type() == Token::Type::Whitespace)
consume_token();
}
BAN::HashSet<BAN::String> used_aliases;
while (peek_token().type() == Token::Type::String)
{
auto token = read_token();
bool can_be_alias = false;
switch (peek_token().type())
{
case Token::Type::EOF_:
case Token::Type::Ampersand:
case Token::Type::CloseParen:
case Token::Type::Pipe:
case Token::Type::Semicolon:
case Token::Type::Whitespace:
can_be_alias = true;
break;
case Token::Type::Backslash:
case Token::Type::CloseCurly:
case Token::Type::Dollar:
case Token::Type::DoubleQuote:
case Token::Type::OpenCurly:
case Token::Type::OpenParen:
case Token::Type::SingleQuote:
case Token::Type::String:
can_be_alias = false;
break;
}
if (!can_be_alias)
{
TRY(unget_token(BAN::move(token)));
break;
}
if (used_aliases.contains(token.string()))
{
TRY(unget_token(BAN::move(token)));
break;
}
auto opt_alias = Alias::get().get_alias(token.string().sv());
if (!opt_alias.has_value())
{
TRY(unget_token(BAN::move(token)));
break;
}
auto tokenized_alias = TRY(tokenize_string(opt_alias.value()));
for (size_t i = tokenized_alias.size(); i > 0; i--)
TRY(unget_token(BAN::move(tokenized_alias[i - 1])));
TRY(used_aliases.insert(TRY(BAN::String::formatted("{}", token.string()))));
while (peek_token().type() == Token::Type::Whitespace)
consume_token();
}
while (peek_token().type() != Token::Type::EOF_)
{
while (peek_token().type() == Token::Type::Whitespace)
consume_token();
auto argument = TRY(parse_argument());
TRY(result.arguments.push_back(BAN::move(argument)));
while (peek_token().type() == Token::Type::Whitespace)
consume_token();
if (!can_parse_argument_from_token_type(peek_token().type()))
break;
}
return result;
}
BAN::ErrorOr<PipedCommand> TokenParser::parse_piped_command()
{
PipedCommand result;
result.background = false;
while (peek_token().type() != Token::Type::EOF_)
{
auto single_command = TRY(parse_single_command());
TRY(result.commands.push_back(BAN::move(single_command)));
const auto token_type = peek_token().type();
if (token_type != Token::Type::Pipe && token_type != Token::Type::Ampersand)
break;
auto temp_token = read_token();
if (peek_token().type() == temp_token.type())
{
TRY(unget_token(BAN::move(temp_token)));
break;
}
if (temp_token.type() == Token::Type::Ampersand)
{
result.background = true;
break;
}
}
return result;
}
BAN::ErrorOr<CommandTree> TokenParser::parse_command_tree()
{
CommandTree result;
auto next_condition = ConditionalCommand::Condition::Always;
while (peek_token().type() != Token::Type::EOF_)
{
ConditionalCommand conditional_command;
conditional_command.command = TRY(parse_piped_command());
conditional_command.condition = next_condition;
TRY(result.commands.push_back(BAN::move(conditional_command)));
while (peek_token().type() == Token::Type::Whitespace)
consume_token();
if (peek_token().type() == Token::Type::EOF_)
break;
bool should_break = false;
const auto token_type = peek_token().type();
switch (token_type)
{
case Token::Type::Semicolon:
consume_token();
next_condition = ConditionalCommand::Condition::Always;
break;
case Token::Type::Ampersand:
case Token::Type::Pipe:
consume_token();
if (read_token().type() != token_type)
return BAN::Error::from_literal("expected double '&' or '|'");
next_condition = (token_type == Token::Type::Ampersand)
? ConditionalCommand::Condition::OnSuccess
: ConditionalCommand::Condition::OnFailure;
break;
default:
should_break = true;
break;
}
if (should_break)
break;
}
return result;
}
BAN::ErrorOr<void> TokenParser::run(BAN::Vector<Token>&& tokens)
{
TRY(feed_tokens(BAN::move(tokens)));
auto command_tree = TRY(parse_command_tree());
const auto token_type = peek_token().type();
while (!m_token_stream.empty())
m_token_stream.pop();
if (token_type != Token::Type::EOF_)
return unexpected_token_error(token_type);
TRY(m_execute.execute_command(command_tree));
return {};
}
bool TokenParser::main_loop(bool break_on_error)
{
for (;;)
{
auto opt_input = m_input_function({});
if (!opt_input.has_value())
break;
auto tokenized_input = tokenize_string(opt_input.release_value());
if (tokenized_input.is_error())
{
fprintf(stderr, "banan-sh: %s\n", tokenized_input.error().get_message());
if (break_on_error)
return false;
continue;
}
if (auto ret = run(tokenized_input.release_value()); ret.is_error())
{
fprintf(stderr, "banan-sh: %s\n", ret.error().get_message());
if (break_on_error)
return false;
continue;
}
}
return true;
}

View File

@ -0,0 +1,57 @@
#pragma once
#include "CommandTypes.h"
#include "Execute.h"
#include "Token.h"
#include <BAN/Function.h>
#include <BAN/NoCopyMove.h>
#include <BAN/Optional.h>
#include <BAN/Queue.h>
#include <BAN/Vector.h>
class TokenParser
{
BAN_NON_COPYABLE(TokenParser);
BAN_NON_MOVABLE(TokenParser);
public:
using InputFunction = BAN::Function<BAN::Optional<BAN::String>(BAN::Optional<BAN::StringView>)>;
public:
TokenParser(const InputFunction& input_function)
: m_input_function(input_function)
{ }
Execute& execute() { return m_execute; }
const Execute& execute() const { return m_execute; }
[[nodiscard]] bool main_loop(bool break_on_error);
private:
const Token& peek_token() const;
Token read_token();
void consume_token();
BAN::ErrorOr<void> feed_tokens(BAN::Vector<Token>&& tokens);
BAN::ErrorOr<void> unget_token(Token&& token);
BAN::ErrorOr<void> ask_input_tokens(BAN::StringView prompt, bool add_newline);
BAN::ErrorOr<void> run(BAN::Vector<Token>&&);
BAN::ErrorOr<CommandArgument::ArgumentPart> parse_backslash(bool is_quoted);
BAN::ErrorOr<CommandArgument::ArgumentPart> parse_dollar();
BAN::ErrorOr<CommandArgument::ArgumentPart> parse_single_quote();
BAN::ErrorOr<CommandArgument> parse_argument();
BAN::ErrorOr<SingleCommand> parse_single_command();
BAN::ErrorOr<PipedCommand> parse_piped_command();
BAN::ErrorOr<CommandTree> parse_command_tree();
private:
Execute m_execute;
Token m_eof_token { Token::Type::EOF_ };
BAN::Queue<BAN::Vector<Token>> m_token_stream;
InputFunction m_input_function;
};

File diff suppressed because it is too large Load Diff

View File

@ -330,7 +330,7 @@ void WindowServer::invalidate(Rectangle area)
ASSERT(m_background_image->height() == (uint64_t)m_framebuffer.height); ASSERT(m_background_image->height() == (uint64_t)m_framebuffer.height);
for (int32_t y = area.y; y < area.y + area.height; y++) for (int32_t y = area.y; y < area.y + area.height; y++)
for (int32_t x = area.x; x < area.x + area.width; x++) for (int32_t x = area.x; x < area.x + area.width; x++)
m_framebuffer.mmap[y * m_framebuffer.width + x] = m_background_image->get_color(x, y).as_rgba(); m_framebuffer.mmap[y * m_framebuffer.width + x] = m_background_image->get_color(x, y).as_argb();
} }
else else
{ {

9
userspace/programs/env/CMakeLists.txt vendored Normal file
View File

@ -0,0 +1,9 @@
set(SOURCES
main.cpp
)
add_executable(env ${SOURCES})
banan_link_library(env ban)
banan_link_library(env libc)
install(TARGETS env OPTIONAL)

13
userspace/programs/env/main.cpp vendored Normal file
View File

@ -0,0 +1,13 @@
#include <stdio.h>
extern char** environ;
int main()
{
if (!environ)
return 0;
char** current = environ;
while (*current)
printf("%s\n", *current++);
return 0;
}

View File

@ -1,6 +1,7 @@
#include <LibImage/Image.h> #include <LibImage/Image.h>
#include <fcntl.h> #include <fcntl.h>
#include <inttypes.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <sys/framebuffer.h> #include <sys/framebuffer.h>
@ -78,10 +79,13 @@ int main(int argc, char** argv)
return usage(argv[0], 1); return usage(argv[0], 1);
bool scale = false; bool scale = false;
bool benchmark = false;
for (int i = 1; i < argc - 1; i++) for (int i = 1; i < argc - 1; i++)
{ {
if (strcmp(argv[i], "-s") == 0 || strcmp(argv[i], "--scale") == 0) if (strcmp(argv[i], "-s") == 0 || strcmp(argv[i], "--scale") == 0)
scale = true; scale = true;
else if (strcmp(argv[i], "-b") == 0 || strcmp(argv[i], "--benchmark") == 0)
benchmark = true;
else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0)
return usage(argv[0], 0); return usage(argv[0], 0);
else else
@ -90,7 +94,11 @@ int main(int argc, char** argv)
auto image_path = BAN::StringView(argv[argc - 1]); auto image_path = BAN::StringView(argv[argc - 1]);
timespec load_start, load_end;
clock_gettime(CLOCK_MONOTONIC, &load_start);
auto image_or_error = LibImage::Image::load_from_file(image_path); auto image_or_error = LibImage::Image::load_from_file(image_path);
clock_gettime(CLOCK_MONOTONIC, &load_end);
if (image_or_error.is_error()) if (image_or_error.is_error())
{ {
fprintf(stderr, "Could not load image '%.*s': %s\n", fprintf(stderr, "Could not load image '%.*s': %s\n",
@ -101,6 +109,34 @@ int main(int argc, char** argv)
return 1; return 1;
} }
if (benchmark)
{
const uint64_t start_ms = load_start.tv_sec * 1000 + load_start.tv_nsec / 1'000'000;
const uint64_t end_ms = load_end.tv_sec * 1000 + load_end.tv_nsec / 1'000'000;
const uint64_t duration_ms = end_ms - start_ms;
printf("image load took %" PRIu64 ".%03" PRIu64 " s\n", duration_ms / 1000, duration_ms % 1000);
if (scale)
{
timespec scale_start, scale_end;
clock_gettime(CLOCK_MONOTONIC, &scale_start);
auto scaled = MUST(image_or_error.value()->resize(1920, 1080, LibImage::Image::ResizeAlgorithm::Linear));
clock_gettime(CLOCK_MONOTONIC, &scale_end);
const uint64_t start_ms = scale_start.tv_sec * 1000 + scale_start.tv_nsec / 1'000'000;
const uint64_t end_ms = scale_end.tv_sec * 1000 + scale_end.tv_nsec / 1'000'000;
const uint64_t duration_ms = end_ms - start_ms;
printf("image scale (%" PRIu64 "x%" PRIu64 " to %dx%d) took %" PRIu64 ".%03" PRIu64 " s\n",
image_or_error.value()->width(), image_or_error.value()->height(),
1920, 1080,
duration_ms / 1000, duration_ms % 1000
);
}
return 0;
}
render_to_framebuffer(image_or_error.release_value(), scale); render_to_framebuffer(image_or_error.release_value(), scale);
for (;;) for (;;)