Kernel: Cleanup and fix page tables and better TLB shootdown

This commit is contained in:
2026-04-03 01:53:30 +03:00
parent f77aa65dc5
commit 7d8f7753d5
4 changed files with 170 additions and 156 deletions

View File

@@ -21,6 +21,11 @@ namespace Kernel
SpinLock PageTable::s_fast_page_lock;
constexpr uint64_t s_page_flag_mask = 0x8000000000000FFF;
constexpr uint64_t s_page_addr_mask = ~s_page_flag_mask;
static bool s_is_post_heap_done = false;
static PageTable* s_kernel = nullptr;
static bool s_has_nxe = false;
static bool s_has_pge = false;
@@ -67,7 +72,7 @@ namespace Kernel
void PageTable::initialize_post_heap()
{
// NOTE: this is no-op as our 32 bit target does not use hhdm
s_is_post_heap_done = true;
}
void PageTable::initial_load()
@@ -141,9 +146,9 @@ namespace Kernel
}
template<typename T>
static vaddr_t P2V(const T paddr)
static uint64_t* P2V(const T paddr)
{
return (paddr_t)paddr - g_boot_info.kernel_paddr + KERNEL_OFFSET;
return reinterpret_cast<uint64_t*>(reinterpret_cast<paddr_t>(paddr) - g_boot_info.kernel_paddr + KERNEL_OFFSET);
}
void PageTable::initialize_kernel()
@@ -194,10 +199,10 @@ namespace Kernel
constexpr uint64_t pdpte = (fast_page() >> 30) & 0x1FF;
constexpr uint64_t pde = (fast_page() >> 21) & 0x1FF;
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct));
const uint64_t* pdpt = P2V(m_highest_paging_struct);
ASSERT(pdpt[pdpte] & Flags::Present);
uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte]) & PAGE_ADDR_MASK);
uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
ASSERT(!(pd[pde] & Flags::Present));
pd[pde] = V2P(allocate_zeroed_page_aligned_page()) | Flags::ReadWrite | Flags::Present;
}
@@ -214,9 +219,9 @@ namespace Kernel
constexpr uint64_t pde = (fast_page() >> 21) & 0x1FF;
constexpr uint64_t pte = (fast_page() >> 12) & 0x1FF;
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(s_kernel->m_highest_paging_struct));
uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte] & PAGE_ADDR_MASK));
uint64_t* pt = reinterpret_cast<uint64_t*>(P2V(pd[pde] & PAGE_ADDR_MASK));
uint64_t* pdpt = P2V(s_kernel->m_highest_paging_struct);
uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
ASSERT(!(pt[pte] & Flags::Present));
pt[pte] = paddr | Flags::ReadWrite | Flags::Present;
@@ -234,9 +239,9 @@ namespace Kernel
constexpr uint64_t pde = (fast_page() >> 21) & 0x1FF;
constexpr uint64_t pte = (fast_page() >> 12) & 0x1FF;
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(s_kernel->m_highest_paging_struct));
uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte] & PAGE_ADDR_MASK));
uint64_t* pt = reinterpret_cast<uint64_t*>(P2V(pd[pde] & PAGE_ADDR_MASK));
uint64_t* pdpt = P2V(s_kernel->m_highest_paging_struct);
uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
ASSERT(pt[pte] & Flags::Present);
pt[pte] = 0;
@@ -263,7 +268,7 @@ namespace Kernel
m_highest_paging_struct = V2P(kmalloc(32, 32, true));
ASSERT(m_highest_paging_struct);
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct));
uint64_t* pdpt = P2V(m_highest_paging_struct);
pdpt[0] = 0;
pdpt[1] = 0;
pdpt[2] = 0;
@@ -276,18 +281,17 @@ namespace Kernel
if (m_highest_paging_struct == 0)
return;
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct));
uint64_t* pdpt = P2V(m_highest_paging_struct);
for (uint32_t pdpte = 0; pdpte < 3; pdpte++)
{
if (!(pdpt[pdpte] & Flags::Present))
continue;
uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte] & PAGE_ADDR_MASK));
uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
for (uint32_t pde = 0; pde < 512; pde++)
{
if (!(pd[pde] & Flags::Present))
continue;
kfree(reinterpret_cast<uint64_t*>(P2V(pd[pde] & PAGE_ADDR_MASK)));
kfree(P2V(pd[pde] & s_page_addr_mask));
}
kfree(pd);
}
@@ -298,15 +302,43 @@ namespace Kernel
{
SpinLockGuard _(m_lock);
ASSERT(m_highest_paging_struct < 0x100000000);
const uint32_t pdpt_lo = m_highest_paging_struct;
asm volatile("movl %0, %%cr3" :: "r"(pdpt_lo));
asm volatile("movl %0, %%cr3" :: "r"(static_cast<uint32_t>(m_highest_paging_struct)));
Processor::set_current_page_table(this);
}
void PageTable::invalidate(vaddr_t vaddr, bool send_smp_message)
void PageTable::invalidate_range(vaddr_t vaddr, size_t pages, bool send_smp_message)
{
ASSERT(vaddr % PAGE_SIZE == 0);
asm volatile("invlpg (%0)" :: "r"(vaddr) : "memory");
const bool is_userspace = (vaddr < KERNEL_OFFSET);
if (is_userspace && this != &PageTable::current())
;
else if (pages <= 32 || !s_is_post_heap_done)
{
for (size_t i = 0; i < pages; i++, vaddr += PAGE_SIZE)
asm volatile("invlpg (%0)" :: "r"(vaddr));
}
else if (is_userspace || !s_has_pge)
{
asm volatile("movl %0, %%cr3" :: "r"(static_cast<uint32_t>(m_highest_paging_struct)));
}
else
{
asm volatile(
"movl %%cr4, %%eax;"
"andl $~0x80, %%eax;"
"movl %%eax, %%cr4;"
"movl %0, %%cr3;"
"orl $0x80, %%eax;"
"movl %%eax, %%cr4;"
:
: "r"(static_cast<uint32_t>(m_highest_paging_struct))
: "eax"
);
}
if (send_smp_message)
{
@@ -314,14 +346,14 @@ namespace Kernel
.type = Processor::SMPMessage::Type::FlushTLB,
.flush_tlb = {
.vaddr = vaddr,
.page_count = 1,
.page_count = pages,
.page_table = vaddr < KERNEL_OFFSET ? this : nullptr,
}
});
}
}
void PageTable::unmap_page(vaddr_t vaddr, bool send_smp_message)
void PageTable::unmap_page(vaddr_t vaddr, bool invalidate)
{
ASSERT(vaddr);
ASSERT(vaddr % PAGE_SIZE == 0);
@@ -340,16 +372,16 @@ namespace Kernel
if (is_page_free(vaddr))
Kernel::panic("trying to unmap unmapped page 0x{H}", vaddr);
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct));
uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte] & PAGE_ADDR_MASK));
uint64_t* pt = reinterpret_cast<uint64_t*>(P2V(pd[pde] & PAGE_ADDR_MASK));
uint64_t* pdpt = P2V(m_highest_paging_struct);
uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
const paddr_t old_paddr = pt[pte] & PAGE_ADDR_MASK;
const paddr_t old_paddr = pt[pte] & s_page_addr_mask;
pt[pte] = 0;
if (old_paddr != 0)
invalidate(vaddr, send_smp_message);
if (invalidate && old_paddr != 0)
invalidate_page(vaddr, true);
}
void PageTable::unmap_range(vaddr_t vaddr, size_t size)
@@ -361,18 +393,10 @@ namespace Kernel
SpinLockGuard _(m_lock);
for (vaddr_t page = 0; page < page_count; page++)
unmap_page(vaddr + page * PAGE_SIZE, false);
Processor::broadcast_smp_message({
.type = Processor::SMPMessage::Type::FlushTLB,
.flush_tlb = {
.vaddr = vaddr,
.page_count = page_count,
.page_table = vaddr < KERNEL_OFFSET ? this : nullptr,
}
});
invalidate_range(vaddr, page_count, true);
}
void PageTable::map_page_at(paddr_t paddr, vaddr_t vaddr, flags_t flags, MemoryType memory_type, bool send_smp_message)
void PageTable::map_page_at(paddr_t paddr, vaddr_t vaddr, flags_t flags, MemoryType memory_type, bool invalidate)
{
ASSERT(vaddr);
ASSERT(vaddr != fast_page());
@@ -407,11 +431,11 @@ namespace Kernel
SpinLockGuard _(m_lock);
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct));
uint64_t* pdpt = P2V(m_highest_paging_struct);
if (!(pdpt[pdpte] & Flags::Present))
pdpt[pdpte] = V2P(allocate_zeroed_page_aligned_page()) | Flags::Present;
uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte] & PAGE_ADDR_MASK));
uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
if ((pd[pde] & uwr_flags) != uwr_flags)
{
if (!(pd[pde] & Flags::Present))
@@ -422,14 +446,14 @@ namespace Kernel
if (!(flags & Flags::Present))
uwr_flags &= ~Flags::Present;
uint64_t* pt = reinterpret_cast<uint64_t*>(P2V(pd[pde] & PAGE_ADDR_MASK));
uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
const paddr_t old_paddr = pt[pte] & PAGE_ADDR_MASK;
const paddr_t old_paddr = pt[pte] & s_page_addr_mask;
pt[pte] = paddr | uwr_flags | extra_flags;
if (old_paddr != 0)
invalidate(vaddr, send_smp_message);
if (invalidate && old_paddr != 0)
invalidate_page(vaddr, true);
}
void PageTable::map_range_at(paddr_t paddr, vaddr_t vaddr, size_t size, flags_t flags, MemoryType memory_type)
@@ -443,15 +467,7 @@ namespace Kernel
SpinLockGuard _(m_lock);
for (size_t page = 0; page < page_count; page++)
map_page_at(paddr + page * PAGE_SIZE, vaddr + page * PAGE_SIZE, flags, memory_type, false);
Processor::broadcast_smp_message({
.type = Processor::SMPMessage::Type::FlushTLB,
.flush_tlb = {
.vaddr = vaddr,
.page_count = page_count,
.page_table = vaddr < KERNEL_OFFSET ? this : nullptr,
}
});
invalidate_range(vaddr, page_count, true);
}
uint64_t PageTable::get_page_data(vaddr_t vaddr) const
@@ -464,15 +480,15 @@ namespace Kernel
SpinLockGuard _(m_lock);
uint64_t* pdpt = (uint64_t*)P2V(m_highest_paging_struct);
const uint64_t* pdpt = P2V(m_highest_paging_struct);
if (!(pdpt[pdpte] & Flags::Present))
return 0;
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK);
const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
if (!(pd[pde] & Flags::Present))
return 0;
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK);
const uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
if (!(pt[pte] & Flags::Used))
return 0;
@@ -486,8 +502,7 @@ namespace Kernel
paddr_t PageTable::physical_address_of(vaddr_t vaddr) const
{
uint64_t page_data = get_page_data(vaddr);
return (page_data & PAGE_ADDR_MASK) & ~(1ull << 63);
return get_page_data(vaddr) & s_page_addr_mask;
}
bool PageTable::is_page_free(vaddr_t vaddr) const
@@ -529,14 +544,8 @@ namespace Kernel
return false;
for (size_t offset = 0; offset < bytes; offset += PAGE_SIZE)
reserve_page(vaddr + offset, true, false);
Processor::broadcast_smp_message({
.type = Processor::SMPMessage::Type::FlushTLB,
.flush_tlb = {
.vaddr = vaddr,
.page_count = bytes / PAGE_SIZE,
.page_table = vaddr < KERNEL_OFFSET ? this : nullptr,
}
});
invalidate_range(vaddr, bytes / PAGE_SIZE, true);
return true;
}
@@ -549,48 +558,47 @@ namespace Kernel
if (size_t rem = last_address % PAGE_SIZE)
last_address -= rem;
const uint32_t s_pdpte = (first_address >> 30) & 0x1FF;
const uint32_t s_pde = (first_address >> 21) & 0x1FF;
const uint32_t s_pte = (first_address >> 12) & 0x1FF;
uint32_t pdpte = (first_address >> 30) & 0x1FF;
uint32_t pde = (first_address >> 21) & 0x1FF;
uint32_t pte = (first_address >> 12) & 0x1FF;
const uint32_t e_pdpte = (last_address >> 30) & 0x1FF;
const uint32_t e_pde = (last_address >> 21) & 0x1FF;
const uint32_t e_pte = (last_address >> 12) & 0x1FF;
const uint32_t e_pdpte = ((last_address - 1) >> 30) & 0x1FF;
const uint32_t e_pde = ((last_address - 1) >> 21) & 0x1FF;
const uint32_t e_pte = ((last_address - 1) >> 12) & 0x1FF;
SpinLockGuard _(m_lock);
// Try to find free page that can be mapped without
// allocations (page table with unused entries)
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct));
for (uint32_t pdpte = s_pdpte; pdpte < 4; pdpte++)
const uint64_t* pdpt = P2V(m_highest_paging_struct);
for (; pdpte <= e_pdpte; pdpte++)
{
if (pdpte > e_pdpte)
break;
if (!(pdpt[pdpte] & Flags::Present))
continue;
uint64_t* pd = reinterpret_cast<uint64_t*>(P2V(pdpt[pdpte] & PAGE_ADDR_MASK));
for (uint32_t pde = s_pde; pde < 512; pde++)
const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
for (; pde < 512; pde++)
{
if (pdpte == e_pdpte && pde > e_pde)
break;
if (!(pd[pde] & Flags::Present))
continue;
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK);
for (uint32_t pte = s_pte; pte < 512; pte++)
const uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
for (; pte < 512; pte++)
{
if (pdpte == e_pdpte && pde == e_pde && pte >= e_pte)
if (pdpte == e_pdpte && pde == e_pde && pte > e_pte)
break;
if (!(pt[pte] & Flags::Used))
{
vaddr_t vaddr = 0;
vaddr |= (vaddr_t)pdpte << 30;
vaddr |= (vaddr_t)pde << 21;
vaddr |= (vaddr_t)pte << 12;
ASSERT(reserve_page(vaddr));
return vaddr;
}
if (pt[pte] & Flags::Used)
continue;
vaddr_t vaddr = 0;
vaddr |= (vaddr_t)pdpte << 30;
vaddr |= (vaddr_t)pde << 21;
vaddr |= (vaddr_t)pte << 12;
ASSERT(reserve_page(vaddr));
return vaddr;
}
pte = 0;
}
pde = 0;
}
// Find any free page
@@ -659,7 +667,7 @@ namespace Kernel
flags_t flags = 0;
vaddr_t start = 0;
uint64_t* pdpt = reinterpret_cast<uint64_t*>(P2V(m_highest_paging_struct));
const uint64_t* pdpt = P2V(m_highest_paging_struct);
for (uint32_t pdpte = 0; pdpte < 4; pdpte++)
{
if (!(pdpt[pdpte] & Flags::Present))
@@ -668,7 +676,7 @@ namespace Kernel
start = 0;
continue;
}
uint64_t* pd = (uint64_t*)P2V(pdpt[pdpte] & PAGE_ADDR_MASK);
const uint64_t* pd = P2V(pdpt[pdpte] & s_page_addr_mask);
for (uint64_t pde = 0; pde < 512; pde++)
{
if (!(pd[pde] & Flags::Present))
@@ -677,7 +685,7 @@ namespace Kernel
start = 0;
continue;
}
uint64_t* pt = (uint64_t*)P2V(pd[pde] & PAGE_ADDR_MASK);
const uint64_t* pt = P2V(pd[pde] & s_page_addr_mask);
for (uint64_t pte = 0; pte < 512; pte++)
{
if (parse_flags(pt[pte]) != flags)