diff --git a/userspace/programs/DynamicLoader/main.cpp b/userspace/programs/DynamicLoader/main.cpp index dc7ed670..0459a79c 100644 --- a/userspace/programs/DynamicLoader/main.cpp +++ b/userspace/programs/DynamicLoader/main.cpp @@ -4,6 +4,8 @@ #include #include +#include + #include #include #include @@ -210,8 +212,34 @@ static size_t s_loaded_file_count = 0; static const char* s_ld_library_path = nullptr; +static BAN::Atomic s_global_locker = 0; +static uint32_t s_global_lock_depth = 0; + constexpr uintptr_t SYM_NOT_FOUND = -1; +static void lock_global_lock() +{ + const pthread_t tid = syscall<>(SYS_PTHREAD_SELF); + + pthread_t expected = 0; + while (!s_global_locker.compare_exchange(expected, tid)) + { + if (expected == tid) + break; + syscall<>(SYS_YIELD); + expected = 0; + } + + s_global_lock_depth++; +} + +static void unlock_global_lock() +{ + s_global_lock_depth--; + if (s_global_lock_depth == 0) + s_global_locker.store(false); +} + static uint32_t elf_hash(const char* name) { uint32_t h = 0, g; @@ -725,11 +753,24 @@ extern "C" __attribute__((used)) uintptr_t resolve_symbol(const LoadedElf& elf, uintptr_t plt_entry) { - if (elf.pltrel == DT_REL) - return handle_relocation(elf, *reinterpret_cast(elf.jmprel + plt_entry), true); - if (elf.pltrel == DT_RELA) - return handle_relocation(elf, reinterpret_cast(elf.jmprel)[plt_entry], true); - print_error_and_exit("invalid value for DT_PLTREL", 0); + lock_global_lock(); + + uintptr_t result; + switch (elf.pltrel) + { + case DT_REL: + result = handle_relocation(elf, *reinterpret_cast(elf.jmprel + plt_entry), true); + break; + case DT_RELA: + result = handle_relocation(elf, reinterpret_cast(elf.jmprel)[plt_entry], true); + break; + default: + print_error_and_exit("invalid value for DT_PLTREL", 0); + } + + unlock_global_lock(); + + return result; } static LoadedElf& load_elf(const char* path, int fd); @@ -1022,9 +1063,16 @@ static bool load_symbol_table(LoadedElf& elf) static LoadedElf& load_elf(const char* path, int fd) { + lock_global_lock(); + for (size_t i = 0; i < s_loaded_file_count; i++) + { if (strcmp(s_loaded_files[i].path, path) == 0) + { + unlock_global_lock(); return s_loaded_files[i]; + } + } if (fd == -1 && (fd = syscall(SYS_OPENAT, AT_FDCWD, path, O_RDONLY)) < 0) print_error_and_exit("could not open library", fd); @@ -1167,6 +1215,8 @@ static LoadedElf& load_elf(const char* path, int fd) load_symbol_table(elf); + unlock_global_lock(); + return elf; }