Skip to content

Improve performance of unstable sort #104116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
993 changes: 826 additions & 167 deletions library/core/src/slice/sort.rs
Original file line number Diff line number Diff line change
@@ -6,9 +6,10 @@
//! Unstable sorting is compatible with libcore because it doesn't allocate memory, unlike our
//! stable sorting implementation.
use crate::cmp;
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
use crate::ptr;
use core::cmp;
use core::intrinsics;
use core::mem::{self, MaybeUninit, SizedTypeProperties};
use core::ptr;

/// When dropped, copies from `src` into `dest`.
struct CopyOnDrop<T> {
@@ -27,98 +28,6 @@ impl<T> Drop for CopyOnDrop<T> {
}
}

/// Shifts the first element to the right until it encounters a greater or equal element.
fn shift_head<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
// SAFETY: The unsafe operations below involves indexing without a bounds check (by offsetting a
// pointer) and copying memory (`ptr::copy_nonoverlapping`).
//
// a. Indexing:
// 1. We checked the size of the array to >=2.
// 2. All the indexing that we will do is always between {0 <= index < len} at most.
//
// b. Memory copying
// 1. We are obtaining pointers to references which are guaranteed to be valid.
// 2. They cannot overlap because we obtain pointers to difference indices of the slice.
// Namely, `i` and `i-1`.
// 3. If the slice is properly aligned, the elements are properly aligned.
// It is the caller's responsibility to make sure the slice is properly aligned.
//
// See comments below for further detail.
unsafe {
// If the first two elements are out-of-order...
if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) {
// Read the first element into a stack-allocated variable. If a following comparison
// operation panics, `hole` will get dropped and automatically write the element back
// into the slice.
let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(0)));
let v = v.as_mut_ptr();
let mut hole = CopyOnDrop { src: &*tmp, dest: v.add(1) };
ptr::copy_nonoverlapping(v.add(1), v.add(0), 1);

for i in 2..len {
if !is_less(&*v.add(i), &*tmp) {
break;
}

// Move `i`-th element one place to the left, thus shifting the hole to the right.
ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1);
hole.dest = v.add(i);
}
// `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
}
}
}

/// Shifts the last element to the left until it encounters a smaller or equal element.
fn shift_tail<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
// SAFETY: The unsafe operations below involves indexing without a bound check (by offsetting a
// pointer) and copying memory (`ptr::copy_nonoverlapping`).
//
// a. Indexing:
// 1. We checked the size of the array to >= 2.
// 2. All the indexing that we will do is always between `0 <= index < len-1` at most.
//
// b. Memory copying
// 1. We are obtaining pointers to references which are guaranteed to be valid.
// 2. They cannot overlap because we obtain pointers to difference indices of the slice.
// Namely, `i` and `i+1`.
// 3. If the slice is properly aligned, the elements are properly aligned.
// It is the caller's responsibility to make sure the slice is properly aligned.
//
// See comments below for further detail.
unsafe {
// If the last two elements are out-of-order...
if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) {
// Read the last element into a stack-allocated variable. If a following comparison
// operation panics, `hole` will get dropped and automatically write the element back
// into the slice.
let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1)));
let v = v.as_mut_ptr();
let mut hole = CopyOnDrop { src: &*tmp, dest: v.add(len - 2) };
ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1);

for i in (0..len - 2).rev() {
if !is_less(&*tmp, &*v.add(i)) {
break;
}

// Move `i`-th element one place to the right, thus shifting the hole to the left.
ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1);
hole.dest = v.add(i);
}
// `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
}
}
}

/// Partially sorts a slice by shifting several out-of-order elements around.
///
/// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case.
@@ -158,26 +67,35 @@ where
// Swap the found pair of elements. This puts them in correct order.
v.swap(i - 1, i);

if i >= 2 {
// SAFETY: We check the that the slice len is >= 2.
unsafe {
insert_tail(&mut v[..i], is_less);
}
}

// Shift the smaller element to the left.
shift_tail(&mut v[..i], is_less);
if i >= 2 {
// SAFETY: We check the that the slice len is >= 2.
unsafe {
insert_tail(&mut v[..i], is_less);
}
}

// Shift the greater element to the right.
shift_head(&mut v[i..], is_less);
if i < (len - 1) {
// SAFETY: We check the that the slice len is >= 2.
unsafe {
// shift_head(&mut v[i..], is_less);
insert_head(&mut v[i..], is_less);
}
}
}

// Didn't manage to sort the slice in the limited number of steps.
false
}

/// Sorts a slice using insertion sort, which is *O*(*n*^2) worst-case.
fn insertion_sort<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
for i in 1..v.len() {
shift_tail(&mut v[..i + 1], is_less);
}
}

/// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case.
#[cold]
#[unstable(feature = "sort_internals", reason = "internal to sort module", issue = "none")]
@@ -326,8 +244,8 @@ where
unsafe {
// Branchless comparison.
*end_l = i as u8;
end_l = end_l.add(!is_less(&*elem, pivot) as usize);
elem = elem.add(1);
end_l = end_l.offset(!is_less(&*elem, pivot) as isize);
elem = elem.offset(1);
}
}
}
@@ -352,9 +270,9 @@ where
// Plus, `block_r` was asserted to be less than `BLOCK` and `elem` will therefore at most be pointing to the beginning of the slice.
unsafe {
// Branchless comparison.
elem = elem.sub(1);
elem = elem.offset(-1);
*end_r = i as u8;
end_r = end_r.add(is_less(&*elem, pivot) as usize);
end_r = end_r.offset(is_less(&*elem, pivot) as isize);
}
}
}
@@ -365,12 +283,12 @@ where
if count > 0 {
macro_rules! left {
() => {
l.add(usize::from(*start_l))
l.offset(*start_l as isize)
};
}
macro_rules! right {
() => {
r.sub(usize::from(*start_r) + 1)
r.offset(-(*start_r as isize) - 1)
};
}

@@ -398,16 +316,16 @@ where
ptr::copy_nonoverlapping(right!(), left!(), 1);

for _ in 1..count {
start_l = start_l.add(1);
start_l = start_l.offset(1);
ptr::copy_nonoverlapping(left!(), right!(), 1);
start_r = start_r.add(1);
start_r = start_r.offset(1);
ptr::copy_nonoverlapping(right!(), left!(), 1);
}

ptr::copy_nonoverlapping(&tmp, right!(), 1);
mem::forget(tmp);
start_l = start_l.add(1);
start_r = start_r.add(1);
start_l = start_l.offset(1);
start_r = start_r.offset(1);
}
}

@@ -420,15 +338,15 @@ where
// safe. Otherwise, the debug assertions in the `is_done` case guarantee that
// `width(l, r) == block_l + block_r`, namely, that the block sizes have been adjusted to account
// for the smaller number of remaining elements.
l = unsafe { l.add(block_l) };
l = unsafe { l.offset(block_l as isize) };
}

if start_r == end_r {
// All out-of-order elements in the right block were moved. Move to the previous block.

// SAFETY: Same argument as [block-width-guarantee]. Either this is a full block `2*BLOCK`-wide,
// or `block_r` has been adjusted for the last handful of elements.
r = unsafe { r.sub(block_r) };
r = unsafe { r.offset(-(block_r as isize)) };
}

if is_done {
@@ -457,9 +375,9 @@ where
// - `offsets_l` contains valid offsets into `v` collected during the partitioning of
// the last block, so the `l.offset` calls are valid.
unsafe {
end_l = end_l.sub(1);
ptr::swap(l.add(usize::from(*end_l)), r.sub(1));
r = r.sub(1);
end_l = end_l.offset(-1);
ptr::swap(l.offset(*end_l as isize), r.offset(-1));
r = r.offset(-1);
}
}
width(v.as_mut_ptr(), r)
@@ -470,9 +388,9 @@ where
while start_r < end_r {
// SAFETY: See the reasoning in [remaining-elements-safety].
unsafe {
end_r = end_r.sub(1);
ptr::swap(l, r.sub(usize::from(*end_r) + 1));
l = l.add(1);
end_r = end_r.offset(-1);
ptr::swap(l, r.offset(-(*end_r as isize) - 1));
l = l.offset(1);
}
}
width(v.as_mut_ptr(), l)
@@ -659,6 +577,12 @@ where

let len = v.len();

if len <= MAX_INSERTION {
// It's a logic bug if this get's called on slice that would be small-sorted.
debug_assert!(false);
return (10, false);
}

// Three indices near which we are going to choose a pivot.
let mut a = len / 4 * 1;
let mut b = len / 4 * 2;
@@ -667,45 +591,46 @@ where
// Counts the total number of swaps we are about to perform while sorting indices.
let mut swaps = 0;

if len >= 8 {
// Swaps indices so that `v[a] <= v[b]`.
// SAFETY: `len >= 8` so there are at least two elements in the neighborhoods of
// `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in
// corresponding calls to `sort3` with valid 3-item neighborhoods around each
// pointer, which in turn means the calls to `sort2` are done with valid
// references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap`
// call.
let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) {
ptr::swap(a, b);
swaps += 1;
}
};

