155
155
//! }
156
156
//! ```
157
157
158
+ #![ deny( unsafe_op_in_unsafe_fn) ]
158
159
#![ allow( clippy:: needless_doctest_main) ]
159
160
#![ allow( missing_docs) ]
160
161
// #![stable(feature = "rust1", since = "1.0.0")]
@@ -319,7 +320,8 @@ impl<'a, T: fmt::Debug, C: Compare<T>> fmt::Debug for PeekMut<'a, T, C> {
319
320
impl < ' a , T , C : Compare < T > > Drop for PeekMut < ' a , T , C > {
320
321
fn drop ( & mut self ) {
321
322
if self . sift {
322
- self . heap . sift_down ( 0 ) ;
323
+ // SAFETY: PeekMut is only instantiated for non-empty heaps.
324
+ unsafe { self . heap . sift_down ( 0 ) } ;
323
325
}
324
326
}
325
327
}
@@ -328,15 +330,19 @@ impl<'a, T, C: Compare<T>> Drop for PeekMut<'a, T, C> {
328
330
impl < ' a , T , C : Compare < T > > Deref for PeekMut < ' a , T , C > {
329
331
type Target = T ;
330
332
fn deref ( & self ) -> & T {
331
- & self . heap . data [ 0 ]
333
+ debug_assert ! ( !self . heap. is_empty( ) ) ;
334
+ // SAFE: PeekMut is only instantiated for non-empty heaps
335
+ unsafe { self . heap . data . get_unchecked ( 0 ) }
332
336
}
333
337
}
334
338
335
339
// #[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
336
340
impl < ' a , T , C : Compare < T > > DerefMut for PeekMut < ' a , T , C > {
337
341
fn deref_mut ( & mut self ) -> & mut T {
342
+ debug_assert ! ( !self . heap. is_empty( ) ) ;
338
343
self . sift = true ;
339
- & mut self . heap . data [ 0 ]
344
+ // SAFE: PeekMut is only instantiated for non-empty heaps
345
+ unsafe { self . heap . data . get_unchecked_mut ( 0 ) }
340
346
}
341
347
}
342
348
@@ -865,7 +871,8 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
865
871
self . data . pop ( ) . map ( |mut item| {
866
872
if !self . is_empty ( ) {
867
873
swap ( & mut item, & mut self . data [ 0 ] ) ;
868
- self . sift_down_to_bottom ( 0 ) ;
874
+ // SAFETY: !self.is_empty() means that self.len() > 0
875
+ unsafe { self . sift_down_to_bottom ( 0 ) } ;
869
876
}
870
877
item
871
878
} )
@@ -891,7 +898,9 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
891
898
pub fn push ( & mut self , item : T ) {
892
899
let old_len = self . len ( ) ;
893
900
self . data . push ( item) ;
894
- self . sift_up ( 0 , old_len) ;
901
+ // SAFETY: Since we pushed a new item it means that
902
+ // old_len = self.len() - 1 < self.len()
903
+ unsafe { self . sift_up ( 0 , old_len) } ;
895
904
}
896
905
897
906
/// Consumes the `BinaryHeap` and returns the underlying vector
@@ -946,7 +955,10 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
946
955
let ptr = self . data . as_mut_ptr ( ) ;
947
956
ptr:: swap ( ptr, ptr. add ( end) ) ;
948
957
}
949
- self . sift_down_range ( 0 , end) ;
958
+ // SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
959
+ // 0 < 1 <= end <= self.len() - 1 < self.len()
960
+ // Which means 0 < end and end < self.len().
961
+ unsafe { self . sift_down_range ( 0 , end) } ;
950
962
}
951
963
self . into_vec ( )
952
964
}
@@ -959,81 +971,139 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
959
971
// the hole is filled back at the end of its scope, even on panic.
960
972
// Using a hole reduces the constant factor compared to using swaps,
961
973
// which involves twice as many moves.
962
- fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
963
- unsafe {
964
- // Take out the value at `pos` and create a hole.
965
- let mut hole = Hole :: new ( & mut self . data , pos) ;
966
-
967
- while hole. pos ( ) > start {
968
- let parent = ( hole. pos ( ) - 1 ) / 2 ;
969
- // if hole.element() <= hole.get(parent) {
970
- if self . cmp . compare ( hole. element ( ) , hole. get ( parent) ) != Ordering :: Greater {
971
- break ;
972
- }
973
- hole. move_to ( parent) ;
974
+
975
+ /// # Safety
976
+ ///
977
+ /// The caller must guarantee that `pos < self.len()`.
978
+ unsafe fn sift_up ( & mut self , start : usize , pos : usize ) -> usize {
979
+ // Take out the value at `pos` and create a hole.
980
+ // SAFETY: The caller guarantees that pos < self.len()
981
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
982
+
983
+ while hole. pos ( ) > start {
984
+ let parent = ( hole. pos ( ) - 1 ) / 2 ;
985
+
986
+ // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
987
+ // and so hole.pos() - 1 can't underflow.
988
+ // This guarantees that parent < hole.pos() so
989
+ // it's a valid index and also != hole.pos().
990
+ if self
991
+ . cmp
992
+ . compare ( hole. element ( ) , unsafe { hole. get ( parent) } )
993
+ != Ordering :: Greater
994
+ {
995
+ break ;
974
996
}
975
- hole. pos ( )
997
+
998
+ // SAFETY: Same as above
999
+ unsafe { hole. move_to ( parent) } ;
976
1000
}
1001
+
1002
+ hole. pos ( )
977
1003
}
978
1004
979
1005
/// Take an element at `pos` and move it down the heap,
980
1006
/// while its children are larger.
981
- fn sift_down_range ( & mut self , pos : usize , end : usize ) {
982
- unsafe {
983
- let mut hole = Hole :: new ( & mut self . data , pos) ;
984
- let mut child = 2 * pos + 1 ;
985
- while child < end - 1 {
986
- // compare with the greater of the two children
987
- // if !(hole.get(child) > hole.get(child + 1)) { child += 1 }
988
- child += ( self . cmp . compare ( hole. get ( child) , hole. get ( child + 1 ) )
989
- != Ordering :: Greater ) as usize ;
990
- // if we are already in order, stop.
991
- // if hole.element() >= hole.get(child) {
992
- if self . cmp . compare ( hole. element ( ) , hole. get ( child) ) != Ordering :: Less {
993
- return ;
994
- }
995
- hole. move_to ( child) ;
996
- child = 2 * hole. pos ( ) + 1 ;
997
- }
998
- if child == end - 1
999
- && self . cmp . compare ( hole. element ( ) , hole. get ( child) ) == Ordering :: Less
1000
- {
1001
- hole. move_to ( child) ;
1007
+ ///
1008
+ /// # Safety
1009
+ ///
1010
+ /// The caller must guarantee that `pos < end <= self.len()`.
1011
+ unsafe fn sift_down_range ( & mut self , pos : usize , end : usize ) {
1012
+ // SAFETY: The caller guarantees that pos < end <= self.len().
1013
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
1014
+ let mut child = 2 * hole. pos ( ) + 1 ;
1015
+
1016
+ // Loop invariant: child == 2 * hole.pos() + 1.
1017
+ while child <= end. saturating_sub ( 2 ) {
1018
+ // compare with the greater of the two children
1019
+ // SAFETY: child < end - 1 < self.len() and
1020
+ // child + 1 < end <= self.len(), so they're valid indexes.
1021
+ // child == 2 * hole.pos() + 1 != hole.pos() and
1022
+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
1023
+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
1024
+ // if T is a ZST
1025
+ child += unsafe {
1026
+ self . cmp . compare ( hole. get ( child) , hole. get ( child + 1 ) ) != Ordering :: Greater
1027
+ } as usize ;
1028
+
1029
+ // if we are already in order, stop.
1030
+ // SAFETY: child is now either the old child or the old child+1
1031
+ // We already proven that both are < self.len() and != hole.pos()
1032
+ if self . cmp . compare ( hole. element ( ) , unsafe { hole. get ( child) } ) != Ordering :: Less {
1033
+ return ;
1002
1034
}
1035
+
1036
+ // SAFETY: same as above.
1037
+ unsafe { hole. move_to ( child) } ;
1038
+ child = 2 * hole. pos ( ) + 1 ;
1039
+ }
1040
+
1041
+ // SAFETY: && short circuit, which means that in the
1042
+ // second condition it's already true that child == end - 1 < self.len().
1043
+ if child == end - 1
1044
+ && self . cmp . compare ( hole. element ( ) , unsafe { hole. get ( child) } ) == Ordering :: Less
1045
+ {
1046
+ // SAFETY: child is already proven to be a valid index and
1047
+ // child == 2 * hole.pos() + 1 != hole.pos().
1048
+ unsafe { hole. move_to ( child) } ;
1003
1049
}
1004
1050
}
1005
1051
1006
- fn sift_down ( & mut self , pos : usize ) {
1052
+ /// # Safety
1053
+ ///
1054
+ /// The caller must guarantee that `pos < self.len()`.
1055
+ unsafe fn sift_down ( & mut self , pos : usize ) {
1007
1056
let len = self . len ( ) ;
1008
- self . sift_down_range ( pos, len) ;
1057
+ // SAFETY: pos < len is guaranteed by the caller and
1058
+ // obviously len = self.len() <= self.len().
1059
+ unsafe { self . sift_down_range ( pos, len) } ;
1009
1060
}
1010
1061
1011
1062
/// Take an element at `pos` and move it all the way down the heap,
1012
1063
/// then sift it up to its position.
1013
1064
///
1014
1065
/// Note: This is faster when the element is known to be large / should
1015
1066
/// be closer to the bottom.
1016
- fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
1067
+ ///
1068
+ /// # Safety
1069
+ ///
1070
+ /// The caller must guarantee that `pos < self.len()`.
1071
+ unsafe fn sift_down_to_bottom ( & mut self , mut pos : usize ) {
1017
1072
let end = self . len ( ) ;
1018
1073
let start = pos;
1019
- unsafe {
1020
- let mut hole = Hole :: new ( & mut self . data , pos) ;
1021
- let mut child = 2 * pos + 1 ;
1022
- while child < end - 1 {
1023
- let right = child + 1 ;
1024
- // compare with the greater of the two children
1025
- // if !(hole.get(child) > hole.get(right)) { child += 1 }
1026
- child += ( self . cmp . compare ( hole. get ( child) , hole. get ( right) ) != Ordering :: Greater )
1027
- as usize ;
1028
- hole. move_to ( child) ;
1029
- child = 2 * hole. pos ( ) + 1 ;
1030
- }
1031
- if child == end - 1 {
1032
- hole. move_to ( child) ;
1033
- }
1034
- pos = hole. pos ;
1074
+
1075
+ // SAFETY: The caller guarantees that pos < self.len().
1076
+ let mut hole = unsafe { Hole :: new ( & mut self . data , pos) } ;
1077
+ let mut child = 2 * hole. pos ( ) + 1 ;
1078
+
1079
+ // Loop invariant: child == 2 * hole.pos() + 1.
1080
+ while child <= end. saturating_sub ( 2 ) {
1081
+ // SAFETY: child < end - 1 < self.len() and
1082
+ // child + 1 < end <= self.len(), so they're valid indexes.
1083
+ // child == 2 * hole.pos() + 1 != hole.pos() and
1084
+ // child + 1 == 2 * hole.pos() + 2 != hole.pos().
1085
+ // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
1086
+ // if T is a ZST
1087
+ child += unsafe {
1088
+ self . cmp . compare ( hole. get ( child) , hole. get ( child + 1 ) ) != Ordering :: Greater
1089
+ } as usize ;
1090
+
1091
+ // SAFETY: Same as above
1092
+ unsafe { hole. move_to ( child) } ;
1093
+ child = 2 * hole. pos ( ) + 1 ;
1094
+ }
1095
+
1096
+ if child == end - 1 {
1097
+ // SAFETY: child == end - 1 < self.len(), so it's a valid index
1098
+ // and child == 2 * hole.pos() + 1 != hole.pos().
1099
+ unsafe { hole. move_to ( child) } ;
1035
1100
}
1036
- self . sift_up ( start, pos) ;
1101
+ pos = hole. pos ( ) ;
1102
+ drop ( hole) ;
1103
+
1104
+ // SAFETY: pos is the position in the hole and was already proven
1105
+ // to be a valid index.
1106
+ unsafe { self . sift_up ( start, pos) } ;
1037
1107
}
1038
1108
1039
1109
/// Returns the length of the binary heap.
@@ -1129,7 +1199,10 @@ impl<T, C: Compare<T>> BinaryHeap<T, C> {
1129
1199
let mut n = self . len ( ) / 2 ;
1130
1200
while n > 0 {
1131
1201
n -= 1 ;
1132
- self . sift_down ( n) ;
1202
+ // SAFETY: n starts from self.len() / 2 and goes down to 0.
1203
+ // The only case when !(n < self.len()) is if
1204
+ // self.len() == 0, but it's ruled out by the loop condition.
1205
+ unsafe { self . sift_down ( n) } ;
1133
1206
}
1134
1207
}
1135
1208
@@ -1205,7 +1278,8 @@ impl<'a, T> Hole<'a, T> {
1205
1278
#[ inline]
1206
1279
unsafe fn new ( data : & ' a mut [ T ] , pos : usize ) -> Self {
1207
1280
debug_assert ! ( pos < data. len( ) ) ;
1208
- let elt = ptr:: read ( & data[ pos] ) ;
1281
+ // SAFE: pos should be inside the slice
1282
+ let elt = unsafe { ptr:: read ( data. get_unchecked ( pos) ) } ;
1209
1283
Hole {
1210
1284
data,
1211
1285
elt : Some ( elt) ,
@@ -1231,7 +1305,7 @@ impl<'a, T> Hole<'a, T> {
1231
1305
unsafe fn get ( & self , index : usize ) -> & T {
1232
1306
debug_assert ! ( index != self . pos) ;
1233
1307
debug_assert ! ( index < self . data. len( ) ) ;
1234
- self . data . get_unchecked ( index)
1308
+ unsafe { self . data . get_unchecked ( index) }
1235
1309
}
1236
1310
1237
1311
/// Move hole to new location
@@ -1241,9 +1315,12 @@ impl<'a, T> Hole<'a, T> {
1241
1315
unsafe fn move_to ( & mut self , index : usize ) {
1242
1316
debug_assert ! ( index != self . pos) ;
1243
1317
debug_assert ! ( index < self . data. len( ) ) ;
1244
- let index_ptr: * const _ = self . data . get_unchecked ( index) ;
1245
- let hole_ptr = self . data . get_unchecked_mut ( self . pos ) ;
1246
- ptr:: copy_nonoverlapping ( index_ptr, hole_ptr, 1 ) ;
1318
+ unsafe {
1319
+ let ptr = self . data . as_mut_ptr ( ) ;
1320
+ let index_ptr: * const _ = ptr. add ( index) ;
1321
+ let hole_ptr = ptr. add ( self . pos ) ;
1322
+ ptr:: copy_nonoverlapping ( index_ptr, hole_ptr, 1 ) ;
1323
+ }
1247
1324
self . pos = index;
1248
1325
}
1249
1326
}
0 commit comments