Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4b7bec8

Browse files
authoredJan 20, 2023
Rollup merge of rust-lang#104672 - Voultapher:unify-sort-modules, r=thomcc
Unify stable and unstable sort implementations in same core module This moves the stable sort implementation to the core::slice::sort module. By virtue of being in core it can't access `Vec`. The two `Vec` used by merge sort, `buf` and `runs`, are modelled as custom types that implement the very limited required `Vec` interface with the help of provided allocation and free functions. This is done to allow future re-use of functions and logic between stable and unstable sort. Such as `insert_head`. This is in preparation of rust-lang#100856 and rust-lang#104116. It only moves code, it *doesn't* change any of the sort related logic. This unlocks the ability to share `insert_head`, `insert_tail`, `swap_if_less` `merge` and more. Tagging ``@Mark-Simulacrum`` I hope this allows progress on rust-lang#100856, by moving `merge_sort` here I hope future changes will be easier to review.
2 parents 993932b + 4b5844f commit 4b7bec8

File tree

3 files changed

+560
-310
lines changed

3 files changed

+560
-310
lines changed
 

‎library/alloc/src/slice.rs

Lines changed: 39 additions & 309 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ use core::cmp::Ordering::{self, Less};
1919
use core::mem::{self, SizedTypeProperties};
2020
#[cfg(not(no_global_oom_handling))]
2121
use core::ptr;
22+
#[cfg(not(no_global_oom_handling))]
23+
use core::slice::sort;
2224