// Swaps indices so that `v[a] <= v[b] <= v[c]`.
let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
sort2(a, b);
sort2(b, c);
sort2(a, b);
};
// Swaps indices so that `v[a] <= v[b]`.
// SAFETY: `len > 20` so there are at least two elements in the neighborhoods of
// `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in
// corresponding calls to `sort3` with valid 3-item neighborhoods around each
// pointer, which in turn means the calls to `sort2` are done with valid
// references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap`
// call.
let mut sort2_idx = |a: &mut usize, b: &mut usize| unsafe {
let should_swap = is_less(v.get_unchecked(*b), v.get_unchecked(*a));

// Generate branchless cmov code, it's not super important but reduces BHB and BTB pressure.
let tmp_idx = if should_swap { *a } else { *b };
*a = if should_swap { *b } else { *a };
*b = tmp_idx;
swaps += should_swap as usize;
};

if len >= SHORTEST_MEDIAN_OF_MEDIANS {
// Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`.
let mut sort_adjacent = |a: &mut usize| {
let tmp = *a;
sort3(&mut (tmp - 1), a, &mut (tmp + 1));
};
// Swaps indices so that `v[a] <= v[b] <= v[c]`.
let mut sort3_idx = |a: &mut usize, b: &mut usize, c: &mut usize| {
sort2_idx(a, b);
sort2_idx(b, c);
sort2_idx(a, b);
};

// Find medians in the neighborhoods of `a`, `b`, and `c`.
sort_adjacent(&mut a);
sort_adjacent(&mut b);
sort_adjacent(&mut c);
}
if len >= SHORTEST_MEDIAN_OF_MEDIANS {
// Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`.
let mut sort_adjacent = |a: &mut usize| {
let tmp = *a;
sort3_idx(&mut (tmp - 1), a, &mut (tmp + 1));
};

// Find the median among `a`, `b`, and `c`.
sort3(&mut a, &mut b, &mut c);
// Find medians in the neighborhoods of `a`, `b`, and `c`.
sort_adjacent(&mut a);
sort_adjacent(&mut b);
sort_adjacent(&mut c);
}

// Find the median among `a`, `b`, and `c`.
sort3_idx(&mut a, &mut b, &mut c);

if swaps < MAX_SWAPS {
(b, swaps == 0)
} else {
@@ -716,6 +641,9 @@ where
}
}

// Slices of up to this length get sorted using insertion sort.
const MAX_INSERTION: usize = 20;

/// Sorts `v` recursively.
///
/// If the slice had a predecessor in the original array, it is specified as `pred`.
@@ -726,9 +654,6 @@ fn recurse<'a, T, F>(mut v: &'a mut [T], is_less: &mut F, mut pred: Option<&'a T
where
F: FnMut(&T, &T) -> bool,
{
// Slices of up to this length get sorted using insertion sort.
const MAX_INSERTION: usize = 20;

// True if the last partitioning was reasonably balanced.
let mut was_balanced = true;
// True if the last partitioning didn't shuffle elements (the slice was already partitioned).
@@ -737,9 +662,9 @@ where
loop {
let len = v.len();

// Very short slices get sorted using insertion sort.
if len <= MAX_INSERTION {
insertion_sort(v, is_less);
// println!("len: {len}");

if sort_small(v, is_less) {
return;
}

@@ -807,13 +732,140 @@ where
}
}

/// Sorts `v` using strategies optimized for small sizes.
pub fn sort_small<T, F>(v: &mut [T], is_less: &mut F) -> bool
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();

const MAX_BRANCHLESS_SMALL_SORT: usize = 40;

if len < 2 {
return true;
}

if qualifies_for_branchless_sort::<T>() && len <= MAX_BRANCHLESS_SMALL_SORT {
if len < 8 {
// For small sizes it's better to just sort. The worst case 7, will only go from 6 to 8
// comparisons for already sorted inputs.
let start = if len >= 4 {
// SAFETY: We just checked the len.
unsafe {
sort4_optimal(&mut v[0..4], is_less);
}
4
} else {
1
};

insertion_sort_shift_left(v, start, is_less);

return true;
}

// Pattern analyze to minimize comparison count for already sorted or reversed inputs.
// For larger inputs pdqsort pattern analysis will be used.

let mut start = len - 1;
if start > 0 {
start -= 1;
unsafe {
if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) {
while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) {
start -= 1;
}
v[start..len].reverse();
} else {
while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1))
{
start -= 1;
}
}
}
}

debug_assert!(start < len);

let already_sorted = len - start;

