diff --git a/BAN/include/BAN/Sort.h b/BAN/include/BAN/Sort.h index 065eeb3f..0ac8c4a4 100644 --- a/BAN/include/BAN/Sort.h +++ b/BAN/include/BAN/Sort.h @@ -137,27 +137,20 @@ namespace BAN::sort template requires is_unsigned_v> && (radix > 0 && (radix & (radix - 1)) == 0) - BAN::ErrorOr radix_sort(It begin, It end) + void radix_sort(It begin, It end, BAN::Span> storage) { - using value_type = it_value_type_t; - const size_t len = distance(begin, end); if (len <= 1) - return {}; + return; - Vector temp; - TRY(temp.resize(len)); - - Vector counts; - TRY(counts.resize(radix)); + ASSERT(storage.size() >= len); constexpr size_t mask = radix - 1; constexpr size_t shift = detail::lsb_index(radix); - for (size_t s = 0; s < sizeof(value_type) * 8; s += shift) + for (size_t s = 0; s < sizeof(it_value_type_t) * 8; s += shift) { - for (auto& cnt : counts) - cnt = 0; + size_t counts[radix] {}; for (It it = begin; it != end; ++it) counts[(*it >> s) & mask]++; @@ -167,12 +160,27 @@ namespace BAN::sort for (It it = end; it != begin;) { --it; - temp[--counts[(*it >> s) & mask]] = *it; + storage[--counts[(*it >> s) & mask]] = *it; } - for (size_t j = 0; j < temp.size(); j++) - *next(begin, j) = temp[j]; + It it = begin; + for (size_t j = 0; j < storage.size(); j++, ++it) + *it = storage[j]; } + } + + template + requires is_unsigned_v> && (radix > 0 && (radix & (radix - 1)) == 0) + BAN::ErrorOr radix_sort(It begin, It end) + { + const size_t len = distance(begin, end); + if (len <= 1) + return {}; + + Vector> temp; + TRY(temp.resize(len)); + + radix_sort(begin, end, temp.span()); return {}; } diff --git a/userspace/tests/test-sort/main.cpp b/userspace/tests/test-sort/main.cpp index 6adbf753..523c75d8 100644 --- a/userspace/tests/test-sort/main.cpp +++ b/userspace/tests/test-sort/main.cpp @@ -39,6 +39,32 @@ bool is_sorted(BAN::Vector& vec) } \ } while (0) + +#define TEST_ALGORITHM_RADIX(ms) do { \ + uint64_t duration_us = 0; \ + printf("radix with preallocated buffer\n"); \ + for (size_t size = 100; duration_us < ms * 1000; size *= 10) { \ + BAN::Vector data(size, 0); \ + for (auto& val : data) \ + val = rand() % 100; \ + BAN::Vector temp(size); \ + uint64_t start_ns = CURRENT_NS(); \ + BAN::sort::radix_sort(data.begin(), data.end(), temp.span()); \ + uint64_t stop_ns = CURRENT_NS(); \ + if (!is_sorted(data)) { \ + printf(" \e[31mFAILED!\e[m\n"); \ + break; \ + } \ + duration_us = (stop_ns - start_ns) / 1'000; \ + printf(" %5d.%03d ms (%zu)\n", \ + (int)(duration_us / 1000), \ + (int)(duration_us % 1000), \ + size \ + ); \ + } \ + } while (0) + + #define TEST_ALGORITHM_QSORT(ms) do { \ uint64_t duration_us = 0; \ printf("qsort\n"); \ @@ -72,5 +98,6 @@ int main() TEST_ALGORITHM(100, BAN::sort::intro_sort); TEST_ALGORITHM(1000, BAN::sort::sort); TEST_ALGORITHM(1000, BAN::sort::radix_sort); + TEST_ALGORITHM_RADIX(1000); TEST_ALGORITHM_QSORT(100); }