Kernel: Cleanup accessing userspace memory

Instead of doing page validiation and loading manually we just do simple
memcpy and handle the possible page faults
This commit is contained in:
Bananymous 2026-04-02 16:36:33 +03:00
parent 9589b5984d
commit f77aa65dc5
5 changed files with 207 additions and 220 deletions

View File

@ -139,6 +139,7 @@ if("${BANAN_ARCH}" STREQUAL "x86_64")
arch/x86_64/Signal.S arch/x86_64/Signal.S
arch/x86_64/Syscall.S arch/x86_64/Syscall.S
arch/x86_64/Thread.S arch/x86_64/Thread.S
arch/x86_64/User.S
arch/x86_64/Yield.S arch/x86_64/Yield.S
) )
elseif("${BANAN_ARCH}" STREQUAL "i686") elseif("${BANAN_ARCH}" STREQUAL "i686")
@ -150,6 +151,7 @@ elseif("${BANAN_ARCH}" STREQUAL "i686")
arch/i686/Signal.S arch/i686/Signal.S
arch/i686/Syscall.S arch/i686/Syscall.S
arch/i686/Thread.S arch/i686/Thread.S
arch/i686/User.S
arch/i686/Yield.S arch/i686/Yield.S
) )
else() else()

54
kernel/arch/i686/User.S Normal file
View File

@ -0,0 +1,54 @@
# bool safe_user_memcpy(void*, const void*, size_t)
.global safe_user_memcpy
.global safe_user_memcpy_end
.global safe_user_memcpy_fault
safe_user_memcpy:
xorl %eax, %eax
xchgl 4(%esp), %edi
xchgl 8(%esp), %esi
movl 12(%esp), %ecx
movl %edi, %edx
rep movsb
movl 4(%esp), %edi
movl 8(%esp), %esi
incl %eax
safe_user_memcpy_fault:
ret
safe_user_memcpy_end:
# bool safe_user_strncpy(void*, const void*, size_t)
.global safe_user_strncpy
.global safe_user_strncpy_end
.global safe_user_strncpy_fault
safe_user_strncpy:
xchgl 4(%esp), %edi
xchgl 8(%esp), %esi
movl 12(%esp), %ecx
testl %ecx, %ecx
jz safe_user_strncpy_fault
.safe_user_strncpy_loop:
movb (%esi), %al
movb %al, (%edi)
testb %al, %al
jz .safe_user_strncpy_done
incl %edi
incl %esi
decl %ecx
jnz .safe_user_strncpy_loop
safe_user_strncpy_fault:
xorl %eax, %eax
jmp .safe_user_strncpy_return
.safe_user_strncpy_done:
movl $1, %eax
.safe_user_strncpy_return:
movl 4(%esp), %edi
movl 8(%esp), %esi
ret
safe_user_strncpy_end:

87
kernel/arch/x86_64/User.S Normal file
View File

@ -0,0 +1,87 @@
# bool safe_user_memcpy(void*, const void*, size_t)
.global safe_user_memcpy
.global safe_user_memcpy_end
.global safe_user_memcpy_fault
safe_user_memcpy:
xorq %rax, %rax
movq %rdx, %rcx
rep movsb
incq %rax
safe_user_memcpy_fault:
ret
safe_user_memcpy_end:
# bool safe_user_strncpy(void*, const void*, size_t)
.global safe_user_strncpy
.global safe_user_strncpy_end
.global safe_user_strncpy_fault
safe_user_strncpy:
movq %rdx, %rcx
testq %rcx, %rcx
jz safe_user_strncpy_fault
.safe_user_strncpy_align_loop:
testb $0x7, %sil
jz .safe_user_strncpy_align_done
movb (%rsi), %al
movb %al, (%rdi)
testb %al, %al
jz .safe_user_strncpy_done
incq %rdi
incq %rsi
decq %rcx
jnz .safe_user_strncpy_align_loop
jmp safe_user_strncpy_fault
.safe_user_strncpy_align_done:
movq $0x0101010101010101, %r8
movq $0x8080808080808080, %r9
.safe_user_strncpy_qword_loop:
cmpq $8, %rcx
jb .safe_user_strncpy_qword_done
movq (%rsi), %rax
movq %rax, %r10
movq %rax, %r11
# https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
subq %r8, %r10
notq %r11
andq %r11, %r10
andq %r9, %r10
jnz .safe_user_strncpy_byte_loop
movq %rax, (%rdi)
addq $8, %rdi
addq $8, %rsi
subq $8, %rcx
jnz .safe_user_strncpy_qword_loop
jmp safe_user_strncpy_fault
.safe_user_strncpy_qword_done:
testq %rcx, %rcx
jz safe_user_strncpy_fault
.safe_user_strncpy_byte_loop:
movb (%rsi), %al
movb %al, (%rdi)
testb %al, %al
jz .safe_user_strncpy_done
incq %rdi
incq %rsi
decq %rcx
jnz .safe_user_strncpy_byte_loop
safe_user_strncpy_fault:
xorq %rax, %rax
ret
.safe_user_strncpy_done:
movb $1, %al
ret
safe_user_strncpy_end:

