Skip to content

Commit 84ae969

Browse files
committed
Prevent Vec::drain_filter from double dropping on panic
Fixes: #60977
1 parent 5187be6 commit 84ae969

File tree

2 files changed

+162
-10
lines changed

2 files changed

+162
-10
lines changed

src/liballoc/tests/vec.rs

+99
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,105 @@ fn drain_filter_complex() {
945945
}
946946
}
947947

948+
#[test]
949+
fn drain_filter_consumed_panic() {
950+
use std::rc::Rc;
951+
use std::sync::Mutex;
952+
953+
struct Check {
954+
index: usize,
955+
drop_counts: Rc<Mutex<Vec<usize>>>,
956+
};
957+
958+
impl Drop for Check {
959+
fn drop(&mut self) {
960+
self.drop_counts.lock().unwrap()[self.index] += 1;
961+
println!("drop: {}", self.index);
962+
}
963+
}
964+
965+
let check_count = 10;
966+
let drop_counts = Rc::new(Mutex::new(vec![0_usize; check_count]));
967+
let mut data: Vec<Check> = (0..check_count)
968+
.map(|index| Check { index, drop_counts: Rc::clone(&drop_counts) })
969+
.collect();
970+
971+
let _ = std::panic::catch_unwind(move || {
972+
let filter = |c: &mut Check| {
973+
if c.index == 2 {
974+
panic!("panic at index: {}", c.index);
975+
}
976+
// Verify that if the filter could panic again on another element
977+
// that it would not cause a double panic and all elements of the
978+
// vec would still be dropped exactly once.
979+
if c.index == 4 {
980+
panic!("panic at index: {}", c.index);
981+
}
982+
c.index < 6
983+
};
984+
let drain = data.drain_filter(filter);
985+
986+
// NOTE: The DrainFilter is explictly consumed
987+
drain.for_each(drop);
988+
});
989+
990+
let drop_counts = drop_counts.lock().unwrap();
991+
assert_eq!(check_count, drop_counts.len());
992+
993+
for (index, count) in drop_counts.iter().cloned().enumerate() {
994+
assert_eq!(1, count, "unexpected drop count at index: {} (count: {})", index, count);
995+
}
996+
}
997+
998+
#[test]
999+
fn drain_filter_unconsumed_panic() {
1000+
use std::rc::Rc;
1001+
use std::sync::Mutex;
1002+
1003+
struct Check {
1004+
index: usize,
1005+
drop_counts: Rc<Mutex<Vec<usize>>>,
1006+
};
1007+
1008+
impl Drop for Check {
1009+
fn drop(&mut self) {
1010+
self.drop_counts.lock().unwrap()[self.index] += 1;
1011+
println!("drop: {}", self.index);
1012+
}
1013+
}
1014+
1015+
let check_count = 10;
1016+
let drop_counts = Rc::new(Mutex::new(vec![0_usize; check_count]));
1017+
let mut data: Vec<Check> = (0..check_count)
1018+
.map(|index| Check { index, drop_counts: Rc::clone(&drop_counts) })
1019+
.collect();
1020+
1021+
let _ = std::panic::catch_unwind(move || {
1022+
let filter = |c: &mut Check| {
1023+
if c.index == 2 {
1024+
panic!("panic at index: {}", c.index);
1025+
}
1026+
// Verify that if the filter could panic again on another element
1027+
// that it would not cause a double panic and all elements of the
1028+
// vec would still be dropped exactly once.
1029+
if c.index == 4 {
1030+
panic!("panic at index: {}", c.index);
1031+
}
1032+
c.index < 6
1033+
};
1034+
let _drain = data.drain_filter(filter);
1035+
1036+
// NOTE: The DrainFilter is dropped without being consumed
1037+
});
1038+
1039+
let drop_counts = drop_counts.lock().unwrap();
1040+
assert_eq!(check_count, drop_counts.len());
1041+
1042+
for (index, count) in drop_counts.iter().cloned().enumerate() {
1043+
assert_eq!(1, count, "unexpected drop count at index: {} (count: {})", index, count);
1044+
}
1045+
}
1046+
9481047
#[test]
9491048
fn test_reserve_exact() {
9501049
// This is all the same as test_reserve

src/liballoc/vec.rs

+63-10
Original file line numberDiff line numberDiff line change
@@ -2120,6 +2120,7 @@ impl<T> Vec<T> {
21202120
del: 0,
21212121
old_len,
21222122
pred: filter,
2123+
panic_flag: false,
21232124
}
21242125
}
21252126
}
@@ -2751,6 +2752,7 @@ pub struct DrainFilter<'a, T, F>
27512752
del: usize,
27522753
old_len: usize,
27532754
pred: F,
2755+
panic_flag: bool,
27542756
}
27552757

