diff --git a/BAN/include/BAN/Sort.h b/BAN/include/BAN/Sort.h index e9a91e2e..8182bf0f 100644 --- a/BAN/include/BAN/Sort.h +++ b/BAN/include/BAN/Sort.h @@ -44,7 +44,7 @@ namespace BAN::sort template> void quick_sort(It begin, It end, Comp comp = {}) { - if (begin == end || next(begin, 1) == end) + if (distance(begin, end) <= 1) return; It mid = detail::partition(begin, end, comp); quick_sort(begin, mid, comp); @@ -54,7 +54,7 @@ namespace BAN::sort template> void insertion_sort(It begin, It end, Comp comp = {}) { - if (begin == end || next(begin, 1) == end) + if (distance(begin, end) <= 1) return; for (It it1 = next(begin, 1); it1 != end; ++it1) { @@ -66,39 +66,87 @@ namespace BAN::sort } } + namespace detail + { + + template + void push_heap(It begin, size_t hole_index, size_t top_index, typename It::value_type value, Comp comp) + { + size_t parent = (hole_index - 1) / 2; + while (hole_index > top_index && comp(*next(begin, parent), value)) + { + *next(begin, hole_index) = move(*next(begin, parent)); + hole_index = parent; + parent = (hole_index - 1) / 2; + } + *next(begin, hole_index) = move(value); + } + + template + void adjust_heap(It begin, size_t hole_index, size_t len, typename It::value_type value, Comp comp) + { + const size_t top_index = hole_index; + size_t child = hole_index; + while (child < (len - 1) / 2) + { + child = 2 * (child + 1); + if (comp(*next(begin, child), *next(begin, child - 1))) + child--; + *next(begin, hole_index) = move(*next(begin, child)); + hole_index = child; + } + if (len % 2 == 0 && child == (len - 2) / 2) + { + child = 2 * (child + 1); + *next(begin, hole_index) = move(*next(begin, child - 1)); + hole_index = child - 1; + } + push_heap(begin, hole_index, top_index, move(value), comp); + } + + } + + template> + void make_heap(It begin, It end, Comp comp = {}) + { + const size_t len = distance(begin, end); + if (len <= 1) + return; + + size_t parent = (len - 2) / 2; + while (true) + { + detail::adjust_heap(begin, parent, len, move(*next(begin, parent)), comp); + + if (parent == 0) + break; + + parent--; + } + } + + template> + void sort_heap(It begin, It end, Comp comp = {}) + { + const size_t len = distance(begin, end); + if (len <= 1) + return; + + size_t last = len; + while (last > 1) + { + last--; + typename It::value_type x = move(*next(begin, last)); + *next(begin, last) = move(*begin); + detail::adjust_heap(begin, 0, last, move(x), comp); + } + } + template> void heap_sort(It begin, It end, Comp comp = {}) { - if (begin == end || next(begin, 1) == end) - return; - - It start = next(begin, distance(begin, end) / 2); - - while (prev(end, 1) != begin) - { - if (start != begin) - --start; - else - swap(*(--end), *begin); - - It root = start; - while (true) - { - size_t left_child = 2 * distance(begin, root) + 1; - if (left_child >= distance(begin, end)) - break; - - It child = next(begin, left_child); - if (next(child, 1) != end && comp(*child, *next(child, 1))) - ++child; - - if (!comp(*root, *child)) - break; - - swap(*root, *child); - root = child; - } - } + make_heap(begin, end, comp); + sort_heap(begin, end, comp); } namespace detail @@ -107,7 +155,7 @@ namespace BAN::sort template void intro_sort_impl(It begin, It end, size_t max_depth, Comp comp) { - if (distance(begin, end) < 16) + if (distance(begin, end) <= 16) return insertion_sort(begin, end, comp); if (max_depth == 0) return heap_sort(begin, end, comp); @@ -121,14 +169,16 @@ namespace BAN::sort template> void intro_sort(It begin, It end, Comp comp = {}) { - size_t max_depth = Math::ilog2(distance(begin, end)); - detail::intro_sort_impl(begin, end, max_depth, comp); + const size_t len = distance(begin, end); + if (len <= 1) + return; + detail::intro_sort_impl(begin, end, 2 * Math::ilog2(len), comp); } template> void sort(It begin, It end, Comp comp = {}) { - return sort::intro_sort(begin, end, comp); + return intro_sort(begin, end, comp); } }