View File

@ -164,6 +164,33 @@ namespace Kernel
"Unkown Exception 0x1F", "Unkown Exception 0x1F",
}; };
extern "C" uint8_t safe_user_memcpy[];
extern "C" uint8_t safe_user_memcpy_end[];
extern "C" uint8_t safe_user_memcpy_fault[];
extern "C" uint8_t safe_user_strncpy[];
extern "C" uint8_t safe_user_strncpy_end[];
extern "C" uint8_t safe_user_strncpy_fault[];
struct safe_user_page_fault
{
const uint8_t* ip_start;
const uint8_t* ip_end;
const uint8_t* ip_fault;
};
static constexpr safe_user_page_fault s_safe_user_page_faults[] {
{
.ip_start = safe_user_memcpy,
.ip_end = safe_user_memcpy_end,
.ip_fault = safe_user_memcpy_fault,
},
{
.ip_start = safe_user_strncpy,
.ip_end = safe_user_strncpy_end,
.ip_fault = safe_user_strncpy_fault,
},
};
extern "C" void cpp_isr_handler(uint32_t isr, uint32_t error, InterruptStack* interrupt_stack, const Registers* regs) extern "C" void cpp_isr_handler(uint32_t isr, uint32_t error, InterruptStack* interrupt_stack, const Registers* regs)
{ {
if (g_paniced) if (g_paniced)
@ -201,6 +228,15 @@ namespace Kernel
if (result.value()) if (result.value())
return; return;
const uint8_t* ip = reinterpret_cast<const uint8_t*>(interrupt_stack->ip);
for (const auto& safe_user : s_safe_user_page_faults)
{
if (ip < safe_user.ip_start || ip >= safe_user.ip_end)
continue;
interrupt_stack->ip = reinterpret_cast<vaddr_t>(safe_user.ip_fault);
return;
}
break; break;
} }
case ISR::DeviceNotAvailable: case ISR::DeviceNotAvailable:

View File