if already_sorted <= 6 {
// SAFETY: We check the len.
unsafe {
match len {
8..=15 => {
sort8_plus(v, is_less);
}
16..=31 => {
sort16_plus(v, is_less);
}
32..=40 => {
sort32_plus(v, is_less);
}
_ => {
unreachable!()
}
}
}
} else {
// Potentially highly or fully sorted. We know that already_sorted >= 7. and len >= 8.
// That leaves the range of start <= 33.
debug_assert!(start <= 33);

if start == 0 {
return true;
} else if start <= 3 {
insertion_sort_shift_right(v, start, is_less);
return true;
}

match start {
4..=7 => {
// SAFETY: We just checked start >= 4.
unsafe {
sort4_plus(&mut v[0..start], is_less);
}
}
8..=15 => {
// SAFETY: We just checked start >= 8.
unsafe {
sort8_plus(&mut v[0..start], is_less);
}
}
16..=33 => {
// SAFETY: We just checked start >= 16.
unsafe {
sort16_plus(&mut v[0..start], is_less);
}
}
_ => unreachable!(),
}

// The longest possible shortest side is len == 40, start == 20 -> 20.
let mut swap = mem::MaybeUninit::<[T; 20]>::uninit();
let swap_ptr = swap.as_mut_ptr() as *mut T;

// SAFETY: swap is long enough and both sides are len >= 1.
unsafe {
merge(v, start, swap_ptr, is_less);
}
}
return true;
} else if len <= MAX_INSERTION {
insertion_sort_shift_left(v, 1, is_less);
return true;
}

false
}

/// Sorts `v` using pattern-defeating quicksort, which is *O*(*n* \* log(*n*)) worst-case.
pub fn quicksort<T, F>(v: &mut [T], mut is_less: F)
where
F: FnMut(&T, &T) -> bool,
{
// Sorting has no meaningful behavior on zero-sized types.
if T::IS_ZST {
if mem::size_of::<T>() == 0 {
return;
}

@@ -823,6 +875,609 @@ where
recurse(v, &mut is_less, None, limit);
}

// --- Insertion sorts ---

// TODO unified sort module.

// When dropped, copies from `src` into `dest`.
struct InsertionHole<T> {
src: *const T,
dest: *mut T,
}

impl<T> Drop for InsertionHole<T> {
fn drop(&mut self) {
unsafe {
ptr::copy_nonoverlapping(self.src, self.dest, 1);
}
}
}

/// Inserts `v[v.len() - 1]` into pre-sorted sequence `v[..v.len() - 1]` so that whole `v[..]`
/// becomes sorted.
unsafe fn insert_tail<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
debug_assert!(v.len() >= 2);

let arr_ptr = v.as_mut_ptr();
let i = v.len() - 1;

// SAFETY: caller must ensure v is at least len 2.
unsafe {
// See insert_head which talks about why this approach is beneficial.
let i_ptr = arr_ptr.add(i);

// It's important that we use i_ptr here. If this check is positive and we continue,
// We want to make sure that no other copy of the value was seen by is_less.
// Otherwise we would have to copy it back.
if is_less(&*i_ptr, &*i_ptr.sub(1)) {
// It's important, that we use tmp for comparison from now on. As it is the value that
// will be copied back. And notionally we could have created a divergence if we copy
// back the wrong value.
let tmp = mem::ManuallyDrop::new(ptr::read(i_ptr));
// Intermediate state of the insertion process is always tracked by `hole`, which
// serves two purposes:
// 1. Protects integrity of `v` from panics in `is_less`.
// 2. Fills the remaining hole in `v` in the end.
//
// Panic safety:
//
// If `is_less` panics at any point during the process, `hole` will get dropped and
// fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it
// initially held exactly once.
let mut hole = InsertionHole { src: &*tmp, dest: i_ptr.sub(1) };
ptr::copy_nonoverlapping(hole.dest, i_ptr, 1);

// SAFETY: We know i is at least 1.
for j in (0..(i - 1)).rev() {
let j_ptr = arr_ptr.add(j);
if !is_less(&*tmp, &*j_ptr) {
break;
}

ptr::copy_nonoverlapping(j_ptr, hole.dest, 1);
hole.dest = j_ptr;
}
// `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
}
}
}

/// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted.
///
/// This is the integral subroutine of insertion sort.
unsafe fn insert_head<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
debug_assert!(v.len() >= 2);

