diff --git a/BAN/include/BAN/Sort.h b/BAN/include/BAN/Sort.h index 4f3a3d7a..065eeb3f 100644 --- a/BAN/include/BAN/Sort.h +++ b/BAN/include/BAN/Sort.h @@ -21,26 +21,47 @@ namespace BAN::sort namespace detail { - template - It partition(It begin, It end, Comp comp) + template + struct partition_pair { - It pivot = prev(end, 1); + It lt; + It gt; + }; - It it1 = begin; - for (It it2 = begin; it2 != pivot; ++it2) + template + partition_pair partition(It begin, It end, Comp comp) + { + It pivot = next(begin, distance(begin, end) / 2); + + It lt = begin; + It eq = begin; + It gt = end; + + while (eq != gt) { - if (comp(*it2, *pivot)) + if (comp(*eq, *pivot)) { - swap(*it1, *it2); - ++it1; + swap(*eq, *lt); + if (pivot == lt) + pivot = eq; + ++lt; + ++eq; + } + else if (comp(*pivot, *eq)) + { + --gt; + swap(*eq, *gt); + if (pivot == gt) + pivot = eq; + } + else + { + ++eq; } } - swap(*it1, *pivot); - - return it1; + return { lt, gt }; } - } template>> @@ -48,9 +69,9 @@ namespace BAN::sort { if (distance(begin, end) <= 1) return; - It mid = detail::partition(begin, end, comp); - quick_sort(begin, mid, comp); - quick_sort(++mid, end, comp); + const auto [lt, gt] = detail::partition(begin, end, comp); + quick_sort(begin, lt, comp); + quick_sort(gt, end, comp); } template>> @@ -85,9 +106,9 @@ namespace BAN::sort return insertion_sort(begin, end, comp); if (max_depth == 0) return heap_sort(begin, end, comp); - It mid = detail::partition(begin, end, comp); - intro_sort_impl(begin, mid, max_depth - 1, comp); - intro_sort_impl(++mid, end, max_depth - 1, comp); + const auto [lt, gt] = detail::partition(begin, end, comp); + intro_sort_impl(begin, lt, max_depth - 1, comp); + intro_sort_impl(gt, end, max_depth - 1, comp); } }