2325
use crate::alloc::Allocator;
2426
#[cfg(not(no_global_oom_handling))]
25-
use crate::alloc::Global;
27+
use crate::alloc::{self, Global};
2628
#[cfg(not(no_global_oom_handling))]
2729
use crate::borrow::ToOwned;
2830
use crate::boxed::Box;
@@ -206,7 +208,7 @@ impl<T> [T] {
206208
where
207209
T: Ord,
208210
{
209-
merge_sort(self, T::lt);
211+
stable_sort(self, T::lt);
210212
}
211213

212214
/// Sorts the slice with a comparator function.
@@ -262,7 +264,7 @@ impl<T> [T] {
262264
where
263265
F: FnMut(&T, &T) -> Ordering,
264266
{
265-
merge_sort(self, |a, b| compare(a, b) == Less);
267+
stable_sort(self, |a, b| compare(a, b) == Less);
266268
}
267269

268270
/// Sorts the slice with a key extraction function.
@@ -305,7 +307,7 @@ impl<T> [T] {
305307
F: FnMut(&T) -> K,
306308
K: Ord,
307309
{
308-
merge_sort(self, |a, b| f(a).lt(&f(b)));
310+
stable_sort(self, |a, b| f(a).lt(&f(b)));
309311
}
310312

311313
/// Sorts the slice with a key extraction function.
@@ -812,324 +814,52 @@ impl<T: Clone> ToOwned for [T] {
812814
// Sorting
813815
////////////////////////////////////////////////////////////////////////////////
814816

815-
/// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted.
816-
///
817-
/// This is the integral subroutine of insertion sort.
818-
#[cfg(not(no_global_oom_handling))]
819-
fn insert_head<T, F>(v: &mut [T], is_less: &mut F)
820-
where
821-
F: FnMut(&T, &T) -> bool,
822-
{
823-
if v.len() >= 2 && is_less(&v[1], &v[0]) {
824-
unsafe {
825-
// There are three ways to implement insertion here:
826-
//
827-
// 1. Swap adjacent elements until the first one gets to its final destination.
828-
// However, this way we copy data around more than is necessary. If elements are big
829-
// structures (costly to copy), this method will be slow.
830-
//
831-
// 2. Iterate until the right place for the first element is found. Then shift the
832-
// elements succeeding it to make room for it and finally place it into the
833-
// remaining hole. This is a good method.
834-
//
835-
// 3. Copy the first element into a temporary variable. Iterate until the right place
836-
// for it is found. As we go along, copy every traversed element into the slot
837-
// preceding it. Finally, copy data from the temporary variable into the remaining
838-
// hole. This method is very good. Benchmarks demonstrated slightly better
839-
// performance than with the 2nd method.
840-
//
841-
// All methods were benchmarked, and the 3rd showed best results. So we chose that one.
842-
let tmp = mem::ManuallyDrop::new(ptr::read(&v[0]));
843-
844-
// Intermediate state of the insertion process is always tracked by `hole`, which
845-
// serves two purposes:
846-
// 1. Protects integrity of `v` from panics in `is_less`.
847-
// 2. Fills the remaining hole in `v` in the end.
848-
//
849-
// Panic safety:
850-
//
851-
// If `is_less` panics at any point during the process, `hole` will get dropped and
852-
// fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it
853-
// initially held exactly once.
854-
let mut hole = InsertionHole { src: &*tmp, dest: &mut v[1] };
855-
ptr::copy_nonoverlapping(&v[1], &mut v[0], 1);
856-
857-
for i in 2..v.len() {
858-
if !is_less(&v[i], &*tmp) {
859-
break;
860-
}
861-
ptr::copy_nonoverlapping(&v[i], &mut v[i - 1], 1);
862-
hole.dest = &mut v[i];
863-
}
864-
// `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
865-
}
866-
}
867-
868-
// When dropped, copies from `src` into `dest`.
869-
struct InsertionHole<T> {
870-
src: *const T,
871-
dest: *mut T,
872-
}
873-
874-
impl<T> Drop for InsertionHole<T> {
875-
fn drop(&mut self) {
876-
unsafe {
877-
ptr::copy_nonoverlapping(self.src, self.dest, 1);
878-
}
879-
}
880-
}
881-
}
882-
883-
/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
884-
/// stores the result into `v[..]`.
885-
///
886-
/// # Safety
887-
///
888-
/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough
889-
/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type.
890-
#[cfg(not(no_global_oom_handling))]
891-
unsafe fn merge<T, F>(v: &mut [T], mid: usize, buf: *mut T, is_less: &mut F)
892-
where
893-
F: FnMut(&T, &T) -> bool,
894-
{
895-
let len = v.len();
896-
let v = v.as_mut_ptr();
897-
let (v_mid, v_end) = unsafe { (v.add(mid), v.add(len)) };
898-
899-
// The merge process first copies the shorter run into `buf`. Then it traces the newly copied
900-
// run and the longer run forwards (or backwards), comparing their next unconsumed elements and
901-
// copying the lesser (or greater) one into `v`.
902-
//
903-
// As soon as the shorter run is fully consumed, the process is done. If the longer run gets
904-
// consumed first, then we must copy whatever is left of the shorter run into the remaining
905-
// hole in `v`.
906-
//
907-
// Intermediate state of the process is always tracked by `hole`, which serves two purposes:
908-
// 1. Protects integrity of `v` from panics in `is_less`.
909-
// 2. Fills the remaining hole in `v` if the longer run gets consumed first.
910-
//
911-
// Panic safety:
912-
//
913-
// If `is_less` panics at any point during the process, `hole` will get dropped and fill the
914-
// hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
915-
// object it initially held exactly once.
916-
let mut hole;
917-
918-
if mid <= len - mid {
919-
// The left run is shorter.
920-
unsafe {
921-
ptr::copy_nonoverlapping(v, buf, mid);
922-
hole = MergeHole { start: buf, end: buf.add(mid), dest: v };
923-
}
924-
925-
// Initially, these pointers point to the beginnings of their arrays.
926-
let left = &mut hole.start;
927-
let mut right = v_mid;
928-
let out = &mut hole.dest;
929-
930-
while *left < hole.end && right < v_end {
931-
// Consume the lesser side.
932-
// If equal, prefer the left run to maintain stability.
933-
unsafe {
934-
let to_copy = if is_less(&*right, &**left) {
935-
get_and_increment(&mut right)
936-
} else {
937-
get_and_increment(left)
938-
};
939-
ptr::copy_nonoverlapping(to_copy, get_and_increment(out), 1);
940-
}
941-
}
942-
} else {
943-
// The right run is shorter.
944-
unsafe {
945-
ptr::copy_nonoverlapping(v_mid, buf, len - mid);
946-
hole = MergeHole { start: buf, end: buf.add(len - mid), dest: v_mid };
947-
}
948-
949-
// Initially, these pointers point past the ends of their arrays.
950-
let left = &mut hole.dest;
951-
let right = &mut hole.end;
952-
let mut out = v_end;
953-
954-
while v < *left && buf < *right {
955-
// Consume the greater side.
956-
// If equal, prefer the right run to maintain stability.
957-
unsafe {
958-
let to_copy = if is_less(&*right.sub(1), &*left.sub(1)) {
959-
decrement_and_get(left)
960-
} else {
961-
decrement_and_get(right)
962-
};
963-
ptr::copy_nonoverlapping(to_copy, decrement_and_get(&mut out), 1);
964-
}
965-
}
966-
}
967-
// Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
968-
// it will now be copied into the hole in `v`.
969-
970-
unsafe fn get_and_increment<T>(ptr: &mut *mut T) -> *mut T {
971-
let old = *ptr;
972-
*ptr = unsafe { ptr.add(1) };
973-
old
974-
}
975-
976-
unsafe fn decrement_and_get<T>(ptr: &mut *mut T) -> *mut T {
977-
*ptr = unsafe { ptr.sub(1) };
978-
*ptr
979-
}
980-
981-
// When dropped, copies the range `start..end` into `dest..`.
982-
struct MergeHole<T> {
983-
start: *mut T,
984-
end: *mut T,
985-
dest: *mut T,
986-
}
987-
988-
impl<T> Drop for MergeHole<T> {
989-
fn drop(&mut self) {
990-
// `T` is not a zero-sized type, and these are pointers into a slice's elements.
991-
unsafe {
992-
let len = self.end.sub_ptr(self.start);
993-
ptr::copy_nonoverlapping(self.start, self.dest, len);
994-
}
995-
}
996-
}
997-
}
998-
999-
/// This merge sort borrows some (but not all) ideas from TimSort, which is described in detail
1000-
/// [here](https://github.com/python/cpython/blob/main/Objects/listsort.txt).
1001-
///
1002-
/// The algorithm identifies strictly descending and non-descending subsequences, which are called
1003-
/// natural runs. There is a stack of pending runs yet to be merged. Each newly found run is pushed
1004-
/// onto the stack, and then some pairs of adjacent runs are merged until these two invariants are
1005-
/// satisfied:
1006-
///
1007-
/// 1. for every `i` in `1..runs.len()`: `runs[i - 1].len > runs[i].len`
1008-
/// 2. for every `i` in `2..runs.len()`: `runs[i - 2].len > runs[i - 1].len + runs[i].len`
1009-
///
1010-
/// The invariants ensure that the total running time is *O*(*n* \* log(*n*)) worst-case.
817+
#[inline]
1011818
#[cfg(not(no_global_oom_handling))]
1012-
fn merge_sort<T, F>(v: &mut [T], mut is_less: F)
819+
fn stable_sort<T, F>(v: &mut [T], mut is_less: F)
1013820
where
1014821
F: FnMut(&T, &T) -> bool,
1015822
{
1016-
// Slices of up to this length get sorted using insertion sort.
1017-
const MAX_INSERTION: usize = 20;
1018-
// Very short runs are extended using insertion sort to span at least this many elements.
1019-
const MIN_RUN: usize = 10;
1020-
1021-
// Sorting has no meaningful behavior on zero-sized types.
1022823
if T::IS_ZST {
824+
// Sorting has no meaningful behavior on zero-sized types. Do nothing.
1023825
return;
1024826
}
1025827

1026-
let len = v.len();
1027-
1028-
// Short arrays get sorted in-place via insertion sort to avoid allocations.
1029-
if len <= MAX_INSERTION {
1030-
if len >= 2 {
1031-
for i in (0..len - 1).rev() {
1032-
insert_head(&mut v[i..], &mut is_less);
1033-
}
1034-
}
1035-
return;
1036-
}
1037-
1038-
// Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it
1039-
// shallow copies of the contents of `v` without risking the dtors running on copies if
1040-
// `is_less` panics. When merging two sorted runs, this buffer holds a copy of the shorter run,
1041-
// which will always have length at most `len / 2`.
1042-
let mut buf = Vec::with_capacity(len / 2);
828+
let elem_alloc_fn = |len: usize| -> *mut T {
829+
// SAFETY: Creating the layout is safe as long as merge_sort never calls this with len >
830+
// v.len(). Alloc in general will only be used as 'shadow-region' to store temporary swap
831+
// elements.
832+
unsafe { alloc::alloc(alloc::Layout::array::<T>(len).unwrap_unchecked()) as *mut T }
833+
};
1043834

1044-
// In order to identify natural runs in `v`, we traverse it backwards. That might seem like a
1045-
// strange decision, but consider the fact that merges more often go in the opposite direction
1046-
// (forwards). According to benchmarks, merging forwards is slightly faster than merging
1047-
// backwards. To conclude, identifying runs by traversing backwards improves performance.
1048-
let mut runs = vec![];
1049-
let mut end = len;
1050-
while end > 0 {
1051-
// Find the next natural run, and reverse it if it's strictly descending.
1052-
let mut start = end - 1;
1053-
if start > 0 {
1054-
start -= 1;
1055-
unsafe {
1056-
if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) {
1057-
while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) {
1058-
start -= 1;
1059-
}
1060-
v[start..end].reverse();
1061-
} else {
1062-
while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1))
1063-
{
1064-
start -= 1;
1065-
}
1066-
}
1067-
}
1068-
}
1069-
1070-
// Insert some more elements into the run if it's too short. Insertion sort is faster than
1071-
// merge sort on short sequences, so this significantly improves performance.
1072-
while start > 0 && end - start < MIN_RUN {
1073-
start -= 1;
1074-
insert_head(&mut v[start..end], &mut is_less);
835+
let elem_dealloc_fn = |buf_ptr: *mut T, len: usize| {
836+
// SAFETY: Creating the layout is safe as long as merge_sort never calls this with len >
837+
// v.len(). The caller must ensure that buf_ptr was created by elem_alloc_fn with the same
838+
// len.
839+
unsafe {
840+
alloc::dealloc(buf_ptr as *mut u8, alloc::Layout::array::<T>(len).unwrap_unchecked());
1075841
}
842+
};
1076843

1077-
// Push this run onto the stack.
1078-
runs.push(Run { start, len: end - start });
1079-
end = start;
1080-
1081-
// Merge some pairs of adjacent runs to satisfy the invariants.
1082-
while let Some(r) = collapse(&runs) {
1083-
let left = runs[r + 1];
1084-
let right = runs[r];
1085-
unsafe {
1086-
merge(
1087-
&mut v[left.start..right.start + right.len],
1088-
left.len,
1089-
buf.as_mut_ptr(),
1090-
&mut is_less,
1091-
);
1092-
}
1093-
runs[r] = Run { start: left.start, len: left.len + right.len };
1094-
runs.remove(r + 1);
844+
let run_alloc_fn = |len: usize| -> *mut sort::TimSortRun {
845+
// SAFETY: Creating the layout is safe as long as merge_sort never calls this with an
846+
// obscene length or 0.
847+
unsafe {
848+
alloc::alloc(alloc::Layout::array::<sort::TimSortRun>(len).unwrap_unchecked())
849+
as *mut sort::TimSortRun
1095850
}
1096-
}
1097-
1098-
// Finally, exactly one run must remain in the stack.
1099-
debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
851+
};
1100852

1101-
// Examines the stack of runs and identifies the next pair of runs to merge. More specifically,
1102-
// if `Some(r)` is returned, that means `runs[r]` and `runs[r + 1]` must be merged next. If the
1103-
// algorithm should continue building a new run instead, `None` is returned.
1104-
//
1105-
// TimSort is infamous for its buggy implementations, as described here:
1106-
// http://envisage-project.eu/timsort-specification-and-verification/
1107-
//
1108-
// The gist of the story is: we must enforce the invariants on the top four runs on the stack.
1109-
// Enforcing them on just top three is not sufficient to ensure that the invariants will still
1110-
// hold for *all* runs in the stack.
1111-
//
1112-
// This function correctly checks invariants for the top four runs. Additionally, if the top
1113-
// run starts at index 0, it will always demand a merge operation until the stack is fully
1114-
// collapsed, in order to complete the sort.
1115-
#[inline]
1116-
fn collapse(runs: &[Run]) -> Option<usize> {
1117-
let n = runs.len();
1118-
if n >= 2
1119-
&& (runs[n - 1].start == 0
1120-
|| runs[n - 2].len <= runs[n - 1].len
1121-
|| (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len)
1122-
|| (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len))
1123-
{
1124-
if n >= 3 && runs[n - 3].len < runs[n - 1].len { Some(n - 3) } else { Some(n - 2) }
1125-
} else {
1126-
None
853+
let run_dealloc_fn = |buf_ptr: *mut sort::TimSortRun, len: usize| {
854+
// SAFETY: The caller must ensure that buf_ptr was created by elem_alloc_fn with the same
855+
// len.
856+
unsafe {
857+
alloc::dealloc(
858+
buf_ptr as *mut u8,
859+
alloc::Layout::array::<sort::TimSortRun>(len).unwrap_unchecked(),
860+
);
1127861
}
1128-
}
862+
};
1129863

1130-
#[derive(Clone, Copy)]
1131-
struct Run {
1132-
start: usize,
1133-
len: usize,
1134-
}
864+
sort::merge_sort(v, &mut is_less, elem_alloc_fn, elem_dealloc_fn, run_alloc_fn, run_dealloc_fn);
1135865
}

‎library/core/src/slice/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,19 @@ use crate::slice;
2929
/// Pure rust memchr implementation, taken from rust-memchr
3030
pub mod memchr;
3131

32+
#[unstable(
33+
feature = "slice_internals",
34+
issue = "none",
35+
reason = "exposed from core to be reused in std;"
36+
)]
37+
pub mod sort;
38+
3239
mod ascii;
3340
mod cmp;
3441
mod index;
3542
mod iter;
3643
mod raw;
3744
mod rotate;
38-
mod sort;
3945
mod specialize;
4046

4147
#[stable(feature = "rust1", since = "1.0.0")]

‎library/core/src/slice/sort.rs

Lines changed: 514 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
//!
66
//! Unstable sorting is compatible with core because it doesn't allocate memory, unlike our
77
//! stable sorting implementation.
8+
//!
9+
//! In addition it also contains the core logic of the stable sort used by `slice::sort` based on
10+
//! TimSort.
811
912
use crate::cmp;
1013
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
@@ -905,6 +908,7 @@ fn partition_at_index_loop<'a, T, F>(
905908
}
906909
}
907910