unsafe {
if is_less(v.get_unchecked(1), v.get_unchecked(0)) {
let arr_ptr = v.as_mut_ptr();

// There are three ways to implement insertion here:
//
// 1. Swap adjacent elements until the first one gets to its final destination.
// However, this way we copy data around more than is necessary. If elements are big
// structures (costly to copy), this method will be slow.
//
// 2. Iterate until the right place for the first element is found. Then shift the
// elements succeeding it to make room for it and finally place it into the
// remaining hole. This is a good method.
//
// 3. Copy the first element into a temporary variable. Iterate until the right place
// for it is found. As we go along, copy every traversed element into the slot
// preceding it. Finally, copy data from the temporary variable into the remaining
// hole. This method is very good. Benchmarks demonstrated slightly better
// performance than with the 2nd method.
//
// All methods were benchmarked, and the 3rd showed best results. So we chose that one.
let tmp = mem::ManuallyDrop::new(ptr::read(arr_ptr));

// Intermediate state of the insertion process is always tracked by `hole`, which
// serves two purposes:
// 1. Protects integrity of `v` from panics in `is_less`.
// 2. Fills the remaining hole in `v` in the end.
//
// Panic safety:
//
// If `is_less` panics at any point during the process, `hole` will get dropped and
// fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it
// initially held exactly once.
let mut hole = InsertionHole { src: &*tmp, dest: arr_ptr.add(1) };
ptr::copy_nonoverlapping(arr_ptr.add(1), arr_ptr.add(0), 1);

for i in 2..v.len() {
if !is_less(&v.get_unchecked(i), &*tmp) {
break;
}
ptr::copy_nonoverlapping(arr_ptr.add(i), arr_ptr.add(i - 1), 1);
hole.dest = arr_ptr.add(i);
}
// `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
}
}
}

/// Sort `v` assuming `v[..offset]` is already sorted.
///
/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
/// performance impact. Even improving performance in some cases.
#[inline(never)]
fn insertion_sort_shift_left<T, F>(v: &mut [T], offset: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();

// This is a logic but not a safety bug.
debug_assert!(offset != 0 && offset <= len);

if intrinsics::unlikely(((len < 2) as u8 + (offset == 0) as u8) != 0) {
return;
}

// Shift each element of the unsorted region v[i..] as far left as is needed to make v sorted.
for i in offset..len {
// SAFETY: we tested that len >= 2.
unsafe {
// Maybe use insert_head here and avoid additional code.
insert_tail(&mut v[..=i], is_less);
}
}
}

/// Sort `v` assuming `v[offset..]` is already sorted.
///
/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
/// performance impact. Even improving performance in some cases.
#[inline(never)]
fn insertion_sort_shift_right<T, F>(v: &mut [T], offset: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();

// This is a logic but not a safety bug.
debug_assert!(offset != 0 && offset <= len);

if intrinsics::unlikely(((len < 2) as u8 + (offset == 0) as u8) != 0) {
return;
}

// Shift each element of the unsorted region v[..i] as far left as is needed to make v sorted.
for i in (0..offset).rev() {
// We ensured that the slice length is always at least 2 long.
// We know that start_found will be at least one less than end,
// and the range is exclusive. Which gives us i always <= (end - 2).
unsafe {
insert_head(&mut v[i..len], is_less);
}
}
}

