diff --git a/kernel/arch/i386/MMU.cpp b/kernel/arch/i386/MMU.cpp index a902655ba9..b16e529842 100644 --- a/kernel/arch/i386/MMU.cpp +++ b/kernel/arch/i386/MMU.cpp @@ -13,10 +13,13 @@ static MMU* s_instance = nullptr; -void MMU::intialize() +void MMU::initialize() { ASSERT(s_instance == nullptr); s_instance = new MMU(); + ASSERT(s_instance); + s_instance->initialize_kernel(); + s_instance->load(); } MMU& MMU::get() @@ -34,7 +37,7 @@ static uint64_t* allocate_page_aligned_page() return page; } -MMU::MMU() +void MMU::initialize_kernel() { m_highest_paging_struct = (uint64_t*)kmalloc(sizeof(uint64_t) * 4, 32); ASSERT(m_highest_paging_struct); @@ -64,8 +67,50 @@ MMU::MMU() // causes page fault :) uint64_t* page_table1 = (uint64_t*)(page_directory1[0] & PAGE_MASK); page_table1[0] = 0; +} - // reload this new pdpt +MMU::MMU() +{ + if (s_instance == nullptr) + return; + + // Here we copy the s_instances paging structs since they are + // global for every process + + uint64_t* global_pdpt = s_instance->m_highest_paging_struct; + + uint64_t* pdpt = (uint64_t*)kmalloc(sizeof(uint64_t) * 4, 32); + ASSERT(pdpt); + + for (uint32_t pdpte = 0; pdpte < 4; pdpte++) + { + if (!(global_pdpt[pdpte] & Flags::Present)) + continue; + + uint64_t* global_pd = (uint64_t*)(global_pdpt[pdpte] & PAGE_MASK); + + uint64_t* pd = allocate_page_aligned_page(); + pdpt[pdpte] = (uint64_t)pd | (global_pdpt[pdpte] & ~PAGE_MASK); + + for (uint32_t pde = 0; pde < 512; pde++) + { + if (!(global_pd[pde] & Flags::Present)) + continue; + + uint64_t* global_pt = (uint64_t*)(global_pd[pde] & PAGE_MASK); + + uint64_t* pt = allocate_page_aligned_page(); + pd[pde] = (uint64_t)pt | (global_pd[pde] & ~PAGE_MASK); + + memcpy(pt, global_pt, PAGE_SIZE); + } + } + + m_highest_paging_struct = pdpt; +} + +void MMU::load() +{ asm volatile("movl %0, %%cr3" :: "r"(m_highest_paging_struct)); } @@ -92,8 +137,6 @@ void MMU::map_page(uintptr_t address, uint8_t flags) uint64_t* page_table = (uint64_t*)(page_directory[pde] & PAGE_MASK); page_table[pte] = address | flags; - - asm volatile("invlpg (%0)" :: "r"(address) : "memory"); } void MMU::map_range(uintptr_t address, ptrdiff_t size, uint8_t flags) @@ -125,8 +168,6 @@ void MMU::unmap_page(uintptr_t address) page_table[pte] = 0; // TODO: Unallocate the page table if this was the only allocated page - - asm volatile("invlpg (%0)" :: "r"(address & PAGE_MASK) : "memory"); } void MMU::unmap_range(uintptr_t address, ptrdiff_t size) diff --git a/kernel/arch/x86_64/MMU.cpp b/kernel/arch/x86_64/MMU.cpp index 94c16f4f4b..311db566d4 100644 --- a/kernel/arch/x86_64/MMU.cpp +++ b/kernel/arch/x86_64/MMU.cpp @@ -8,15 +8,18 @@ #define CLEANUP_STRUCTURE(s) \ for (uint64_t i = 0; i < 512; i++) \ if (s[i] & Flags::Present) \ - goto cleanup_done; \ + return; \ kfree(s) static MMU* s_instance = nullptr; -void MMU::intialize() +void MMU::initialize() { ASSERT(s_instance == nullptr); s_instance = new MMU(); + ASSERT(s_instance); + s_instance->initialize_kernel(); + s_instance->load(); } MMU& MMU::get() @@ -33,10 +36,14 @@ static uint64_t* allocate_page_aligned_page() return (uint64_t*)page; } -MMU::MMU() +extern uint8_t g_kernel_end[]; + +void MMU::initialize_kernel() { // FIXME: We should just identity map until g_kernel_end + ASSERT((uintptr_t)g_kernel_end <= 6 * (1 << 20)); + // Identity map from 0 -> 6 MiB m_highest_paging_struct = allocate_page_aligned_page(); @@ -57,9 +64,55 @@ MMU::MMU() // Unmap 0 -> 4 KiB uint64_t* pt1 = (uint64_t*)(pd[0] & PAGE_MASK); pt1[0] = 0; +} - // Load the new pml4 - asm volatile("movq %0, %%cr3" :: "r"(m_highest_paging_struct)); +MMU::MMU() +{ + if (s_instance == nullptr) + return; + + // Here we copy the s_instances paging structs since they are + // global for every process + + uint64_t* global_pml4 = s_instance->m_highest_paging_struct; + + uint64_t* pml4 = allocate_page_aligned_page(); + for (uint32_t pml4e = 0; pml4e < 512; pml4e++) + { + if (!(global_pml4[pml4e] & Flags::Present)) + continue; + + uint64_t* global_pdpt = (uint64_t*)(global_pml4[pml4e] & PAGE_MASK); + + uint64_t* pdpt = allocate_page_aligned_page(); + pml4[pml4e] = (uint64_t)pdpt | (global_pml4[pml4e] & ~PAGE_MASK); + + for (uint32_t pdpte = 0; pdpte < 512; pdpte++) + { + if (!(global_pdpt[pdpte] & Flags::Present)) + continue; + + uint64_t* global_pd = (uint64_t*)(global_pdpt[pdpte] & PAGE_MASK); + + uint64_t* pd = allocate_page_aligned_page(); + pdpt[pdpte] = (uint64_t)pd | (global_pdpt[pdpte] & ~PAGE_MASK); + + for (uint32_t pde = 0; pde < 512; pde++) + { + if (!(global_pd[pde] & Flags::Present)) + continue; + + uint64_t* global_pt = (uint64_t*)(global_pd[pde] & PAGE_MASK); + + uint64_t* pt = allocate_page_aligned_page(); + pd[pde] = (uint64_t)pt | (global_pd[pde] & ~PAGE_MASK); + + memcpy(pt, global_pt, PAGE_SIZE); + } + } + } + + m_highest_paging_struct = pml4; } MMU::~MMU() @@ -88,6 +141,11 @@ MMU::~MMU() kfree(pml4); } +void MMU::load() +{ + asm volatile("movq %0, %%cr3" :: "r"(m_highest_paging_struct)); +} + void MMU::map_page(uintptr_t address, uint8_t flags) { address &= PAGE_MASK; @@ -137,9 +195,6 @@ void MMU::unmap_page(uintptr_t address) pdpt[pdpte] = 0; CLEANUP_STRUCTURE(pdpt); pml4[pml4e] = 0; -cleanup_done: - - asm volatile("invlpg (%0)" :: "r"(address) : "memory"); } void MMU::unmap_range(uintptr_t address, ptrdiff_t size) @@ -159,7 +214,6 @@ void MMU::map_page_at(paddr_t paddr, vaddr_t vaddr, uint8_t flags) ASSERT((vaddr & ~PAGE_MASK) == 0);; ASSERT(flags & Flags::Present); - bool should_invalidate = false; uint64_t pml4e = (vaddr >> 39) & 0x1FF; uint64_t pdpte = (vaddr >> 30) & 0x1FF; @@ -172,7 +226,6 @@ void MMU::map_page_at(paddr_t paddr, vaddr_t vaddr, uint8_t flags) if (!(pml4[pml4e] & Flags::Present)) pml4[pml4e] = (uint64_t)allocate_page_aligned_page(); pml4[pml4e] = (pml4[pml4e] & PAGE_MASK) | flags; - should_invalidate = true; } uint64_t* pdpt = (uint64_t*)(pml4[pml4e] & PAGE_MASK); @@ -181,7 +234,6 @@ void MMU::map_page_at(paddr_t paddr, vaddr_t vaddr, uint8_t flags) if (!(pdpt[pdpte] & Flags::Present)) pdpt[pdpte] = (uint64_t)allocate_page_aligned_page(); pdpt[pdpte] = (pdpt[pdpte] & PAGE_MASK) | flags; - should_invalidate = true; } uint64_t* pd = (uint64_t*)(pdpt[pdpte] & PAGE_MASK); @@ -190,16 +242,9 @@ void MMU::map_page_at(paddr_t paddr, vaddr_t vaddr, uint8_t flags) if (!(pd[pde] & Flags::Present)) pd[pde] = (uint64_t)allocate_page_aligned_page(); pd[pde] = (pd[pde] & PAGE_MASK) | flags; - should_invalidate = true; } uint64_t* pt = (uint64_t*)(pd[pde] & PAGE_MASK); if ((pt[pte] & flags) != flags) - { pt[pte] = paddr | flags; - should_invalidate = true; - } - - if (should_invalidate) - asm volatile("movq %0, %%cr3" :: "r"(m_highest_paging_struct)); } diff --git a/kernel/include/kernel/Memory/MMU.h b/kernel/include/kernel/Memory/MMU.h index 04552d09c2..38e1c47da2 100644 --- a/kernel/include/kernel/Memory/MMU.h +++ b/kernel/include/kernel/Memory/MMU.h @@ -17,7 +17,7 @@ public: using paddr_t = uintptr_t; public: - static void intialize(); + static void initialize(); static MMU& get(); MMU(); @@ -31,6 +31,11 @@ public: void map_page_at(paddr_t, vaddr_t, uint8_t); + void load(); + +private: + void initialize_kernel(); + private: uint64_t* m_highest_paging_struct; }; diff --git a/kernel/include/kernel/Process.h b/kernel/include/kernel/Process.h index c5934856d2..826af39d7e 100644 --- a/kernel/include/kernel/Process.h +++ b/kernel/include/kernel/Process.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -54,6 +55,8 @@ namespace Kernel static Process& current() { return Thread::current().process(); } + MMU& mmu() { return m_mmu ? *m_mmu : MMU::get(); } + private: Process(pid_t); static Process* create_process(); @@ -81,6 +84,7 @@ namespace Kernel BAN::String m_working_directory; BAN::Vector m_threads; + MMU* m_mmu { nullptr }; TTY* m_tty { nullptr }; }; diff --git a/kernel/include/kernel/Thread.h b/kernel/include/kernel/Thread.h index 14204206a4..a4746e4175 100644 --- a/kernel/include/kernel/Thread.h +++ b/kernel/include/kernel/Thread.h @@ -47,6 +47,7 @@ namespace Kernel static Thread& current() ; Process& process(); + bool has_process() const { return m_process; } private: Thread(pid_t tid, Process*); diff --git a/kernel/kernel/Process.cpp b/kernel/kernel/Process.cpp index 0730e9bcd5..fb413bf98d 100644 --- a/kernel/kernel/Process.cpp +++ b/kernel/kernel/Process.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include @@ -39,12 +38,14 @@ namespace Kernel auto* process = create_process(); TRY(process->m_working_directory.push_back('/')); TRY(process->init_stdio()); - + process->m_mmu = new MMU(); + ASSERT(process->m_mmu); + TRY(process->add_thread( [](void* entry_func) { Thread& current = Thread::current(); - MMU::get().map_range(current.stack_base(), current.stack_size(), MMU::Flags::UserSupervisor | MMU::Flags::ReadWrite | MMU::Flags::Present); + Process::current().m_mmu->map_range(current.stack_base(), current.stack_size(), MMU::Flags::UserSupervisor | MMU::Flags::ReadWrite | MMU::Flags::Present); current.jump_userspace((uintptr_t)entry_func); ASSERT_NOT_REACHED(); }, (void*)entry diff --git a/kernel/kernel/Scheduler.cpp b/kernel/kernel/Scheduler.cpp index 6ad00a993c..99d1f6ce0c 100644 --- a/kernel/kernel/Scheduler.cpp +++ b/kernel/kernel/Scheduler.cpp @@ -169,6 +169,11 @@ namespace Kernel Thread& current = current_thread(); + if (current.has_process()) + current.process().mmu().load(); + else + MMU::get().load(); + switch (current.state()) { case Thread::State::NotStarted: @@ -245,7 +250,6 @@ namespace Kernel void Scheduler::set_current_process_done() { - VERIFY_STI(); DISABLE_INTERRUPTS(); pid_t pid = m_current_thread->thread->process().pid(); diff --git a/kernel/kernel/kernel.cpp b/kernel/kernel/kernel.cpp index ba5d508b85..851af4bad2 100644 --- a/kernel/kernel/kernel.cpp +++ b/kernel/kernel/kernel.cpp @@ -134,7 +134,7 @@ extern "C" void kernel_main() IDT::initialize(); dprintln("IDT initialized"); - MMU::intialize(); + MMU::initialize(); dprintln("MMU initialized"); TerminalDriver* terminal_driver = VesaTerminalDriver::create(); @@ -194,8 +194,8 @@ static void init2(void* tty1) }, nullptr )); - //jump_userspace(); - //return; + jump_userspace(); + return; MUST(Process::create_kernel( [](void*)