BAN: Expose radix sort with user provided buffer

This can be nice if user has memory for a the temporary buffer and
doesnt want the sorting to allocate or be able to fail.

Also counts are now stack allocated, there isn't really any reason to
allocate them on the heap as 256x 64 bit values only adds up to 2 KiB
This commit is contained in:
2026-05-13 04:55:32 +03:00
parent d345f96387
commit 212ab010a5
2 changed files with 50 additions and 15 deletions

View File

@@ -137,27 +137,20 @@ namespace BAN::sort
template<typename It, size_t radix = 256>
requires is_unsigned_v<it_value_type_t<It>> && (radix > 0 && (radix & (radix - 1)) == 0)
BAN::ErrorOr<void> radix_sort(It begin, It end)
void radix_sort(It begin, It end, BAN::Span<it_value_type_t<It>> storage)
{
using value_type = it_value_type_t<It>;
const size_t len = distance(begin, end);
if (len <= 1)
return {};
return;
Vector<value_type> temp;
TRY(temp.resize(len));
Vector<size_t> counts;
TRY(counts.resize(radix));
ASSERT(storage.size() >= len);
constexpr size_t mask = radix - 1;
constexpr size_t shift = detail::lsb_index(radix);
for (size_t s = 0; s < sizeof(value_type) * 8; s += shift)
for (size_t s = 0; s < sizeof(it_value_type_t<It>) * 8; s += shift)
{
for (auto& cnt : counts)
cnt = 0;
size_t counts[radix] {};
for (It it = begin; it != end; ++it)
counts[(*it >> s) & mask]++;
@@ -167,12 +160,27 @@ namespace BAN::sort
for (It it = end; it != begin;)
{
--it;
temp[--counts[(*it >> s) & mask]] = *it;
storage[--counts[(*it >> s) & mask]] = *it;
}
for (size_t j = 0; j < temp.size(); j++)
*next(begin, j) = temp[j];
It it = begin;
for (size_t j = 0; j < storage.size(); j++, ++it)
*it = storage[j];
}
}
template<typename It, size_t radix = 256>
requires is_unsigned_v<it_value_type_t<It>> && (radix > 0 && (radix & (radix - 1)) == 0)
BAN::ErrorOr<void> radix_sort(It begin, It end)
{
const size_t len = distance(begin, end);
if (len <= 1)
return {};
Vector<it_value_type_t<It>> temp;
TRY(temp.resize(len));
radix_sort(begin, end, temp.span());
return {};
}

View File

@@ -39,6 +39,32 @@ bool is_sorted(BAN::Vector<T>& vec)
} \
} while (0)
#define TEST_ALGORITHM_RADIX(ms) do { \
uint64_t duration_us = 0; \
printf("radix with preallocated buffer\n"); \
for (size_t size = 100; duration_us < ms * 1000; size *= 10) { \
BAN::Vector<unsigned> data(size, 0); \
for (auto& val : data) \
val = rand() % 100; \
BAN::Vector<unsigned> temp(size); \
uint64_t start_ns = CURRENT_NS(); \
BAN::sort::radix_sort(data.begin(), data.end(), temp.span()); \
uint64_t stop_ns = CURRENT_NS(); \
if (!is_sorted(data)) { \
printf(" \e[31mFAILED!\e[m\n"); \
break; \
} \
duration_us = (stop_ns - start_ns) / 1'000; \
printf(" %5d.%03d ms (%zu)\n", \
(int)(duration_us / 1000), \
(int)(duration_us % 1000), \
size \
); \
} \
} while (0)
#define TEST_ALGORITHM_QSORT(ms) do { \
uint64_t duration_us = 0; \
printf("qsort\n"); \
@@ -72,5 +98,6 @@ int main()
TEST_ALGORITHM(100, BAN::sort::intro_sort);
TEST_ALGORITHM(1000, BAN::sort::sort);
TEST_ALGORITHM(1000, BAN::sort::radix_sort);
TEST_ALGORITHM_RADIX(1000);
TEST_ALGORITHM_QSORT(100);
}