@ -3887,237 +3887,45 @@ namespace Kernel
return region->allocate_page_containing(address, wants_write); return region->allocate_page_containing(address, wants_write);
} }
// TODO: The following 3 functions could be simplified into one generic helper function extern "C" bool safe_user_memcpy(void*, const void*, size_t);
extern "C" bool safe_user_strncpy(void*, const void*, size_t);
static inline bool is_valid_user_address(const void* user_addr, size_t size)
{
const vaddr_t user_vaddr = reinterpret_cast<vaddr_t>(user_addr);
if (BAN::Math::will_addition_overflow<vaddr_t>(user_vaddr, size))
return false;
if (user_vaddr + size > USERSPACE_END)
return false;
return true;
}
BAN::ErrorOr<void> Process::read_from_user(const void* user_addr, void* out, size_t size) BAN::ErrorOr<void> Process::read_from_user(const void* user_addr, void* out, size_t size)
{ {
const vaddr_t user_vaddr = reinterpret_cast<vaddr_t>(user_addr); if (!is_valid_user_address(user_addr, size))
return BAN::Error::from_errno(EFAULT);
auto* out_u8 = static_cast<uint8_t*>(out); if (!safe_user_memcpy(out, user_addr, size))
size_t ncopied = 0;
{
RWLockRDGuard _(m_memory_region_lock);
const size_t first_index = find_mapped_region(user_vaddr);
for (size_t i = first_index; ncopied < size && i < m_mapped_regions.size(); i++)
{
auto& region = m_mapped_regions[i];
if (!region->contains(user_vaddr + ncopied))
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
const size_t ncopy = BAN::Math::min<size_t>(
(region->vaddr() + region->size()) - (user_vaddr + ncopied),
size - ncopied
);
const size_t page_count = range_page_count(user_vaddr + ncopied, ncopy);
const vaddr_t page_base = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
for (size_t p = 0; p < page_count; p++)
{
const auto flags = PageTable::UserSupervisor | PageTable::Present;
if ((m_page_table->get_page_flags(page_base + p * PAGE_SIZE) & flags) != flags)
goto read_from_user_with_allocation;
}
memcpy(out_u8 + ncopied, reinterpret_cast<void*>(user_vaddr + ncopied), ncopy);
ncopied += ncopy;
}
if (ncopied >= size)
return {}; return {};
if (ncopied > 0)
return BAN::Error::from_errno(EFAULT);
}
read_from_user_with_allocation:
RWLockWRGuard _(m_memory_region_lock);
const size_t first_index = find_mapped_region(user_vaddr + ncopied);
for (size_t i = first_index; ncopied < size && i < m_mapped_regions.size(); i++)
{
auto& region = m_mapped_regions[i];
if (!region->contains(user_vaddr + ncopied))
return BAN::Error::from_errno(EFAULT);
const size_t ncopy = BAN::Math::min<size_t>(
(region->vaddr() + region->size()) - (user_vaddr + ncopied),
size - ncopied
);
const size_t page_count = range_page_count(user_vaddr + ncopied, ncopy);
const vaddr_t page_base = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
for (size_t p = 0; p < page_count; p++)
{
const auto flags = PageTable::UserSupervisor | PageTable::Present;
if ((m_page_table->get_page_flags(page_base + p * PAGE_SIZE) & flags) == flags)
continue;
if (!TRY(region->allocate_page_containing(page_base + p * PAGE_SIZE, false)))
return BAN::Error::from_errno(EFAULT);
}
memcpy(out_u8 + ncopied, reinterpret_cast<void*>(user_vaddr + ncopied), ncopy);
ncopied += ncopy;
}
if (ncopied >= size)
return {};
return BAN::Error::from_errno(EFAULT);
} }
BAN::ErrorOr<void> Process::read_string_from_user(const char* user_addr, char* out, size_t max_size) BAN::ErrorOr<void> Process::read_string_from_user(const char* user_addr, char* out, size_t max_size)
{ {
const vaddr_t user_vaddr = reinterpret_cast<vaddr_t>(user_addr); max_size = BAN::Math::min<size_t>(max_size, USERSPACE_END - reinterpret_cast<vaddr_t>(user_addr));
if (!is_valid_user_address(user_addr, max_size))
size_t ncopied = 0; return BAN::Error::from_errno(EFAULT);
if (!safe_user_strncpy(out, user_addr, max_size))
{
RWLockRDGuard _(m_memory_region_lock);
const size_t first_index = find_mapped_region(user_vaddr);
for (size_t i = first_index; ncopied < max_size && i < m_mapped_regions.size(); i++)
{
auto& region = m_mapped_regions[i];
if (!region->contains(user_vaddr + ncopied))
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
vaddr_t last_page = 0;
for (; ncopied < max_size; ncopied++)
{
const vaddr_t curr_page = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
if (curr_page != last_page)
{
const auto flags = PageTable::UserSupervisor | PageTable::Present;
if ((m_page_table->get_page_flags(curr_page) & flags) != flags)
goto read_string_from_user_with_allocation;
}
out[ncopied] = user_addr[ncopied];
if (out[ncopied] == '\0')
return {}; return {};
last_page = curr_page;
}
}
if (ncopied >= max_size)
return BAN::Error::from_errno(ENAMETOOLONG);
if (ncopied > 0)
return BAN::Error::from_errno(EFAULT);
}
read_string_from_user_with_allocation:
RWLockWRGuard _(m_memory_region_lock);
const size_t first_index = find_mapped_region(user_vaddr + ncopied);
for (size_t i = first_index; ncopied < max_size && i < m_mapped_regions.size(); i++)
{
auto& region = m_mapped_regions[i];
if (!region->contains(user_vaddr + ncopied))
return BAN::Error::from_errno(EFAULT);
vaddr_t last_page = 0;
for (; ncopied < max_size; ncopied++)
{
const vaddr_t curr_page = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
if (curr_page != last_page)
{
const auto flags = PageTable::UserSupervisor | PageTable::Present;
if ((m_page_table->get_page_flags(curr_page) & flags) == flags)
;
else if (!TRY(region->allocate_page_containing(curr_page, false)))
return BAN::Error::from_errno(EFAULT);
}
out[ncopied] = user_addr[ncopied];
if (out[ncopied] == '\0')
return {};
last_page = curr_page;
}
}
if (ncopied >= max_size)
return BAN::Error::from_errno(ENAMETOOLONG);
return BAN::Error::from_errno(EFAULT);
} }
BAN::ErrorOr<void> Process::write_to_user(void* user_addr, const void* in, size_t size) BAN::ErrorOr<void> Process::write_to_user(void* user_addr, const void* in, size_t size)
{ {
const vaddr_t user_vaddr = reinterpret_cast<vaddr_t>(user_addr); if (!is_valid_user_address(user_addr, size))
return BAN::Error::from_errno(EFAULT);
const auto* in_u8 = static_cast<const uint8_t*>(in); if (!safe_user_memcpy(user_addr, in, size))
size_t ncopied = 0;
{
RWLockRDGuard _(m_memory_region_lock);
const size_t first_index = find_mapped_region(user_vaddr);
for (size_t i = first_index; ncopied < size && i < m_mapped_regions.size(); i++)
{
auto& region = m_mapped_regions[i];
if (!region->contains(user_vaddr + ncopied))
return BAN::Error::from_errno(EFAULT); return BAN::Error::from_errno(EFAULT);
const size_t ncopy = BAN::Math::min<size_t>(
(region->vaddr() + region->size()) - (user_vaddr + ncopied),
size - ncopied
);
const size_t page_count = range_page_count(user_vaddr + ncopied, ncopy);
const vaddr_t page_base = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
for (size_t i = 0; i < page_count; i++)
{
const auto flags = PageTable::UserSupervisor | PageTable::ReadWrite | PageTable::Present;
if ((m_page_table->get_page_flags(page_base + i * PAGE_SIZE) & flags) != flags)
goto write_to_user_with_allocation;
}
memcpy(reinterpret_cast<void*>(user_vaddr + ncopied), in_u8 + ncopied, ncopy);
ncopied += ncopy;
}
if (ncopied >= size)
return {}; return {};
if (ncopied > 0)
return BAN::Error::from_errno(EFAULT);
}
write_to_user_with_allocation:
RWLockWRGuard _(m_memory_region_lock);
const size_t first_index = find_mapped_region(user_vaddr + ncopied);
for (size_t i = first_index; ncopied < size && i < m_mapped_regions.size(); i++)
{
auto& region = m_mapped_regions[i];
if (!region->contains(user_vaddr + ncopied))
return BAN::Error::from_errno(EFAULT);
const size_t ncopy = BAN::Math::min<size_t>(
(region->vaddr() + region->size()) - (user_vaddr + ncopied),
size - ncopied
);
const size_t page_count = range_page_count(user_vaddr + ncopied, ncopy);
const vaddr_t page_base = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
for (size_t p = 0; p < page_count; p++)
{
const auto flags = PageTable::UserSupervisor | PageTable::ReadWrite | PageTable::Present;
if ((m_page_table->get_page_flags(page_base + p * PAGE_SIZE) & flags) == flags)
continue;
if (!TRY(region->allocate_page_containing(page_base + p * PAGE_SIZE, true)))
return BAN::Error::from_errno(EFAULT);
}
memcpy(reinterpret_cast<void*>(user_vaddr + ncopied), in_u8 + ncopied, ncopy);
ncopied += ncopy;
}
if (ncopied >= size)
return {};
return BAN::Error::from_errno(EFAULT);
} }
BAN::ErrorOr<MemoryRegion*> Process::validate_and_pin_pointer_access(const void* ptr, size_t size, bool needs_write) BAN::ErrorOr<MemoryRegion*> Process::validate_and_pin_pointer_access(const void* ptr, size_t size, bool needs_write)