Kernel: Cleanup accessing userspace memory
Instead of doing page validiation and loading manually we just do simple memcpy and handle the possible page faults
This commit is contained in:
parent
9589b5984d
commit
f77aa65dc5
|
|
@ -139,6 +139,7 @@ if("${BANAN_ARCH}" STREQUAL "x86_64")
|
|||
arch/x86_64/Signal.S
|
||||
arch/x86_64/Syscall.S
|
||||
arch/x86_64/Thread.S
|
||||
arch/x86_64/User.S
|
||||
arch/x86_64/Yield.S
|
||||
)
|
||||
elseif("${BANAN_ARCH}" STREQUAL "i686")
|
||||
|
|
@ -150,6 +151,7 @@ elseif("${BANAN_ARCH}" STREQUAL "i686")
|
|||
arch/i686/Signal.S
|
||||
arch/i686/Syscall.S
|
||||
arch/i686/Thread.S
|
||||
arch/i686/User.S
|
||||
arch/i686/Yield.S
|
||||
)
|
||||
else()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
# bool safe_user_memcpy(void*, const void*, size_t)
|
||||
.global safe_user_memcpy
|
||||
.global safe_user_memcpy_end
|
||||
.global safe_user_memcpy_fault
|
||||
safe_user_memcpy:
|
||||
xorl %eax, %eax
|
||||
xchgl 4(%esp), %edi
|
||||
xchgl 8(%esp), %esi
|
||||
movl 12(%esp), %ecx
|
||||
movl %edi, %edx
|
||||
rep movsb
|
||||
movl 4(%esp), %edi
|
||||
movl 8(%esp), %esi
|
||||
incl %eax
|
||||
safe_user_memcpy_fault:
|
||||
ret
|
||||
safe_user_memcpy_end:
|
||||
|
||||
# bool safe_user_strncpy(void*, const void*, size_t)
|
||||
.global safe_user_strncpy
|
||||
.global safe_user_strncpy_end
|
||||
.global safe_user_strncpy_fault
|
||||
safe_user_strncpy:
|
||||
xchgl 4(%esp), %edi
|
||||
xchgl 8(%esp), %esi
|
||||
movl 12(%esp), %ecx
|
||||
|
||||
testl %ecx, %ecx
|
||||
jz safe_user_strncpy_fault
|
||||
|
||||
.safe_user_strncpy_loop:
|
||||
movb (%esi), %al
|
||||
movb %al, (%edi)
|
||||
testb %al, %al
|
||||
jz .safe_user_strncpy_done
|
||||
|
||||
incl %edi
|
||||
incl %esi
|
||||
decl %ecx
|
||||
jnz .safe_user_strncpy_loop
|
||||
|
||||
safe_user_strncpy_fault:
|
||||
xorl %eax, %eax
|
||||
jmp .safe_user_strncpy_return
|
||||
|
||||
.safe_user_strncpy_done:
|
||||
movl $1, %eax
|
||||
|
||||
.safe_user_strncpy_return:
|
||||
movl 4(%esp), %edi
|
||||
movl 8(%esp), %esi
|
||||
ret
|
||||
|
||||
safe_user_strncpy_end:
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
# bool safe_user_memcpy(void*, const void*, size_t)
|
||||
.global safe_user_memcpy
|
||||
.global safe_user_memcpy_end
|
||||
.global safe_user_memcpy_fault
|
||||
safe_user_memcpy:
|
||||
xorq %rax, %rax
|
||||
movq %rdx, %rcx
|
||||
rep movsb
|
||||
incq %rax
|
||||
safe_user_memcpy_fault:
|
||||
ret
|
||||
safe_user_memcpy_end:
|
||||
|
||||
# bool safe_user_strncpy(void*, const void*, size_t)
|
||||
.global safe_user_strncpy
|
||||
.global safe_user_strncpy_end
|
||||
.global safe_user_strncpy_fault
|
||||
safe_user_strncpy:
|
||||
movq %rdx, %rcx
|
||||
testq %rcx, %rcx
|
||||
jz safe_user_strncpy_fault
|
||||
|
||||
.safe_user_strncpy_align_loop:
|
||||
testb $0x7, %sil
|
||||
jz .safe_user_strncpy_align_done
|
||||
|
||||
movb (%rsi), %al
|
||||
movb %al, (%rdi)
|
||||
testb %al, %al
|
||||
jz .safe_user_strncpy_done
|
||||
|
||||
incq %rdi
|
||||
incq %rsi
|
||||
decq %rcx
|
||||
jnz .safe_user_strncpy_align_loop
|
||||
jmp safe_user_strncpy_fault
|
||||
|
||||
.safe_user_strncpy_align_done:
|
||||
movq $0x0101010101010101, %r8
|
||||
movq $0x8080808080808080, %r9
|
||||
|
||||
.safe_user_strncpy_qword_loop:
|
||||
cmpq $8, %rcx
|
||||
jb .safe_user_strncpy_qword_done
|
||||
|
||||
movq (%rsi), %rax
|
||||
movq %rax, %r10
|
||||
movq %rax, %r11
|
||||
|
||||
# https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
|
||||
subq %r8, %r10
|
||||
notq %r11
|
||||
andq %r11, %r10
|
||||
andq %r9, %r10
|
||||
jnz .safe_user_strncpy_byte_loop
|
||||
|
||||
movq %rax, (%rdi)
|
||||
|
||||
addq $8, %rdi
|
||||
addq $8, %rsi
|
||||
subq $8, %rcx
|
||||
jnz .safe_user_strncpy_qword_loop
|
||||
jmp safe_user_strncpy_fault
|
||||
|
||||
.safe_user_strncpy_qword_done:
|
||||
testq %rcx, %rcx
|
||||
jz safe_user_strncpy_fault
|
||||
|
||||
.safe_user_strncpy_byte_loop:
|
||||
movb (%rsi), %al
|
||||
movb %al, (%rdi)
|
||||
testb %al, %al
|
||||
jz .safe_user_strncpy_done
|
||||
|
||||
incq %rdi
|
||||
incq %rsi
|
||||
decq %rcx
|
||||
jnz .safe_user_strncpy_byte_loop
|
||||
|
||||
safe_user_strncpy_fault:
|
||||
xorq %rax, %rax
|
||||
ret
|
||||
|
||||
.safe_user_strncpy_done:
|
||||
movb $1, %al
|
||||
ret
|
||||
safe_user_strncpy_end:
|
||||
|
|
@ -164,6 +164,33 @@ namespace Kernel
|
|||
"Unkown Exception 0x1F",
|
||||
};
|
||||
|
||||
extern "C" uint8_t safe_user_memcpy[];
|
||||
extern "C" uint8_t safe_user_memcpy_end[];
|
||||
extern "C" uint8_t safe_user_memcpy_fault[];
|
||||
|
||||
extern "C" uint8_t safe_user_strncpy[];
|
||||
extern "C" uint8_t safe_user_strncpy_end[];
|
||||
extern "C" uint8_t safe_user_strncpy_fault[];
|
||||
|
||||
struct safe_user_page_fault
|
||||
{
|
||||
const uint8_t* ip_start;
|
||||
const uint8_t* ip_end;
|
||||
const uint8_t* ip_fault;
|
||||
};
|
||||
static constexpr safe_user_page_fault s_safe_user_page_faults[] {
|
||||
{
|
||||
.ip_start = safe_user_memcpy,
|
||||
.ip_end = safe_user_memcpy_end,
|
||||
.ip_fault = safe_user_memcpy_fault,
|
||||
},
|
||||
{
|
||||
.ip_start = safe_user_strncpy,
|
||||
.ip_end = safe_user_strncpy_end,
|
||||
.ip_fault = safe_user_strncpy_fault,
|
||||
},
|
||||
};
|
||||
|
||||
extern "C" void cpp_isr_handler(uint32_t isr, uint32_t error, InterruptStack* interrupt_stack, const Registers* regs)
|
||||
{
|
||||
if (g_paniced)
|
||||
|
|
@ -201,6 +228,15 @@ namespace Kernel
|
|||
if (result.value())
|
||||
return;
|
||||
|
||||
const uint8_t* ip = reinterpret_cast<const uint8_t*>(interrupt_stack->ip);
|
||||
for (const auto& safe_user : s_safe_user_page_faults)
|
||||
{
|
||||
if (ip < safe_user.ip_start || ip >= safe_user.ip_end)
|
||||
continue;
|
||||
interrupt_stack->ip = reinterpret_cast<vaddr_t>(safe_user.ip_fault);
|
||||
return;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case ISR::DeviceNotAvailable:
|
||||
|
|
|
|||
|
|
@ -3887,237 +3887,45 @@ namespace Kernel
|
|||
return region->allocate_page_containing(address, wants_write);
|
||||
}
|
||||
|
||||
// TODO: The following 3 functions could be simplified into one generic helper function
|
||||
extern "C" bool safe_user_memcpy(void*, const void*, size_t);
|
||||
extern "C" bool safe_user_strncpy(void*, const void*, size_t);
|
||||
|
||||
static inline bool is_valid_user_address(const void* user_addr, size_t size)
|
||||
{
|
||||
const vaddr_t user_vaddr = reinterpret_cast<vaddr_t>(user_addr);
|
||||
if (BAN::Math::will_addition_overflow<vaddr_t>(user_vaddr, size))
|
||||
return false;
|
||||
if (user_vaddr + size > USERSPACE_END)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
BAN::ErrorOr<void> Process::read_from_user(const void* user_addr, void* out, size_t size)
|
||||
{
|
||||
const vaddr_t user_vaddr = reinterpret_cast<vaddr_t>(user_addr);
|
||||
|
||||
auto* out_u8 = static_cast<uint8_t*>(out);
|
||||
size_t ncopied = 0;
|
||||
|
||||
{
|
||||
RWLockRDGuard _(m_memory_region_lock);
|
||||
|
||||
const size_t first_index = find_mapped_region(user_vaddr);
|
||||
for (size_t i = first_index; ncopied < size && i < m_mapped_regions.size(); i++)
|
||||
{
|
||||
auto& region = m_mapped_regions[i];
|
||||
if (!region->contains(user_vaddr + ncopied))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
|
||||
const size_t ncopy = BAN::Math::min<size_t>(
|
||||
(region->vaddr() + region->size()) - (user_vaddr + ncopied),
|
||||
size - ncopied
|
||||
);
|
||||
|
||||
const size_t page_count = range_page_count(user_vaddr + ncopied, ncopy);
|
||||
const vaddr_t page_base = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
|
||||
for (size_t p = 0; p < page_count; p++)
|
||||
{
|
||||
const auto flags = PageTable::UserSupervisor | PageTable::Present;
|
||||
if ((m_page_table->get_page_flags(page_base + p * PAGE_SIZE) & flags) != flags)
|
||||
goto read_from_user_with_allocation;
|
||||
}
|
||||
|
||||
memcpy(out_u8 + ncopied, reinterpret_cast<void*>(user_vaddr + ncopied), ncopy);
|
||||
ncopied += ncopy;
|
||||
}
|
||||
|
||||
if (ncopied >= size)
|
||||
return {};
|
||||
if (ncopied > 0)
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
}
|
||||
|
||||
read_from_user_with_allocation:
|
||||
RWLockWRGuard _(m_memory_region_lock);
|
||||
|
||||
const size_t first_index = find_mapped_region(user_vaddr + ncopied);
|
||||
for (size_t i = first_index; ncopied < size && i < m_mapped_regions.size(); i++)
|
||||
{
|
||||
auto& region = m_mapped_regions[i];
|
||||
if (!region->contains(user_vaddr + ncopied))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
|
||||
const size_t ncopy = BAN::Math::min<size_t>(
|
||||
(region->vaddr() + region->size()) - (user_vaddr + ncopied),
|
||||
size - ncopied
|
||||
);
|
||||
|
||||
const size_t page_count = range_page_count(user_vaddr + ncopied, ncopy);
|
||||
const vaddr_t page_base = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
|
||||
for (size_t p = 0; p < page_count; p++)
|
||||
{
|
||||
const auto flags = PageTable::UserSupervisor | PageTable::Present;
|
||||
if ((m_page_table->get_page_flags(page_base + p * PAGE_SIZE) & flags) == flags)
|
||||
continue;
|
||||
if (!TRY(region->allocate_page_containing(page_base + p * PAGE_SIZE, false)))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
}
|
||||
|
||||
memcpy(out_u8 + ncopied, reinterpret_cast<void*>(user_vaddr + ncopied), ncopy);
|
||||
ncopied += ncopy;
|
||||
}
|
||||
|
||||
if (ncopied >= size)
|
||||
return {};
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
if (!is_valid_user_address(user_addr, size))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
if (!safe_user_memcpy(out, user_addr, size))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
return {};
|
||||
}
|
||||
|
||||
BAN::ErrorOr<void> Process::read_string_from_user(const char* user_addr, char* out, size_t max_size)
|
||||
{
|
||||
const vaddr_t user_vaddr = reinterpret_cast<vaddr_t>(user_addr);
|
||||
|
||||
size_t ncopied = 0;
|
||||
|
||||
{
|
||||
RWLockRDGuard _(m_memory_region_lock);
|
||||
|
||||
const size_t first_index = find_mapped_region(user_vaddr);
|
||||
for (size_t i = first_index; ncopied < max_size && i < m_mapped_regions.size(); i++)
|
||||
{
|
||||
auto& region = m_mapped_regions[i];
|
||||
if (!region->contains(user_vaddr + ncopied))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
|
||||
vaddr_t last_page = 0;
|
||||
|
||||
for (; ncopied < max_size; ncopied++)
|
||||
{
|
||||
const vaddr_t curr_page = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
|
||||
if (curr_page != last_page)
|
||||
{
|
||||
const auto flags = PageTable::UserSupervisor | PageTable::Present;
|
||||
if ((m_page_table->get_page_flags(curr_page) & flags) != flags)
|
||||
goto read_string_from_user_with_allocation;
|
||||
}
|
||||
|
||||
out[ncopied] = user_addr[ncopied];
|
||||
if (out[ncopied] == '\0')
|
||||
return {};
|
||||
|
||||
last_page = curr_page;
|
||||
}
|
||||
}
|
||||
|
||||
if (ncopied >= max_size)
|
||||
return BAN::Error::from_errno(ENAMETOOLONG);
|
||||
if (ncopied > 0)
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
}
|
||||
|
||||
read_string_from_user_with_allocation:
|
||||
RWLockWRGuard _(m_memory_region_lock);
|
||||
|
||||
const size_t first_index = find_mapped_region(user_vaddr + ncopied);
|
||||
for (size_t i = first_index; ncopied < max_size && i < m_mapped_regions.size(); i++)
|
||||
{
|
||||
auto& region = m_mapped_regions[i];
|
||||
if (!region->contains(user_vaddr + ncopied))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
|
||||
vaddr_t last_page = 0;
|
||||
|
||||
for (; ncopied < max_size; ncopied++)
|
||||
{
|
||||
const vaddr_t curr_page = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
|
||||
if (curr_page != last_page)
|
||||
{
|
||||
const auto flags = PageTable::UserSupervisor | PageTable::Present;
|
||||
if ((m_page_table->get_page_flags(curr_page) & flags) == flags)
|
||||
;
|
||||
else if (!TRY(region->allocate_page_containing(curr_page, false)))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
}
|
||||
|
||||
out[ncopied] = user_addr[ncopied];
|
||||
if (out[ncopied] == '\0')
|
||||
return {};
|
||||
|
||||
last_page = curr_page;
|
||||
}
|
||||
}
|
||||
|
||||
if (ncopied >= max_size)
|
||||
return BAN::Error::from_errno(ENAMETOOLONG);
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
max_size = BAN::Math::min<size_t>(max_size, USERSPACE_END - reinterpret_cast<vaddr_t>(user_addr));
|
||||
if (!is_valid_user_address(user_addr, max_size))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
if (!safe_user_strncpy(out, user_addr, max_size))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
return {};
|
||||
}
|
||||
|
||||
BAN::ErrorOr<void> Process::write_to_user(void* user_addr, const void* in, size_t size)
|
||||
{
|
||||
const vaddr_t user_vaddr = reinterpret_cast<vaddr_t>(user_addr);
|
||||
|
||||
const auto* in_u8 = static_cast<const uint8_t*>(in);
|
||||
size_t ncopied = 0;
|
||||
|
||||
{
|
||||
RWLockRDGuard _(m_memory_region_lock);
|
||||
|
||||
const size_t first_index = find_mapped_region(user_vaddr);
|
||||
for (size_t i = first_index; ncopied < size && i < m_mapped_regions.size(); i++)
|
||||
{
|
||||
auto& region = m_mapped_regions[i];
|
||||
if (!region->contains(user_vaddr + ncopied))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
|
||||
const size_t ncopy = BAN::Math::min<size_t>(
|
||||
(region->vaddr() + region->size()) - (user_vaddr + ncopied),
|
||||
size - ncopied
|
||||
);
|
||||
|
||||
const size_t page_count = range_page_count(user_vaddr + ncopied, ncopy);
|
||||
const vaddr_t page_base = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
|
||||
for (size_t i = 0; i < page_count; i++)
|
||||
{
|
||||
const auto flags = PageTable::UserSupervisor | PageTable::ReadWrite | PageTable::Present;
|
||||
if ((m_page_table->get_page_flags(page_base + i * PAGE_SIZE) & flags) != flags)
|
||||
goto write_to_user_with_allocation;
|
||||
}
|
||||
|
||||
memcpy(reinterpret_cast<void*>(user_vaddr + ncopied), in_u8 + ncopied, ncopy);
|
||||
ncopied += ncopy;
|
||||
}
|
||||
|
||||
if (ncopied >= size)
|
||||
return {};
|
||||
if (ncopied > 0)
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
}
|
||||
|
||||
write_to_user_with_allocation:
|
||||
RWLockWRGuard _(m_memory_region_lock);
|
||||
|
||||
const size_t first_index = find_mapped_region(user_vaddr + ncopied);
|
||||
for (size_t i = first_index; ncopied < size && i < m_mapped_regions.size(); i++)
|
||||
{
|
||||
auto& region = m_mapped_regions[i];
|
||||
if (!region->contains(user_vaddr + ncopied))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
|
||||
const size_t ncopy = BAN::Math::min<size_t>(
|
||||
(region->vaddr() + region->size()) - (user_vaddr + ncopied),
|
||||
size - ncopied
|
||||
);
|
||||
|
||||
const size_t page_count = range_page_count(user_vaddr + ncopied, ncopy);
|
||||
const vaddr_t page_base = (user_vaddr + ncopied) & PAGE_ADDR_MASK;
|
||||
for (size_t p = 0; p < page_count; p++)
|
||||
{
|
||||
const auto flags = PageTable::UserSupervisor | PageTable::ReadWrite | PageTable::Present;
|
||||
if ((m_page_table->get_page_flags(page_base + p * PAGE_SIZE) & flags) == flags)
|
||||
continue;
|
||||
if (!TRY(region->allocate_page_containing(page_base + p * PAGE_SIZE, true)))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
}
|
||||
|
||||
memcpy(reinterpret_cast<void*>(user_vaddr + ncopied), in_u8 + ncopied, ncopy);
|
||||
ncopied += ncopy;
|
||||
}
|
||||
|
||||
if (ncopied >= size)
|
||||
return {};
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
if (!is_valid_user_address(user_addr, size))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
if (!safe_user_memcpy(user_addr, in, size))
|
||||
return BAN::Error::from_errno(EFAULT);
|
||||
return {};
|
||||
}
|
||||
|
||||
BAN::ErrorOr<MemoryRegion*> Process::validate_and_pin_pointer_access(const void* ptr, size_t size, bool needs_write)
|
||||
|
|
|
|||
Loading…
Reference in New Issue