diff --git a/userspace/libraries/LibC/stdlib.cpp b/userspace/libraries/LibC/stdlib.cpp index 41317673..0cb34180 100644 --- a/userspace/libraries/LibC/stdlib.cpp +++ b/userspace/libraries/LibC/stdlib.cpp @@ -572,20 +572,21 @@ int mblen(const char* s, size_t n) size_t mbstowcs(wchar_t* __restrict pwcs, const char* __restrict s, size_t n) { - auto* us = reinterpret_cast(s); - - size_t len = 0; + size_t written = 0; switch (__getlocale(LC_CTYPE)) { case LOCALE_INVALID: ASSERT_NOT_REACHED(); case LOCALE_POSIX: - while (*us && len < n) - pwcs[len++] = *us++; + if (pwcs == nullptr) + written = strlen(s); + else for (; s[written] && written < n; written++) + pwcs[written] = s[written]; break; case LOCALE_UTF8: - while (*us && len < n) + const auto* us = reinterpret_cast(s); + for (; *us && (pwcs == nullptr || written < n); written++) { auto wch = BAN::UTF8::to_codepoint(us); if (wch == BAN::UTF8::invalid) @@ -593,16 +594,57 @@ size_t mbstowcs(wchar_t* __restrict pwcs, const char* __restrict s, size_t n) errno = EILSEQ; return -1; } - pwcs[len++] = wch; + if (pwcs != nullptr) + pwcs[written] = wch; us += BAN::UTF8::byte_length(*us); } break; } - if (len < n) - pwcs[len] = 0; + if (pwcs != nullptr && written < n) + pwcs[written] = L'\0'; + return written; +} - return len; +size_t wcstombs(char* __restrict s, const wchar_t* __restrict pwcs, size_t n) +{ + size_t written = 0; + + switch (__getlocale(LC_CTYPE)) + { + case locale_t::LOCALE_INVALID: + ASSERT_NOT_REACHED(); + case locale_t::LOCALE_POSIX: + for (size_t i = 0; pwcs[i] && (s == nullptr || written < n); i++) + { + if (pwcs[i] > 0xFF) + return -1; + if (s != nullptr) + s[written] = pwcs[i]; + written++; + } + break; + case locale_t::LOCALE_UTF8: + for (size_t i = 0; pwcs[i] && (s == nullptr || written < n); i++) + { + char buffer[5]; + if (!BAN::UTF8::from_codepoints(pwcs + i, 1, buffer)) + return -1; + + const size_t len = strlen(buffer); + if (written + len > n) + return len; + + if (s != nullptr) + memcpy(s + written, buffer, len); + written += len; + } + break; + } + + if (s && written < n) + s[written] = '\0'; + return written; } void* bsearch(const void* key, const void* base, size_t nel, size_t width, int (*compar)(const void*, const void*))