Skip to content

Commit f68b09c

Browse files
author
tugtugtug
committed
feat: Add retain_with_break to HashSet/Table/Map
With the removal of the raw table, it is hard to implement an efficient loop to conditionally remove/keep certain fields up to a limit. i.e. a loop that can be aborted and does not require rehash the key for removal of the entry.
1 parent b74e3a7 commit f68b09c

File tree

3 files changed

+218
-0
lines changed

3 files changed

+218
-0
lines changed

src/map.rs

+71
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,53 @@ impl<K, V, S, A: Allocator> HashMap<K, V, S, A> {
929929
}
930930
}
931931

932+
/// Retains only the elements specified by the predicate and breaks the iteration when
933+
/// the predicate fails. Keeps the allocated memory for reuse.
934+
///
935+
/// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Some(false)` until
936+
/// `f(&k, &mut v)` returns `None`
937+
/// The elements are visited in unsorted (and unspecified) order.
938+
///
939+
/// # Examples
940+
///
941+
/// ```
942+
/// use hashbrown::HashMap;
943+
///
944+
/// let mut map: HashMap<i32, i32> = (0..8).map(|x|(x, x*10)).collect();
945+
/// assert_eq!(map.len(), 8);
946+
/// let mut removed = 0;
947+
/// map.retain_with_break(|&k, _| if removed < 3 {
948+
/// if k % 2 == 0 {
949+
/// Some(true)
950+
/// } else {
951+
/// removed += 1;
952+
/// Some(false)
953+
/// }
954+
/// } else {
955+
/// None
956+
/// });
957+
///
958+
/// // We can see, that the number of elements inside map is changed and the
959+
/// // length matches when we have aborted the retain with the return of `None`
960+
/// assert_eq!(map.len(), 5);
961+
/// ```
962+
pub fn retain_with_break<F>(&mut self, mut f: F)
963+
where
964+
F: FnMut(&K, &mut V) -> Option<bool>,
965+
{
966+
// Here we only use `iter` as a temporary, preventing use-after-free
967+
unsafe {
968+
for item in self.table.iter() {
969+
let &mut (ref key, ref mut value) = item.as_mut();
970+
match f(key, value) {
971+
Some(false) => self.table.erase(item),
972+
Some(true) => continue,
973+
None => break,
974+
}
975+
}
976+
}
977+
}
978+
932979
/// Drains elements which are true under the given predicate,
933980
/// and returns an iterator over the removed items.
934981
///
@@ -5909,6 +5956,30 @@ mod test_map {
59095956
assert_eq!(map[&6], 60);
59105957
}
59115958