911+
/// Reorder the slice such that the element at `index` is at its final sorted position.
908912
pub fn partition_at_index<T, F>(
909913
v: &mut [T],
910914
index: usize,
@@ -949,3 +953,513 @@ where
949953
let pivot = &mut pivot[0];
950954
(left, pivot, right)
951955
}
956+
957+
/// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted.
958+
///
959+
/// This is the integral subroutine of insertion sort.
960+
fn insert_head<T, F>(v: &mut [T], is_less: &mut F)
961+
where
962+
F: FnMut(&T, &T) -> bool,
963+
{
964+
if v.len() >= 2 && is_less(&v[1], &v[0]) {
965+
// SAFETY: Copy tmp back even if panic, and ensure unique observation.
966+
unsafe {
967+
// There are three ways to implement insertion here:
968+
//
969+
// 1. Swap adjacent elements until the first one gets to its final destination.
970+
// However, this way we copy data around more than is necessary. If elements are big
971+
// structures (costly to copy), this method will be slow.
972+
//
973+
// 2. Iterate until the right place for the first element is found. Then shift the
974+
// elements succeeding it to make room for it and finally place it into the
975+
// remaining hole. This is a good method.
976+
//
977+
// 3. Copy the first element into a temporary variable. Iterate until the right place
978+
// for it is found. As we go along, copy every traversed element into the slot
979+
// preceding it. Finally, copy data from the temporary variable into the remaining
980+
// hole. This method is very good. Benchmarks demonstrated slightly better
981+
// performance than with the 2nd method.
982+
//
983+
// All methods were benchmarked, and the 3rd showed best results. So we chose that one.
984+
let tmp = mem::ManuallyDrop::new(ptr::read(&v[0]));
985+
986+
// Intermediate state of the insertion process is always tracked by `hole`, which
987+
// serves two purposes:
988+
// 1. Protects integrity of `v` from panics in `is_less`.
989+
// 2. Fills the remaining hole in `v` in the end.
990+
//
991+
// Panic safety:
992+
//
993+
// If `is_less` panics at any point during the process, `hole` will get dropped and
994+
// fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it
995+
// initially held exactly once.
996+
let mut hole = InsertionHole { src: &*tmp, dest: &mut v[1] };
997+
ptr::copy_nonoverlapping(&v[1], &mut v[0], 1);
998+
999+
for i in 2..v.len() {
1000+
if !is_less(&v[i], &*tmp) {
1001+
break;
1002+
}
1003+
ptr::copy_nonoverlapping(&v[i], &mut v[i - 1], 1);
1004+
hole.dest = &mut v[i];
1005+
}
1006+
// `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`.
1007+
}
1008+
}
1009+
1010+
// When dropped, copies from `src` into `dest`.
1011+
struct InsertionHole<T> {
1012+
src: *const T,
1013+
dest: *mut T,
1014+
}
1015+
1016+
impl<T> Drop for InsertionHole<T> {
1017+
fn drop(&mut self) {
1018+
// SAFETY: The caller must ensure that src and dest are correctly set.
1019+
unsafe {
1020+
ptr::copy_nonoverlapping(self.src, self.dest, 1);
1021+
}
1022+
}
1023+
}
1024+
}
1025+
1026+
/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and
1027+
/// stores the result into `v[..]`.
1028+
///
1029+
/// # Safety
1030+
///
1031+
/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough
1032+
/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type.
1033+
unsafe fn merge<T, F>(v: &mut [T], mid: usize, buf: *mut T, is_less: &mut F)
1034+
where
1035+
F: FnMut(&T, &T) -> bool,
1036+
{
1037+
let len = v.len();
1038+
let v = v.as_mut_ptr();
1039+
1040+
// SAFETY: mid and len must be in-bounds of v.
1041+
let (v_mid, v_end) = unsafe { (v.add(mid), v.add(len)) };
1042+
1043+
// The merge process first copies the shorter run into `buf`. Then it traces the newly copied
1044+
// run and the longer run forwards (or backwards), comparing their next unconsumed elements and
1045+
// copying the lesser (or greater) one into `v`.
1046+
//
1047+
// As soon as the shorter run is fully consumed, the process is done. If the longer run gets
1048+
// consumed first, then we must copy whatever is left of the shorter run into the remaining
1049+
// hole in `v`.
1050+
//
1051+
// Intermediate state of the process is always tracked by `hole`, which serves two purposes:
1052+
// 1. Protects integrity of `v` from panics in `is_less`.
1053+
// 2. Fills the remaining hole in `v` if the longer run gets consumed first.
1054+
//
1055+
// Panic safety:
1056+
//
1057+
// If `is_less` panics at any point during the process, `hole` will get dropped and fill the
1058+
// hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
1059+
// object it initially held exactly once.
1060+
let mut hole;
1061+
1062+
if mid <= len - mid {
1063+
// The left run is shorter.
1064+
1065+
// SAFETY: buf must have enough capacity for `v[..mid]`.
1066+
unsafe {
1067+
ptr::copy_nonoverlapping(v, buf, mid);
1068+
hole = MergeHole { start: buf, end: buf.add(mid), dest: v };
1069+
}
1070+
1071+
// Initially, these pointers point to the beginnings of their arrays.
1072+
let left = &mut hole.start;
1073+
let mut right = v_mid;
1074+
let out = &mut hole.dest;
1075+
1076+
while *left < hole.end && right < v_end {
1077+
// Consume the lesser side.
1078+
// If equal, prefer the left run to maintain stability.
1079+
1080+
// SAFETY: left and right must be valid and part of v same for out.
1081+
unsafe {
1082+
let to_copy = if is_less(&*right, &**left) {
1083+
get_and_increment(&mut right)
1084+
} else {
1085+
get_and_increment(left)
1086+
};
1087+
ptr::copy_nonoverlapping(to_copy, get_and_increment(out), 1);
1088+
}
1089+
}
1090+
} else {
1091+
// The right run is shorter.
1092+
1093+
// SAFETY: buf must have enough capacity for `v[mid..]`.
1094+
unsafe {
1095+
ptr::copy_nonoverlapping(v_mid, buf, len - mid);
1096+
hole = MergeHole { start: buf, end: buf.add(len - mid), dest: v_mid };
1097+
}
1098+
1099+
// Initially, these pointers point past the ends of their arrays.
1100+
let left = &mut hole.dest;
1101+
let right = &mut hole.end;
1102+
let mut out = v_end;
1103+
1104+
while v < *left && buf < *right {
1105+
// Consume the greater side.
1106+
// If equal, prefer the right run to maintain stability.
1107+
1108+
// SAFETY: left and right must be valid and part of v same for out.
1109+
unsafe {
1110+
let to_copy = if is_less(&*right.sub(1), &*left.sub(1)) {
1111+
decrement_and_get(left)
1112+
} else {
1113+
decrement_and_get(right)
1114+
};
1115+
ptr::copy_nonoverlapping(to_copy, decrement_and_get(&mut out), 1);
1116+
}
1117+
}
1118+
}
1119+
// Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of
1120+
// it will now be copied into the hole in `v`.
1121+
1122+
unsafe fn get_and_increment<T>(ptr: &mut *mut T) -> *mut T {
1123+
let old = *ptr;
1124+
1125+
// SAFETY: ptr.add(1) must still be a valid pointer and part of `v`.
1126+
*ptr = unsafe { ptr.add(1) };
1127+
old
1128+
}
1129+
1130+
unsafe fn decrement_and_get<T>(ptr: &mut *mut T) -> *mut T {
1131+
// SAFETY: ptr.sub(1) must still be a valid pointer and part of `v`.
1132+
*ptr = unsafe { ptr.sub(1) };
1133+
*ptr
1134+
}
1135+
1136+
// When dropped, copies the range `start..end` into `dest..`.
1137+
struct MergeHole<T> {
1138+
start: *mut T,
1139+
end: *mut T,
1140+
dest: *mut T,
1141+
}
1142+
1143+
impl<T> Drop for MergeHole<T> {
1144+
fn drop(&mut self) {
1145+
// SAFETY: `T` is not a zero-sized type, and these are pointers into a slice's elements.
1146+
unsafe {
1147+
let len = self.end.sub_ptr(self.start);
1148+
ptr::copy_nonoverlapping(self.start, self.dest, len);
1149+
}
1150+
}
1151+
}
1152+
}
1153+
1154+
/// This merge sort borrows some (but not all) ideas from TimSort, which used to be described in
1155+
/// detail [here](https://github.com/python/cpython/blob/main/Objects/listsort.txt). However Python
1156+
/// has switched to a Powersort based implementation.
1157+
///
1158+
/// The algorithm identifies strictly descending and non-descending subsequences, which are called
1159+
/// natural runs. There is a stack of pending runs yet to be merged. Each newly found run is pushed
1160+
/// onto the stack, and then some pairs of adjacent runs are merged until these two invariants are
1161+
/// satisfied:
1162+
///
1163+
/// 1. for every `i` in `1..runs.len()`: `runs[i - 1].len > runs[i].len`
1164+
/// 2. for every `i` in `2..runs.len()`: `runs[i - 2].len > runs[i - 1].len + runs[i].len`
1165+
///
1166+
/// The invariants ensure that the total running time is *O*(*n* \* log(*n*)) worst-case.
1167+
pub fn merge_sort<T, CmpF, ElemAllocF, ElemDeallocF, RunAllocF, RunDeallocF>(
1168+
v: &mut [T],
1169+
is_less: &mut CmpF,
1170+
elem_alloc_fn: ElemAllocF,
1171+
elem_dealloc_fn: ElemDeallocF,
1172+
run_alloc_fn: RunAllocF,
1173+
run_dealloc_fn: RunDeallocF,
1174+
) where
1175+
CmpF: FnMut(&T, &T) -> bool,
1176+
ElemAllocF: Fn(usize) -> *mut T,
1177+
ElemDeallocF: Fn(*mut T, usize),
1178+
RunAllocF: Fn(usize) -> *mut TimSortRun,
1179+
RunDeallocF: Fn(*mut TimSortRun, usize),
1180+
{
1181+
// Slices of up to this length get sorted using insertion sort.
1182+
const MAX_INSERTION: usize = 20;
1183+
// Very short runs are extended using insertion sort to span at least this many elements.
1184+
const MIN_RUN: usize = 10;
1185+
1186+
// The caller should have already checked that.
1187+
debug_assert!(!T::IS_ZST);
1188+
1189+
let len = v.len();
1190+
1191+
// Short arrays get sorted in-place via insertion sort to avoid allocations.
1192+
if len <= MAX_INSERTION {
1193+
if len >= 2 {
1194+
for i in (0..len - 1).rev() {
1195+
insert_head(&mut v[i..], is_less);
1196+
}
1197+
}
1198+
return;
1199+
}
1200+
1201+
// Allocate a buffer to use as scratch memory. We keep the length 0 so we can keep in it
1202+
// shallow copies of the contents of `v` without risking the dtors running on copies if
1203+
// `is_less` panics. When merging two sorted runs, this buffer holds a copy of the shorter run,
1204+
// which will always have length at most `len / 2`.
1205+
let buf = BufGuard::new(len / 2, elem_alloc_fn, elem_dealloc_fn);
1206+
let buf_ptr = buf.buf_ptr;
1207+
1208+
let mut runs = RunVec::new(run_alloc_fn, run_dealloc_fn);
1209+
1210+
// In order to identify natural runs in `v`, we traverse it backwards. That might seem like a
1211+
// strange decision, but consider the fact that merges more often go in the opposite direction
1212+
// (forwards). According to benchmarks, merging forwards is slightly faster than merging
1213+
// backwards. To conclude, identifying runs by traversing backwards improves performance.
1214+
let mut end = len;
1215+
while end > 0 {
1216+
// Find the next natural run, and reverse it if it's strictly descending.
1217+
let mut start = end - 1;
1218+
if start > 0 {
1219+
start -= 1;
1220+
1221+
// SAFETY: The v.get_unchecked must be fed with correct inbound indicies.
1222+
unsafe {
1223+
if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) {
1224+
while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) {
1225+
start -= 1;
1226+
}
1227+
v[start..end].reverse();
1228+
} else {
1229+
while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1))
1230+
{
1231+
start -= 1;
1232+
}
1233+
}
1234+
}
1235+
}
1236+
1237+
// Insert some more elements into the run if it's too short. Insertion sort is faster than
1238+
// merge sort on short sequences, so this significantly improves performance.
1239+
while start > 0 && end - start < MIN_RUN {
1240+
start -= 1;
1241+
insert_head(&mut v[start..end], is_less);
1242+
}
1243+
1244+
// Push this run onto the stack.
1245+
runs.push(TimSortRun { start, len: end - start });
1246+
end = start;
1247+
1248+
// Merge some pairs of adjacent runs to satisfy the invariants.
1249+
while let Some(r) = collapse(runs.as_slice()) {
1250+
let left = runs[r + 1];
1251+
let right = runs[r];
1252+
// SAFETY: `buf_ptr` must hold enough capacity for the shorter of the two sides, and
1253+
// neither side may be on length 0.
1254+
unsafe {
1255+
merge(&mut v[left.start..right.start + right.len], left.len, buf_ptr, is_less);
1256+
}
1257+
runs[r] = TimSortRun { start: left.start, len: left.len + right.len };
1258+
runs.remove(r + 1);
1259+
}
1260+
}
1261+
1262+
// Finally, exactly one run must remain in the stack.
1263+
debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
1264+
1265+
// Examines the stack of runs and identifies the next pair of runs to merge. More specifically,
1266+
// if `Some(r)` is returned, that means `runs[r]` and `runs[r + 1]` must be merged next. If the
1267+
// algorithm should continue building a new run instead, `None` is returned.
1268+
//
1269+
// TimSort is infamous for its buggy implementations, as described here:
1270+
// http://envisage-project.eu/timsort-specification-and-verification/
1271+
//
1272+
// The gist of the story is: we must enforce the invariants on the top four runs on the stack.
1273+
// Enforcing them on just top three is not sufficient to ensure that the invariants will still
1274+
// hold for *all* runs in the stack.
1275+
//
1276+
// This function correctly checks invariants for the top four runs. Additionally, if the top
1277+
// run starts at index 0, it will always demand a merge operation until the stack is fully
1278+
// collapsed, in order to complete the sort.
1279+
#[inline]
1280+
fn collapse(runs: &[TimSortRun]) -> Option<usize> {
1281+
let n = runs.len();
1282+
if n >= 2
1283+
&& (runs[n - 1].start == 0
1284+
|| runs[n - 2].len <= runs[n - 1].len
1285+
|| (n >= 3 && runs[n - 3].len <= runs[n - 2].len + runs[n - 1].len)
1286+
|| (n >= 4 && runs[n - 4].len <= runs[n - 3].len + runs[n - 2].len))
1287+
{
1288+
if n >= 3 && runs[n - 3].len < runs[n - 1].len { Some(n - 3) } else { Some(n - 2) }
1289+
} else {
1290+
None
1291+
}
1292+
}
1293+
1294+
// Extremely basic versions of Vec.
1295+
// Their use is super limited and by having the code here, it allows reuse between the sort
1296+
// implementations.
1297+
struct BufGuard<T, ElemDeallocF>
1298+
where
1299+
ElemDeallocF: Fn(*mut T, usize),
1300+
{
1301+
buf_ptr: *mut T,
1302+
capacity: usize,
1303+
elem_dealloc_fn: ElemDeallocF,
1304+
}
1305+
1306+
impl<T, ElemDeallocF> BufGuard<T, ElemDeallocF>
1307+
where
1308+
ElemDeallocF: Fn(*mut T, usize),
1309+
{
1310+
fn new<ElemAllocF>(
1311+
len: usize,
1312+
elem_alloc_fn: ElemAllocF,
1313+
elem_dealloc_fn: ElemDeallocF,
1314+
) -> Self
1315+
where
1316+
ElemAllocF: Fn(usize) -> *mut T,
1317+
{
1318+
Self { buf_ptr: elem_alloc_fn(len), capacity: len, elem_dealloc_fn }
1319+
}
1320+
}
1321+
1322+
impl<T, ElemDeallocF> Drop for BufGuard<T, ElemDeallocF>
1323+
where
1324+
ElemDeallocF: Fn(*mut T, usize),
1325+
{
1326+
fn drop(&mut self) {
1327+
(self.elem_dealloc_fn)(self.buf_ptr, self.capacity);
1328+
}
1329+
}
1330+
1331+
struct RunVec<RunAllocF, RunDeallocF>
1332+
where
1333+
RunAllocF: Fn(usize) -> *mut TimSortRun,
1334+
RunDeallocF: Fn(*mut TimSortRun, usize),
1335+
{
1336+
buf_ptr: *mut TimSortRun,
1337+
capacity: usize,
1338+
len: usize,
1339+
run_alloc_fn: RunAllocF,
1340+
run_dealloc_fn: RunDeallocF,
1341+
}
1342+
1343+
impl<RunAllocF, RunDeallocF> RunVec<RunAllocF, RunDeallocF>
1344+
where
1345+
RunAllocF: Fn(usize) -> *mut TimSortRun,
1346+
RunDeallocF: Fn(*mut TimSortRun, usize),
1347+
{
1348+
fn new(run_alloc_fn: RunAllocF, run_dealloc_fn: RunDeallocF) -> Self {
1349+
// Most slices can be sorted with at most 16 runs in-flight.
1350+
const START_RUN_CAPACITY: usize = 16;
1351+
1352+
Self {
1353+
buf_ptr: run_alloc_fn(START_RUN_CAPACITY),
1354+
capacity: START_RUN_CAPACITY,
1355+
len: 0,
1356+
run_alloc_fn,
1357+
run_dealloc_fn,
1358+
}
1359+
}
1360+
1361+
fn push(&mut self, val: TimSortRun) {
1362+
if self.len == self.capacity {
1363+
let old_capacity = self.capacity;
1364+
let old_buf_ptr = self.buf_ptr;
1365+
1366+
self.capacity = self.capacity * 2;
1367+
self.buf_ptr = (self.run_alloc_fn)(self.capacity);
1368+
1369+
// SAFETY: buf_ptr new and old were correctly allocated and old_buf_ptr has
1370+
// old_capacity valid elements.
1371+
unsafe {
1372+
ptr::copy_nonoverlapping(old_buf_ptr, self.buf_ptr, old_capacity);
1373+
}
1374+
1375+
(self.run_dealloc_fn)(old_buf_ptr, old_capacity);
1376+
}
1377+
1378+
// SAFETY: The invariant was just checked.
1379+
unsafe {
1380+
self.buf_ptr.add(self.len).write(val);
1381+
}
1382+
self.len += 1;
1383+
}
1384+
1385+
fn remove(&mut self, index: usize) {
1386+
if index >= self.len {
1387+
panic!("Index out of bounds");
1388+
}
1389+
1390+
// SAFETY: buf_ptr needs to be valid and len invariant upheld.
1391+
unsafe {
1392+
// the place we are taking from.
1393+
let ptr = self.buf_ptr.add(index);
1394+
1395+
// Shift everything down to fill in that spot.
1396+
ptr::copy(ptr.add(1), ptr, self.len - index - 1);
1397+
}
1398+
self.len -= 1;
1399+
}
1400+
1401+
fn as_slice(&self) -> &[TimSortRun] {
1402+
// SAFETY: Safe as long as buf_ptr is valid and len invariant was upheld.
1403+
unsafe { &*ptr::slice_from_raw_parts(self.buf_ptr, self.len) }
1404+
}
1405+
1406+
fn len(&self) -> usize {
1407+
self.len
1408+
}
1409+
}
1410+
1411+
impl<RunAllocF, RunDeallocF> core::ops::Index<usize> for RunVec<RunAllocF, RunDeallocF>
1412+
where
1413+
RunAllocF: Fn(usize) -> *mut TimSortRun,
1414+
RunDeallocF: Fn(*mut TimSortRun, usize),
1415+
{
1416+
type Output = TimSortRun;
1417+
1418+
fn index(&self, index: usize) -> &Self::Output {
1419+
if index < self.len {
1420+
// SAFETY: buf_ptr and len invariant must be upheld.
1421+
unsafe {
1422+
return &*(self.buf_ptr.add(index));
1423+
}
1424+
}
1425+
1426+
panic!("Index out of bounds");
1427+
}
1428+
}
1429+
1430+
impl<RunAllocF, RunDeallocF> core::ops::IndexMut<usize> for RunVec<RunAllocF, RunDeallocF>
1431+
where
1432+
RunAllocF: Fn(usize) -> *mut TimSortRun,
1433+
RunDeallocF: Fn(*mut TimSortRun, usize),
1434+
{
1435+
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
1436+
if index < self.len {
1437+
// SAFETY: buf_ptr and len invariant must be upheld.
1438+
unsafe {
1439+
return &mut *(self.buf_ptr.add(index));
1440+
}
1441+
}
1442+
1443+
panic!("Index out of bounds");
1444+
}
1445+
}
1446+
1447+
impl<RunAllocF, RunDeallocF> Drop for RunVec<RunAllocF, RunDeallocF>
1448+
where
1449+
RunAllocF: Fn(usize) -> *mut TimSortRun,
1450+
RunDeallocF: Fn(*mut TimSortRun, usize),
1451+
{
1452+
fn drop(&mut self) {
1453+
// As long as TimSortRun is Copy we don't need to drop them individually but just the
1454+
// whole allocation.
1455+
(self.run_dealloc_fn)(self.buf_ptr, self.capacity);
1456+
}
1457+
}
1458+
}
1459+
1460+
/// Internal type used by merge_sort.
1461+
#[derive(Clone, Copy, Debug)]
1462+
pub struct TimSortRun {
1463+
len: usize,
1464+
start: usize,
1465+
}

0 commit comments

Comments
 (0)
Please sign in to comment.