diff --git a/libc/stdlib.cpp b/libc/stdlib.cpp index d36f9cf2..18217c4e 100644 --- a/libc/stdlib.cpp +++ b/libc/stdlib.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -51,35 +52,66 @@ int atexit(void (*func)(void)) return 0; } +static constexpr int get_base_digit(char c, int base) +{ + int digit = -1; + if (isdigit(c)) + digit = c - '0'; + else if (isalpha(c)) + digit = 10 + tolower(c) - 'a'; + if (digit < base) + return digit; + return -1; +} + template -static T strtoT(const char* str, char** endp, int base) +static constexpr bool will_multiplication_overflow(T a, T b) +{ + if (a == 0 || b == 0) + return false; + if ((a > 0) == (b > 0)) + return a > BAN::numeric_limits::max() / b; + else + return a < BAN::numeric_limits::min() / b; +} + +template +static constexpr bool will_addition_overflow(T a, T b) +{ + if (a > 0 && b > 0) + return a > BAN::numeric_limits::max() - b; + if (a < 0 && b < 0) + return a < BAN::numeric_limits::min() - b; + return false; +} + +template +static constexpr bool will_digit_append_overflow(bool negative, T current, int digit, int base) +{ + if (BAN::is_unsigned_v && negative && digit) + return true; + if (will_multiplication_overflow(current, base)) + return true; + if (will_addition_overflow(current * base, current < 0 ? -digit : digit)) + return true; + return false; +} + +template +static T strtoT(const char* str, char** endp, int base, int& error) { // validate base if (base != 0 && (base < 2 || base > 36)) { - errno = EINVAL; + error = EINVAL; return 0; } - // parse character to its value in base - // if digit is not of base, return -1 - auto get_base_digit = [](char c, int base) -> int - { - int digit = -1; - if (isdigit(c)) - digit = c - '0'; - else if (isalpha(c)) - digit = 10 + tolower(c) - 'a'; - if (digit < base) - return digit; - return -1; - }; - // skip whitespace while (isspace(*str)) str++; - // get sign and skip in + // get sign and skip it bool negative = (*str == '-'); if (*str == '-' || *str == '+') str++; @@ -100,52 +132,28 @@ static T strtoT(const char* str, char** endp, int base) { if (endp) *endp = const_cast(str); - errno = EINVAL; + error = EINVAL; return 0; } // remove "0x" prefix from hexadecimal - if (base == 16) - { - if (strncasecmp(str, "0x", 2) == 0) - str += 2; - } - - // limits of type T - constexpr T max_val = BAN::numeric_limits::max(); - constexpr T min_val = BAN::is_signed_v ? -max_val - 1 : 0; + if (base == 16 && strncasecmp(str, "0x", 2) == 0 && get_base_digit(str[2], base) != -1) + str += 2; bool overflow = false; T result = 0; + // calculate the value of the number in string while (!overflow) { int digit = get_base_digit(*str, base); if (digit == -1) break; + str++; - // check for overflow - if (negative) - { - if (result < min_val / base) - overflow = true; - else if (result * base < min_val + digit) - overflow = true; - } - else - { - if (result > max_val / base) - overflow = true; - else if (result * base > max_val - digit) - overflow = true; - } - - // calculate result's next step and move to next character + overflow = will_digit_append_overflow(negative, result, digit, base); if (!overflow) - { result = result * base + (negative ? -digit : digit); - str++; - } } // save endp if asked @@ -159,11 +167,164 @@ static T strtoT(const char* str, char** endp, int base) // return error on overflow if (overflow) { - errno = ERANGE; - return negative ? min_val : max_val; + error = ERANGE; + if constexpr(BAN::is_unsigned_v) + return BAN::numeric_limits::max(); + return negative ? BAN::numeric_limits::min() : BAN::numeric_limits::max(); } - return negative ? -result : result; + return result; +} + +template +static T strtoT(const char* str, char** endp, int& error) +{ + // find nan end including possible n-char-sequence + auto get_nan_end = [](const char* str) -> const char* + { + ASSERT(strcasecmp(str, "nan") == 0); + if (str[3] != '(') + return str + 3; + for (size_t i = 4; isalnum(str[i]) || str[i] == '_'; i++) + if (str[i] == ')') + return str + i + 1; + return str + 3; + }; + + // skip whitespace + while (isspace(*str)) + str++; + + // get sign and skip it + bool negative = (*str == '-'); + if (*str == '-' || *str == '+') + str++; + + // check for infinity or nan + { + T result = 0; + + if (strncasecmp(str, "inf", 3) == 0) + { + result = BAN::numeric_limits::infinity(); + str += strncasecmp(str, "infinity", 8) ? 3 : 8; + } + else if (strncasecmp(str, "nan", 3) == 0) + { + result = BAN::numeric_limits::quiet_NaN(); + str = get_nan_end(str); + } + + if (result != 0) + { + if (endp) + *endp = const_cast(str); + return negative ? -result : result; + } + } + + // no conversion can be performed -- not ([digit] || .[digit]) + if (!(isdigit(*str) || (str[0] == '.' && isdigit(str[1])))) + { + error = EINVAL; + return 0; + } + + int base = 10; + int exponent = 0; + int exponents_per_digit = 1; + + // check whether we have base 16 value -- (0x[xdigit] || 0x.[xdigit]) + if (strncasecmp(str, "0x", 2) == 0 && (isxdigit(str[2]) || (str[2] == '.' && isxdigit(str[3])))) + { + base = 16; + exponents_per_digit = 4; + str += 2; + } + + // parse whole part + T result = 0; + T multiplier = 1; + while (true) + { + int digit = get_base_digit(*str, base); + if (digit == -1) + break; + str++; + + if (result) + exponent += exponents_per_digit; + if (digit) + result += multiplier * digit; + if (result) + multiplier /= base; + } + + if (*str == '.') + str++; + + while (true) + { + int digit = get_base_digit(*str, base); + if (digit == -1) + break; + str++; + + if (result == 0) + exponent -= exponents_per_digit; + if (digit) + result += multiplier * digit; + if (result) + multiplier /= base; + } + + if (tolower(*str) == (base == 10 ? 'e' : 'p')) + { + char* maybe_end = nullptr; + int exp_error = 0; + + int extra_exponent = strtoT(str + 1, &maybe_end, 10, exp_error); + if (exp_error != EINVAL) + { + if (exp_error == ERANGE || will_addition_overflow(exponent, extra_exponent)) + exponent = negative ? BAN::numeric_limits::min() : BAN::numeric_limits::max(); + else + exponent += extra_exponent; + str = maybe_end; + } + } + + if (endp) + *endp = const_cast(str); + + // no over/underflow can happed with zero + if (result == 0) + return 0; + + const int max_exponent = (base == 10) ? BAN::numeric_limits::max_exponent10() : BAN::numeric_limits::max_exponent2(); + if (exponent > max_exponent) + { + error = ERANGE; + result = BAN::numeric_limits::infinity(); + return negative ? -result : result; + } + + const int min_exponent = (base == 10) ? BAN::numeric_limits::min_exponent10() : BAN::numeric_limits::min_exponent2(); + if (exponent < min_exponent) + { + error = ERANGE; + result = 0; + return negative ? -result : result; + } + + if (exponent) + result *= BAN::Math::pow((base == 10) ? 10 : 2, exponent); + return result; +} + +double atof(const char* str) +{ + return strtod(str, nullptr); } int atoi(const char* str) @@ -181,24 +342,39 @@ long long atoll(const char* str) return strtoll(str, nullptr, 10); } +float strtof(const char* __restrict str, char** __restrict endp) +{ + return strtoT(str, endp, errno); +} + +double strtod(const char* __restrict str, char** __restrict endp) +{ + return strtoT(str, endp, errno); +} + +long double strtold(const char* __restrict str, char** __restrict endp) +{ + return strtoT(str, endp, errno); +} + long strtol(const char* __restrict str, char** __restrict endp, int base) { - return strtoT(str, endp, base); + return strtoT(str, endp, base, errno); } long long strtoll(const char* __restrict str, char** __restrict endp, int base) { - return strtoT(str, endp, base); + return strtoT(str, endp, base, errno); } unsigned long strtoul(const char* __restrict str, char** __restrict endp, int base) { - return strtoT(str, endp, base); + return strtoT(str, endp, base, errno); } unsigned long long strtoull(const char* __restrict str, char** __restrict endp, int base) { - return strtoT(str, endp, base); + return strtoT(str, endp, base, errno); } char* getenv(const char* name)