5959+
#[test]
5960+
fn test_retain_with_break() {
5961+
let mut map: HashMap<i32, i32> = (0..100).map(|x| (x, x * 10)).collect();
5962+
// looping and removing any key > 50, but stop after 40 iterations
5963+
let mut removed = 0;
5964+
map.retain_with_break(|&k, _| {
5965+
if removed < 40 {
5966+
if k > 50 {
5967+
removed += 1;
5968+
Some(false)
5969+
} else {
5970+
Some(true)
5971+
}
5972+
} else {
5973+
None
5974+
}
5975+
});
5976+
assert_eq!(map.len(), 60);
5977+
// check nothing up to 50 is removed
5978+
for k in 0..=50 {
5979+
assert_eq!(map[&k], k * 10);
5980+
}
5981+
}
5982+
59125983
#[test]
59135984
fn test_extract_if() {
59145985
{

src/set.rs

+55
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,37 @@ impl<T, S, A: Allocator> HashSet<T, S, A> {
372372
self.map.retain(|k, _| f(k));
373373
}
374374

375+
/// Retains only the elements specified by the predicate until the predicate returns `None`.
376+
///
377+
/// In other words, remove all elements `e` such that `f(&e)` returns `Some(false)` until
378+
/// `f(&e)` returns `None`.
379+
///
380+
/// # Examples
381+
///
382+
/// ```
383+
/// use hashbrown::HashSet;
384+
///
385+
/// let xs = [1,2,3,4,5,6];
386+
/// let mut set: HashSet<i32> = xs.into_iter().collect();
387+
/// let mut count = 0;
388+
/// set.retain_with_break(|&k| if count < 2 {
389+
/// if k % 2 == 0 {
390+
/// Some(true)
391+
/// } else {
392+
/// Some(false)
393+
/// }
394+
/// } else {
395+
/// None
396+
/// });
397+
/// assert_eq!(set.len(), 3);
398+
/// ```
399+
pub fn retain_with_break<F>(&mut self, mut f: F)
400+
where
401+
F: FnMut(&T) -> Option<bool>,
402+
{
403+
self.map.retain_with_break(|k, _| f(k));
404+
}
405+
375406
/// Drains elements which are true under the given predicate,
376407
/// and returns an iterator over the removed items.
377408
///
@@ -2980,6 +3011,30 @@ mod test_set {
29803011
assert!(set.contains(&6));
29813012
}
29823013

3014+
#[test]
3015+
fn test_retain_with_break() {
3016+
let mut set: HashSet<i32> = (0..100).collect();
3017+
// looping and removing any key > 50, but stop after 40 iterations
3018+
let mut removed = 0;
3019+
set.retain_with_break(|&k| {
3020+
if removed < 40 {
3021+
if k > 50 {
3022+
removed += 1;
3023+
Some(false)
3024+
} else {
3025+
Some(true)
3026+
}
3027+
} else {
3028+
None
3029+
}
3030+
});
3031+
assert_eq!(set.len(), 60);
3032+
// check nothing up to 50 is removed
3033+
for k in 0..=50 {
3034+
assert!(set.contains(&k));
3035+
}
3036+
}
3037+
29833038
#[test]
29843039
fn test_extract_if() {
29853040
{

src/table.rs

+92
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,61 @@ where
870870
}
871871
}
872872

873+
/// Retains only the elements specified by the predicate until the predicate returns `None`.
874+
///
875+
/// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until
876+
/// `f(&e)` returns `None`.
877+
///
878+
/// # Examples
879+
///
880+
/// ```
881+
/// # #[cfg(feature = "nightly")]
882+
/// # fn test() {
883+
/// use hashbrown::{HashTable, DefaultHashBuilder};
884+
/// use std::hash::BuildHasher;
885+
///
886+
/// let mut table = HashTable::new();
887+
/// let hasher = DefaultHashBuilder::default();
888+
/// let hasher = |val: &_| {
889+
/// use core::hash::Hasher;
890+
/// let mut state = hasher.build_hasher();
891+
/// core::hash::Hash::hash(&val, &mut state);
892+
/// state.finish()
893+
/// };
894+
/// let mut removed = 0;
895+
/// for x in 1..=8 {
896+
/// table.insert_unique(hasher(&x), x, hasher);
897+
/// }
898+
/// table.retain_with_break(|&mut v| if removed < 3 {
899+
/// if v % 2 == 0 {
900+
/// Some(true)
901+
/// } else {
902+
/// removed += 1;
903+
/// Some(false)
904+
/// }
905+
/// } else {
906+
/// None
907+
/// });
908+
/// assert_eq!(table.len(), 5);
909+
/// # }
910+
/// # fn main() {
911+
/// # #[cfg(feature = "nightly")]
912+
/// # test()
913+
/// # }
914+
/// ```
915+
pub fn retain_with_break(&mut self, mut f: impl FnMut(&mut T) -> Option<bool>) {
916+
// Here we only use `iter` as a temporary, preventing use-after-free
917+
unsafe {
918+
for item in self.raw.iter() {
919+
match f(item.as_mut()) {
920+
Some(false) => self.raw.erase(item),
921+
Some(true) => continue,
922+
None => break,
923+
}
924+
}
925+
}
926+
}
927+
873928
/// Clears the set, returning all elements in an iterator.
874929
///
875930
/// # Examples
@@ -2372,12 +2427,49 @@ impl<T, F, A: Allocator> FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut
23722427

23732428
#[cfg(test)]
23742429
mod tests {
2430+
use crate::DefaultHashBuilder;
2431+
23752432
use super::HashTable;
23762433

2434+
use core::hash::BuildHasher;
23772435
#[test]
23782436
fn test_allocation_info() {
23792437
assert_eq!(HashTable::<()>::new().allocation_size(), 0);
23802438
assert_eq!(HashTable::<u32>::new().allocation_size(), 0);
23812439
assert!(HashTable::<u32>::with_capacity(1).allocation_size() > core::mem::size_of::<u32>());
23822440
}
2441+
2442+
#[test]
2443+
fn test_retain_with_break() {
2444+
let mut table = HashTable::new();
2445+
let hasher = DefaultHashBuilder::default();
2446+
let hasher = |val: &_| {
2447+
use core::hash::Hasher;
2448+
let mut state = hasher.build_hasher();
2449+
core::hash::Hash::hash(&val, &mut state);
2450+
state.finish()
2451+
};
2452+
for x in 0..100 {
2453+
table.insert_unique(hasher(&x), x, hasher);
2454+
}
2455+
// looping and removing any value > 50, but stop after 40 iterations
2456+
let mut removed = 0;
2457+
table.retain_with_break(|&mut v| {
2458+
if removed < 40 {
2459+
if v > 50 {
2460+
removed += 1;
2461+
Some(false)
2462+
} else {
2463+
Some(true)
2464+
}
2465+
} else {
2466+
None
2467+
}
2468+
});
2469+
assert_eq!(table.len(), 60);
2470+
// check nothing up to 50 is removed
2471+
for v in 0..=50 {
2472+
assert_eq!(table.find(hasher(&v), |&val| val == v), Some(&v));
2473+
}
2474+
}
23832475
}

0 commit comments

Comments
 (0)