/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
/// stores the result into `v[..]`.
///
/// # Safety
///
/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough
/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type.
///
/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
/// performance impact.
#[inline(never)]
#[cfg(not(no_global_oom_handling))]
unsafe fn merge<T, F>(v: &mut [T], mid: usize, buf: *mut T, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
let len = v.len();
let arr_ptr = v.as_mut_ptr();
let (v_mid, v_end) = unsafe { (arr_ptr.add(mid), arr_ptr.add(len)) };

// The merge process first copies the shorter run into `buf`. Then it traces the newly copied
// run and the longer run forwards (or backwards), comparing their next unconsumed elements and
// copying the lesser (or greater) one into `v`.
//
// As soon as the shorter run is fully consumed, the process is done. If the longer run gets
// consumed first, then we must copy whatever is left of the shorter run into the remaining
// hole in `v`.
//
// Intermediate state of the process is always tracked by `hole`, which serves two purposes:
// 1. Protects integrity of `v` from panics in `is_less`.
// 2. Fills the remaining hole in `v` if the longer run gets consumed first.
//
// Panic safety:
//
// If `is_less` panics at any point during the process, `hole` will get dropped and fill the
// hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
// object it initially held exactly once.
let mut hole;

if mid <= len - mid {
// The left run is shorter.
unsafe {
ptr::copy_nonoverlapping(arr_ptr, buf, mid);
hole = MergeHole { start: buf, end: buf.add(mid), dest: arr_ptr };
}

// Initially, these pointers point to the beginnings of their arrays.
let left = &mut hole.start;
let mut right = v_mid;
let out = &mut hole.dest;

while *left < hole.end && right < v_end {
// Consume the lesser side.
// If equal, prefer the left run to maintain stability.
unsafe {
let to_copy = if is_less(&*right, &**left) {
get_and_increment(&mut right)
} else {
get_and_increment(left)
};
ptr::copy_nonoverlapping(to_copy, get_and_increment(out), 1);
}
}
} else {
// The right run is shorter.
unsafe {
ptr::copy_nonoverlapping(v_mid, buf, len - mid);
hole = MergeHole { start: buf, end: buf.add(len - mid), dest: v_mid };
}

// Initially, these pointers point past the ends of their arrays.
let left = &mut hole.dest;
let right = &mut hole.end;
let mut out = v_end;

while arr_ptr < *left && buf < *right {
// Consume the greater side.
// If equal, prefer the right run to maintain stability.
unsafe {
let to_copy = if is_less(&*right.offset(-1), &*left.offset(-1)) {
decrement_and_get(left)
} else {
decrement_and_get(right)
};
ptr::copy_nonoverlapping(to_copy, decrement_and_get(&mut out), 1);
}
}
}
// Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
// it will now be copied into the hole in `v`.

unsafe fn get_and_increment<T>(ptr: &mut *mut T) -> *mut T {
let old = *ptr;
*ptr = unsafe { ptr.offset(1) };
old
}

unsafe fn decrement_and_get<T>(ptr: &mut *mut T) -> *mut T {
*ptr = unsafe { ptr.offset(-1) };
*ptr
}

// When dropped, copies the range `start..end` into `dest..`.
struct MergeHole<T> {
start: *mut T,
end: *mut T,
dest: *mut T,
}

impl<T> Drop for MergeHole<T> {
fn drop(&mut self) {
// `T` is not a zero-sized type, and these are pointers into a slice's elements.
unsafe {
let len = self.end.sub_ptr(self.start);
ptr::copy_nonoverlapping(self.start, self.dest, len);
}
}
}
}

// --- Branchless sorting (less branches not zero) ---

#[inline]
fn qualifies_for_branchless_sort<T>() -> bool {
// This is a heuristic, and as such it will guess wrong from time to time. The two parts broken
// down:
//
// - Type size: Large types are more expensive to move and the time won avoiding branches can be
// offset by the increased cost of moving the values.
//
// In contrast to stable sort, using sorting networks here, allows to do fewer comparisons.
mem::size_of::<T>() <= mem::size_of::<[usize; 4]>()
}

/// Swap two values in array pointed to by a_ptr and b_ptr if b is less than a.
#[inline]
unsafe fn branchless_swap<T>(a_ptr: *mut T, b_ptr: *mut T, should_swap: bool) {
// This is a branchless version of swap if.
// The equivalent code with a branch would be:
//
// if should_swap {
// ptr::swap_nonoverlapping(a_ptr, b_ptr, 1);
// }

// Give ourselves some scratch space to work with.
// We do not have to worry about drops: `MaybeUninit` does nothing when dropped.
let mut tmp = mem::MaybeUninit::<T>::uninit();

// The goal is to generate cmov instructions here.
let a_swap_ptr = if should_swap { b_ptr } else { a_ptr };
let b_swap_ptr = if should_swap { a_ptr } else { b_ptr };

// SAFETY: the caller must guarantee that `a_ptr` and `b_ptr` are valid for writes
// and properly aligned, and part of the same allocation, and do not alias.
unsafe {
ptr::copy_nonoverlapping(b_swap_ptr, tmp.as_mut_ptr(), 1);
ptr::copy(a_swap_ptr, a_ptr, 1);
ptr::copy_nonoverlapping(tmp.as_ptr(), b_ptr, 1);
}
}

/// Swap two values in array pointed to by a_ptr and b_ptr if b is less than a.
#[inline]
unsafe fn swap_if_less<T, F>(arr_ptr: *mut T, a: usize, b: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: the caller must guarantee that `a` and `b` each added to `arr_ptr` yield valid
// pointers into `arr_ptr`. and properly aligned, and part of the same allocation, and do not
// alias. `a` and `b` must be different numbers.
unsafe {
debug_assert!(a != b);

let a_ptr = arr_ptr.add(a);
let b_ptr = arr_ptr.add(b);

// PANIC SAFETY: if is_less panics, no scratch memory was created and the slice should still be
// in a well defined state, without duplicates.

// Important to only swap if it is more and not if it is equal. is_less should return false for
// equal, so we don't swap.
let should_swap = is_less(&*b_ptr, &*a_ptr);

branchless_swap(a_ptr, b_ptr, should_swap);
}
}

// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
// performance impact.
#[inline(never)]
unsafe fn sort4_optimal<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: caller must ensure v.len() >= 4.
unsafe {
debug_assert!(v.len() == 4);

let arr_ptr = v.as_mut_ptr();

// Optimal sorting network see:
// https://bertdobbelaere.github.io/sorting_networks.html.

swap_if_less(arr_ptr, 0, 2, is_less);
swap_if_less(arr_ptr, 1, 3, is_less);
swap_if_less(arr_ptr, 0, 1, is_less);
swap_if_less(arr_ptr, 2, 3, is_less);
swap_if_less(arr_ptr, 1, 2, is_less);
}
}

// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
// performance impact.
#[inline(never)]
unsafe fn sort8_optimal<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: caller must ensure v.len() >= 8.
unsafe {
debug_assert!(v.len() == 8);

let arr_ptr = v.as_mut_ptr();

// Optimal sorting network see:
// https://bertdobbelaere.github.io/sorting_networks.html.

swap_if_less(arr_ptr, 0, 2, is_less);
swap_if_less(arr_ptr, 1, 3, is_less);
swap_if_less(arr_ptr, 4, 6, is_less);
swap_if_less(arr_ptr, 5, 7, is_less);
swap_if_less(arr_ptr, 0, 4, is_less);
swap_if_less(arr_ptr, 1, 5, is_less);
swap_if_less(arr_ptr, 2, 6, is_less);
swap_if_less(arr_ptr, 3, 7, is_less);
swap_if_less(arr_ptr, 0, 1, is_less);
swap_if_less(arr_ptr, 2, 3, is_less);
swap_if_less(arr_ptr, 4, 5, is_less);
swap_if_less(arr_ptr, 6, 7, is_less);
swap_if_less(arr_ptr, 2, 4, is_less);
swap_if_less(arr_ptr, 3, 5, is_less);
swap_if_less(arr_ptr, 1, 4, is_less);
swap_if_less(arr_ptr, 3, 6, is_less);
swap_if_less(arr_ptr, 1, 2, is_less);
swap_if_less(arr_ptr, 3, 4, is_less);
swap_if_less(arr_ptr, 5, 6, is_less);
}
}

// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no
// performance impact.
#[inline(never)]
unsafe fn sort16_optimal<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: caller must ensure v.len() >= 16.
unsafe {
debug_assert!(v.len() == 16);

let arr_ptr = v.as_mut_ptr();

// Optimal sorting network see:
// https://bertdobbelaere.github.io/sorting_networks.html#N16L60D10

swap_if_less(arr_ptr, 0, 13, is_less);
swap_if_less(arr_ptr, 1, 12, is_less);
swap_if_less(arr_ptr, 2, 15, is_less);
swap_if_less(arr_ptr, 3, 14, is_less);
swap_if_less(arr_ptr, 4, 8, is_less);
swap_if_less(arr_ptr, 5, 6, is_less);
swap_if_less(arr_ptr, 7, 11, is_less);
swap_if_less(arr_ptr, 9, 10, is_less);
swap_if_less(arr_ptr, 0, 5, is_less);
swap_if_less(arr_ptr, 1, 7, is_less);
swap_if_less(arr_ptr, 2, 9, is_less);
swap_if_less(arr_ptr, 3, 4, is_less);
swap_if_less(arr_ptr, 6, 13, is_less);
swap_if_less(arr_ptr, 8, 14, is_less);
swap_if_less(arr_ptr, 10, 15, is_less);
swap_if_less(arr_ptr, 11, 12, is_less);
swap_if_less(arr_ptr, 0, 1, is_less);
swap_if_less(arr_ptr, 2, 3, is_less);
swap_if_less(arr_ptr, 4, 5, is_less);
swap_if_less(arr_ptr, 6, 8, is_less);
swap_if_less(arr_ptr, 7, 9, is_less);
swap_if_less(arr_ptr, 10, 11, is_less);
swap_if_less(arr_ptr, 12, 13, is_less);
swap_if_less(arr_ptr, 14, 15, is_less);
swap_if_less(arr_ptr, 0, 2, is_less);
swap_if_less(arr_ptr, 1, 3, is_less);
swap_if_less(arr_ptr, 4, 10, is_less);
swap_if_less(arr_ptr, 5, 11, is_less);
swap_if_less(arr_ptr, 6, 7, is_less);
swap_if_less(arr_ptr, 8, 9, is_less);
swap_if_less(arr_ptr, 12, 14, is_less);
swap_if_less(arr_ptr, 13, 15, is_less);
swap_if_less(arr_ptr, 1, 2, is_less);
swap_if_less(arr_ptr, 3, 12, is_less);
swap_if_less(arr_ptr, 4, 6, is_less);
swap_if_less(arr_ptr, 5, 7, is_less);
swap_if_less(arr_ptr, 8, 10, is_less);
swap_if_less(arr_ptr, 9, 11, is_less);
swap_if_less(arr_ptr, 13, 14, is_less);
swap_if_less(arr_ptr, 1, 4, is_less);
swap_if_less(arr_ptr, 2, 6, is_less);
swap_if_less(arr_ptr, 5, 8, is_less);
swap_if_less(arr_ptr, 7, 10, is_less);
swap_if_less(arr_ptr, 9, 13, is_less);
swap_if_less(arr_ptr, 11, 14, is_less);
swap_if_less(arr_ptr, 2, 4, is_less);
swap_if_less(arr_ptr, 3, 6, is_less);
swap_if_less(arr_ptr, 9, 12, is_less);
swap_if_less(arr_ptr, 11, 13, is_less);
swap_if_less(arr_ptr, 3, 5, is_less);
swap_if_less(arr_ptr, 6, 8, is_less);
swap_if_less(arr_ptr, 7, 9, is_less);
swap_if_less(arr_ptr, 10, 12, is_less);
swap_if_less(arr_ptr, 3, 4, is_less);
swap_if_less(arr_ptr, 5, 6, is_less);
swap_if_less(arr_ptr, 7, 8, is_less);
swap_if_less(arr_ptr, 9, 10, is_less);
swap_if_less(arr_ptr, 11, 12, is_less);
swap_if_less(arr_ptr, 6, 7, is_less);
swap_if_less(arr_ptr, 8, 9, is_less);
}
}

