Skip to content

Commit fcdcff7

Browse files
committed
Port some rust-lang/rust performance and safety PRs
Port changes from rust-lang/rust related to binary heap performance and use of unsafe: - #81706: Document BinaryHeap unsafe functions - #81127: Improve sift_down performance in BinaryHeap - #58123: Avoid some bounds checks in binary_heap::{PeekMut,Hole} - #72709: #[deny(unsafe_op_in_unsafe_fn)] in liballoc Note that the following related rust-lang/rust PRs were already ported here in earlier PRs: - (in sekineh#28) #78857: Improve BinaryHeap performance - (in sekineh#27) #75974: Avoid useless sift_down when std::collections::binary_heap::PeekMut is never mutably dereferenced
1 parent fec9e24 commit fcdcff7

File tree

1 file changed

+143
-66
lines changed

1 file changed

+143
-66
lines changed

src/binary_heap.rs

Lines changed: 143 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
//! }
156156
//! ```
157157
158+
#![deny(unsafe_op_in_unsafe_fn)]
158159
#![allow(clippy::needless_doctest_main)]
159160
#![allow(missing_docs)]
160161
// #![stable(feature = "rust1", since = "1.0.0")]
@@ -319,7 +320,8 @@ impl<'a, T: fmt::Debug, C: Compare<T>> fmt::Debug for PeekMut<'a, T, C> {
319320
impl<'a, T, C: Compare<T>> Drop for PeekMut<'a, T, C> {
320321
fn drop(&mut self) {
321322
if self.sift {
322-
self.heap.sift_down(0);
323+
// SAFETY: PeekMut is only instantiated for non-empty heaps.
324+
unsafe { self.heap.sift_down(0) };
323325
}
324326
}
325327
}
@@ -328,15 +330,19 @@ impl<'a, T, C: Compare<T>> Drop for PeekMut<'a, T, C> {
328330
impl<'a, T, C: Compare<T>> Deref for PeekMut<'a, T, C> {
329331
type Target = T;
330332
fn deref(&self) -> &T {
331-
&self.heap.data[0]
333+
debug_assert!(!self.heap.is_empty());
334+
// SAFE: PeekMut is only instantiated for non-empty heaps
335+
unsafe { self.heap.data.get_unchecked(0) }
332336
}
333337
}
334338

335339
// #[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
336340
impl<'a, T, C: Compare<T>> DerefMut for PeekMut<'a, T, C> {
337341
fn deref_mut(&mut self) -> &mut T {
342+
debug_assert!(!self.heap.is_empty());
338343
self.sift = true;
339-
&mut self.heap.data[0]
344+
// SAFE: PeekMut is only instantiated for non-empty heaps
345+
unsafe { self.heap.data.get_unchecked_mut(0) }
340346
}
341347
}
342348

@@ -865,7 +871,8 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
865871
self.data.pop().map(|mut item| {
866872
if !self.is_empty() {
867873
swap(&mut item, &mut self.data[0]);
868-
self.sift_down_to_bottom(0);
874+
// SAFETY: !self.is_empty() means that self.len() > 0
875+
unsafe { self.sift_down_to_bottom(0) };
869876
}
870877
item
871878
})
@@ -891,7 +898,9 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
891898
pub fn push(&mut self, item: T) {
892899
let old_len = self.len();
893900
self.data.push(item);
894-
self.sift_up(0, old_len);
901+
// SAFETY: Since we pushed a new item it means that
902+
// old_len = self.len() - 1 < self.len()
903+
unsafe { self.sift_up(0, old_len) };
895904
}
896905

897906
/// Consumes the `BinaryHeap` and returns the underlying vector
@@ -946,7 +955,10 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
946955
let ptr = self.data.as_mut_ptr();
947956
ptr::swap(ptr, ptr.add(end));
948957
}
949-
self.sift_down_range(0, end);
958+
// SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
959+
// 0 < 1 <= end <= self.len() - 1 < self.len()
960+
// Which means 0 < end and end < self.len().
961+
unsafe { self.sift_down_range(0, end) };
950962
}
951963
self.into_vec()
952964
}
@@ -959,81 +971,139 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
959971
// the hole is filled back at the end of its scope, even on panic.
960972
// Using a hole reduces the constant factor compared to using swaps,
961973
// which involves twice as many moves.
962-
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
963-
unsafe {
964-
// Take out the value at `pos` and create a hole.
965-
let mut hole = Hole::new(&mut self.data, pos);
966-
967-
while hole.pos() > start {
968-
let parent = (hole.pos() - 1) / 2;
969-
// if hole.element() <= hole.get(parent) {
970-
if self.cmp.compare(hole.element(), hole.get(parent)) != Ordering::Greater {
971-
break;
972-
}
973-
hole.move_to(parent);
974+
975+
/// # Safety
976+
///
977+
/// The caller must guarantee that `pos < self.len()`.
978+
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
979+
// Take out the value at `pos` and create a hole.
980+
// SAFETY: The caller guarantees that pos < self.len()
981+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
982+
983+
while hole.pos() > start {
984+
let parent = (hole.pos() - 1) / 2;
985+
986+
// SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
987+
// and so hole.pos() - 1 can't underflow.
988+
// This guarantees that parent < hole.pos() so
989+
// it's a valid index and also != hole.pos().
990+
if self
991+
.cmp
992+
.compare(hole.element(), unsafe { hole.get(parent) })
993+
!= Ordering::Greater
994+
{
995+
break;
974996
}
975-
hole.pos()
997+
998+
// SAFETY: Same as above
999+
unsafe { hole.move_to(parent) };
9761000
}
1001+
1002+
hole.pos()
9771003
}
9781004

9791005
/// Take an element at `pos` and move it down the heap,
9801006
/// while its children are larger.
981-
fn sift_down_range(&mut self, pos: usize, end: usize) {
982-
unsafe {
983-
let mut hole = Hole::new(&mut self.data, pos);
984-
let mut child = 2 * pos + 1;
985-
while child < end - 1 {
986-
// compare with the greater of the two children
987-
// if !(hole.get(child) > hole.get(child + 1)) { child += 1 }
988-
child += (self.cmp.compare(hole.get(child), hole.get(child + 1))
989-
!= Ordering::Greater) as usize;
990-
// if we are already in order, stop.
991-
// if hole.element() >= hole.get(child) {
992-
if self.cmp.compare(hole.element(), hole.get(child)) != Ordering::Less {
993-
return;
994-
}
995-
hole.move_to(child);
996-
child = 2 * hole.pos() + 1;
997-
}
998-
if child == end - 1
999-
&& self.cmp.compare(hole.element(), hole.get(child)) == Ordering::Less
1000-
{
1001-
hole.move_to(child);
1007+
///
1008+
/// # Safety
1009+
///
1010+
/// The caller must guarantee that `pos < end <= self.len()`.
1011+
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
1012+
// SAFETY: The caller guarantees that pos < end <= self.len().
1013+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
1014+
let mut child = 2 * hole.pos() + 1;
1015+
1016+
// Loop invariant: child == 2 * hole.pos() + 1.
1017+
while child <= end.saturating_sub(2) {
1018+
// compare with the greater of the two children
1019+
// SAFETY: child < end - 1 < self.len() and
1020+
// child + 1 < end <= self.len(), so they're valid indexes.
1021+
// child == 2 * hole.pos() + 1 != hole.pos() and
1022+
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
1023+
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
1024+
// if T is a ZST
1025+
child += unsafe {
1026+
self.cmp.compare(hole.get(child), hole.get(child + 1)) != Ordering::Greater
1027+
} as usize;
1028+
1029+
// if we are already in order, stop.
1030+
// SAFETY: child is now either the old child or the old child+1
1031+
// We already proven that both are < self.len() and != hole.pos()
1032+
if self.cmp.compare(hole.element(), unsafe { hole.get(child) }) != Ordering::Less {
1033+
return;
10021034
}
1035+
1036+
// SAFETY: same as above.
1037+
unsafe { hole.move_to(child) };
1038+
child = 2 * hole.pos() + 1;
1039+
}
1040+
1041+
// SAFETY: && short circuit, which means that in the
1042+
// second condition it's already true that child == end - 1 < self.len().
1043+
if child == end - 1
1044+
&& self.cmp.compare(hole.element(), unsafe { hole.get(child) }) == Ordering::Less
1045+
{
1046+
// SAFETY: child is already proven to be a valid index and
1047+
// child == 2 * hole.pos() + 1 != hole.pos().
1048+
unsafe { hole.move_to(child) };
10031049
}
10041050
}
10051051

1006-
fn sift_down(&mut self, pos: usize) {
1052+
/// # Safety
1053+
///
1054+
/// The caller must guarantee that `pos < self.len()`.
1055+
unsafe fn sift_down(&mut self, pos: usize) {
10071056
let len = self.len();
1008-
self.sift_down_range(pos, len);
1057+
// SAFETY: pos < len is guaranteed by the caller and
1058+
// obviously len = self.len() <= self.len().
1059+
unsafe { self.sift_down_range(pos, len) };
10091060
}
10101061

10111062
/// Take an element at `pos` and move it all the way down the heap,
10121063
/// then sift it up to its position.
10131064
///
10141065
/// Note: This is faster when the element is known to be large / should
10151066
/// be closer to the bottom.
1016-
fn sift_down_to_bottom(&mut self, mut pos: usize) {
1067+
///
1068+
/// # Safety
1069+
///
1070+
/// The caller must guarantee that `pos < self.len()`.
1071+
unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
10171072
let end = self.len();
10181073
let start = pos;
1019-
unsafe {
1020-
let mut hole = Hole::new(&mut self.data, pos);
1021-
let mut child = 2 * pos + 1;
1022-
while child < end - 1 {
1023-
let right = child + 1;
1024-
// compare with the greater of the two children
1025-
// if !(hole.get(child) > hole.get(right)) { child += 1 }
1026-
child += (self.cmp.compare(hole.get(child), hole.get(right)) != Ordering::Greater)
1027-
as usize;
1028-
hole.move_to(child);
1029-
child = 2 * hole.pos() + 1;
1030-
}
1031-
if child == end - 1 {
1032-
hole.move_to(child);
1033-
}
1034-
pos = hole.pos;
1074+
1075+
// SAFETY: The caller guarantees that pos < self.len().
1076+
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
1077+
let mut child = 2 * hole.pos() + 1;
1078+
1079+
// Loop invariant: child == 2 * hole.pos() + 1.
1080+
while child <= end.saturating_sub(2) {
1081+
// SAFETY: child < end - 1 < self.len() and
1082+
// child + 1 < end <= self.len(), so they're valid indexes.
1083+
// child == 2 * hole.pos() + 1 != hole.pos() and
1084+
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
1085+
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
1086+
// if T is a ZST
1087+
child += unsafe {
1088+
self.cmp.compare(hole.get(child), hole.get(child + 1)) != Ordering::Greater
1089+
} as usize;
1090+
1091+
// SAFETY: Same as above
1092+
unsafe { hole.move_to(child) };
1093+
child = 2 * hole.pos() + 1;
1094+
}
1095+
1096+
if child == end - 1 {
1097+
// SAFETY: child == end - 1 < self.len(), so it's a valid index
1098+
// and child == 2 * hole.pos() + 1 != hole.pos().
1099+
unsafe { hole.move_to(child) };
10351100
}
1036-
self.sift_up(start, pos);
1101+
pos = hole.pos();
1102+
drop(hole);
1103+
1104+
// SAFETY: pos is the position in the hole and was already proven
1105+
// to be a valid index.
1106+
unsafe { self.sift_up(start, pos) };
10371107
}
10381108

10391109
/// Returns the length of the binary heap.
@@ -1129,7 +1199,10 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
11291199
let mut n = self.len() / 2;
11301200
while n > 0 {
11311201
n -= 1;
1132-
self.sift_down(n);
1202+
// SAFETY: n starts from self.len() / 2 and goes down to 0.
1203+
// The only case when !(n < self.len()) is if
1204+
// self.len() == 0, but it's ruled out by the loop condition.
1205+
unsafe { self.sift_down(n) };
11331206
}
11341207
}
11351208

@@ -1205,7 +1278,8 @@ impl<'a, T> Hole<'a, T> {
12051278
#[inline]
12061279
unsafe fn new(data: &'a mut [T], pos: usize) -> Self {
12071280
debug_assert!(pos < data.len());
1208-
let elt = ptr::read(&data[pos]);
1281+
// SAFE: pos should be inside the slice
1282+
let elt = unsafe { ptr::read(data.get_unchecked(pos)) };
12091283
Hole {
12101284
data,
12111285
elt: Some(elt),
@@ -1231,7 +1305,7 @@ impl<'a, T> Hole<'a, T> {
12311305
unsafe fn get(&self, index: usize) -> &T {
12321306
debug_assert!(index != self.pos);
12331307
debug_assert!(index < self.data.len());
1234-
self.data.get_unchecked(index)
1308+
unsafe { self.data.get_unchecked(index) }
12351309
}
12361310

12371311
/// Move hole to new location
@@ -1241,9 +1315,12 @@ impl<'a, T> Hole<'a, T> {
12411315
unsafe fn move_to(&mut self, index: usize) {
12421316
debug_assert!(index != self.pos);
12431317
debug_assert!(index < self.data.len());
1244-
let index_ptr: *const _ = self.data.get_unchecked(index);
1245-
let hole_ptr = self.data.get_unchecked_mut(self.pos);
1246-
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
1318+
unsafe {
1319+
let ptr = self.data.as_mut_ptr();
1320+
let index_ptr: *const _ = ptr.add(index);
1321+
let hole_ptr = ptr.add(self.pos);
1322+
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
1323+
}
12471324
self.pos = index;
12481325
}
12491326
}

0 commit comments

Comments
 (0)