27562758
#[unstable(feature = "drain_filter", reason = "recently added", issue = "43244")]
@@ -2760,21 +2762,34 @@ impl<T, F> Iterator for DrainFilter<'_, T, F>
27602762
type Item = T;
27612763

27622764
fn next(&mut self) -> Option<T> {
2765+
struct SetIdxOnDrop<'a> {
2766+
idx: &'a mut usize,
2767+
new_idx: usize,
2768+
}
2769+
2770+
impl<'a> Drop for SetIdxOnDrop<'a> {
2771+
fn drop(&mut self) {
2772+
*self.idx = self.new_idx;
2773+
}
2774+
}
2775+
27632776
unsafe {
2764-
while self.idx != self.old_len {
2777+
while self.idx < self.old_len {
27652778
let i = self.idx;
2766-
self.idx += 1;
27672779
let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len);
2768-
if (self.pred)(&mut v[i]) {
2780+
let mut set_idx = SetIdxOnDrop { new_idx: self.idx, idx: &mut self.idx };
2781+
self.panic_flag = true;
2782+
let drained = (self.pred)(&mut v[i]);
2783+
self.panic_flag = false;
2784+
set_idx.new_idx += 1;
2785+
if drained {
27692786
self.del += 1;
27702787
return Some(ptr::read(&v[i]));
2771-
} else if self.del > 0 {
2788+
}
2789+
else if self.del > 0 {
27722790
let del = self.del;
27732791
let src: *const T = &v[i];
27742792
let dst: *mut T = &mut v[i - del];
2775-
// This is safe because self.vec has length 0
2776-
// thus its elements will not have Drop::drop
2777-
// called on them in the event of a panic.
27782793
ptr::copy_nonoverlapping(src, dst, 1);
27792794
}
27802795
}
@@ -2792,9 +2807,47 @@ impl<T, F> Drop for DrainFilter<'_, T, F>
27922807
where F: FnMut(&mut T) -> bool,
27932808
{
27942809
fn drop(&mut self) {
2795-
self.for_each(drop);
2796-
unsafe {
2797-
self.vec.set_len(self.old_len - self.del);
2810+
// If the predicate panics, we still need to backshift everything
2811+
// down after the last successfully drained element, but no additional
2812+
// elements are drained or checked.
2813+
struct BackshiftOnDrop<'a, 'b, T, F>
2814+
where
2815+
F: FnMut(&mut T) -> bool,
2816+
{
2817+
drain: &'b mut DrainFilter<'a, T, F>,
2818+
}
2819+
2820+
impl<'a, 'b, T, F> Drop for BackshiftOnDrop<'a, 'b, T, F>
2821+
where
2822+
F: FnMut(&mut T) -> bool
2823+
{
2824+
fn drop(&mut self) {
2825+
unsafe {
2826+
while self.drain.idx < self.drain.old_len {
2827+
let i = self.drain.idx;
2828+
self.drain.idx += 1;
2829+
let v = slice::from_raw_parts_mut(
2830+
self.drain.vec.as_mut_ptr(),
2831+
self.drain.old_len,
2832+
);
2833+
if self.drain.del > 0 {
2834+
let del = self.drain.del;
2835+
let src: *const T = &v[i];
2836+
let dst: *mut T = &mut v[i - del];
2837+
ptr::copy_nonoverlapping(src, dst, 1);
2838+
}
2839+
}
2840+
self.drain.vec.set_len(self.drain.old_len - self.drain.del);
2841+
}
2842+
}
2843+
}
2844+
2845+
let backshift = BackshiftOnDrop {
2846+
drain: self
2847+
};
2848+
2849+
if !backshift.drain.panic_flag {
2850+
backshift.drain.for_each(drop);
27982851
}
27992852
}
28002853
}

0 commit comments

Comments
 (0)