diff --git a/BAN/include/BAN/Sort.h b/BAN/include/BAN/Sort.h index b85bda6e80..299eca36c1 100644 --- a/BAN/include/BAN/Sort.h +++ b/BAN/include/BAN/Sort.h @@ -1,8 +1,9 @@ #pragma once +#include #include #include -#include +#include namespace BAN::sort { @@ -175,6 +176,61 @@ namespace BAN::sort detail::intro_sort_impl(begin, end, 2 * Math::ilog2(len), comp); } + namespace detail + { + + template + consteval T lsb_index(T value) + { + for (T result = 0;; result++) + if (value & (1 << result)) + return result; + } + + } + + template + requires is_unsigned_v && (radix > 0 && (radix & (radix - 1)) == 0) + BAN::ErrorOr radix_sort(It begin, It end) + { + using value_type = typename It::value_type; + + const size_t len = distance(begin, end); + if (len <= 1) + return {}; + + Vector temp; + TRY(temp.resize(len)); + + Vector counts; + TRY(counts.resize(radix)); + + constexpr size_t mask = radix - 1; + constexpr size_t shift = detail::lsb_index(radix); + + for (size_t s = 0; s < sizeof(value_type) * 8; s += shift) + { + for (auto& cnt : counts) + cnt = 0; + for (It it = begin; it != end; ++it) + counts[(*it >> s) & mask]++; + + for (size_t i = 0; i < radix - 1; i++) + counts[i + 1] += counts[i]; + + for (It it = end; it != begin;) + { + --it; + temp[--counts[(*it >> s) & mask]] = *it; + } + + for (size_t j = 0; j < temp.size(); j++) + *next(begin, j) = temp[j]; + } + + return {}; + } + template> void sort(It begin, It end, Comp comp = {}) {