unsafe fn sort4_plus<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: caller must ensure v.len() >= 4.
unsafe {
let len = v.len();
debug_assert!(len >= 4);

sort4_optimal(&mut v[0..4], is_less);
insertion_sort_shift_left(v, 4, is_less);
}
}

unsafe fn sort8_plus<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: caller must ensure v.len() >= 8.
unsafe {
let len = v.len();
debug_assert!(len >= 8);

sort8_optimal(&mut v[0..8], is_less);

if len >= 9 {
insertion_sort_shift_left(&mut v[8..], 1, is_less);

// We only need place for 8 entries because we know the shorter side is at most 8 long.
let mut swap = mem::MaybeUninit::<[T; 8]>::uninit();
let swap_ptr = swap.as_mut_ptr() as *mut T;

merge(v, 8, swap_ptr, is_less);
}
}
}

unsafe fn sort16_plus<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: caller must ensure v.len() >= 16.
unsafe {
let len = v.len();
debug_assert!(len >= 16);

sort16_optimal(&mut v[0..16], is_less);

if len >= 17 {
let start = if len >= 24 {
sort8_optimal(&mut v[16..24], is_less);
8
} else if len >= 20 {
sort4_optimal(&mut v[16..20], is_less);
4
} else {
1
};

insertion_sort_shift_left(&mut v[16..], start, is_less);

// We only need place for 16 entries because we know the shorter side is at most 16 long.
let mut swap = mem::MaybeUninit::<[T; 16]>::uninit();
let swap_ptr = swap.as_mut_ptr() as *mut T;

merge(v, 16, swap_ptr, is_less);
}
}
}

unsafe fn sort32_plus<T, F>(v: &mut [T], is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: caller must ensure v.len() >= 32.
unsafe {
debug_assert!(v.len() >= 32 && v.len() <= 40);

sort16_optimal(&mut v[0..16], is_less);
sort16_optimal(&mut v[16..32], is_less);

insertion_sort_shift_left(&mut v[16..], 16, is_less);

// We only need place for 16 entries because we know the shorter side is 16 long.
let mut swap = mem::MaybeUninit::<[T; 16]>::uninit();
let swap_ptr = swap.as_mut_ptr() as *mut T;

merge(v, 16, swap_ptr, is_less);
}
}

fn partition_at_index_loop<'a, T, F>(
mut v: &'a mut [T],
mut index: usize,
@@ -833,9 +1488,13 @@ fn partition_at_index_loop<'a, T, F>(
{
loop {
// For slices of up to this length it's probably faster to simply sort them.

// TODO use sort_small here?
const MAX_INSERTION: usize = 10;
if v.len() <= MAX_INSERTION {
insertion_sort(v, is_less);
if v.len() >= 2 {
insertion_sort_shift_left(v, 1, is_less);
}
return;
}