Skip to content

Commit 55b2338

Browse files
author
tugtugtug
committed
feat: Add retain_with_break to HashSet/Table/Map
With the removal of the raw table, it is hard to implement an efficient loop to conditionally remove/keep certain fields up to a limit. i.e. a loop that can be aborted and does not require rehash the key for removal of the entry.
1 parent b74e3a7 commit 55b2338

File tree

3 files changed

+220
-0
lines changed

3 files changed

+220
-0
lines changed

src/map.rs

+71
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,53 @@ impl<K, V, S, A: Allocator> HashMap<K, V, S, A> {
929929
}
930930
}
931931

932+
/// Retains only the elements specified by the predicate and breaks the iteration when
933+
/// the predicate fails. Keeps the allocated memory for reuse.
934+
///
935+
/// In other words, remove all pairs `(k, v)` such that `f(&k, &mut v)` returns `Ok(false)` until
936+
/// `f(&k, &mut v)` returns `Err(())`
937+
/// The elements are visited in unsorted (and unspecified) order.
938+
///
939+
/// # Examples
940+
///
941+
/// ```
942+
/// use hashbrown::HashMap;
943+
///
944+
/// let mut map: HashMap<i32, i32> = (0..8).map(|x|(x, x*10)).collect();
945+
/// assert_eq!(map.len(), 8);
946+
/// let mut removed = 0;
947+
/// map.retain_with_break(|&k, _| if removed < 3 {
948+
/// if k % 2 == 0 {
949+
/// Ok(true)
950+
/// } else {
951+
/// removed += 1;
952+
/// Ok(false)
953+
/// }
954+
/// } else {
955+
/// Err(())
956+
/// });
957+
///
958+
/// // We can see, that the number of elements inside map is changed and the
959+
/// // length matches when we have aborted the `Err(())`
960+
/// assert_eq!(map.len(), 5);
961+
/// ```
962+
pub fn retain_with_break<F>(&mut self, mut f: F)
963+
where
964+
F: FnMut(&K, &mut V) -> core::result::Result<bool, ()>,
965+
{
966+
// Here we only use `iter` as a temporary, preventing use-after-free
967+
unsafe {
968+
for item in self.table.iter() {
969+
let &mut (ref key, ref mut value) = item.as_mut();
970+
match f(key, value) {
971+
Ok(false) => self.table.erase(item),
972+
Err(_) => break,
973+
_ => continue,
974+
}
975+
}
976+
}
977+
}
978+
932979
/// Drains elements which are true under the given predicate,
933980
/// and returns an iterator over the removed items.
934981
///
@@ -5909,6 +5956,30 @@ mod test_map {
59095956
assert_eq!(map[&6], 60);
59105957
}
59115958

5959+
#[test]
5960+
fn test_retain_with_break() {
5961+
let mut map: HashMap<i32, i32> = (0..100).map(|x| (x, x * 10)).collect();
5962+
// looping and removing any key > 50, but stop after 40 iterations
5963+
let mut removed = 0;
5964+
map.retain_with_break(|&k, _| {
5965+
if removed < 40 {
5966+
if k > 50 {
5967+
removed += 1;
5968+
Ok(false)
5969+
} else {
5970+
Ok(true)
5971+
}
5972+
} else {
5973+
Err(())
5974+
}
5975+
});
5976+
assert_eq!(map.len(), 60);
5977+
// check nothing up to 50 is removed
5978+
for k in 0..=50 {
5979+
assert_eq!(map[&k], k * 10);
5980+
}
5981+
}
5982+
59125983
#[test]
59135984
fn test_extract_if() {
59145985
{

src/set.rs

+54
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,36 @@ impl<T, S, A: Allocator> HashSet<T, S, A> {
372372
self.map.retain(|k, _| f(k));
373373
}
374374

375+
/// Retains only the elements specified by the predicate until the predicate fails.
376+
///
377+
/// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)`.
378+
///
379+
/// # Examples
380+
///
381+
/// ```
382+
/// use hashbrown::HashSet;
383+
///
384+
/// let xs = [1,2,3,4,5,6];
385+
/// let mut set: HashSet<i32> = xs.into_iter().collect();
386+
/// let mut count = 0;
387+
/// set.retain_with_break(|&k| if count < 2 {
388+
/// if k % 2 == 0 {
389+
/// Ok(true)
390+
/// } else {
391+
/// Ok(false)
392+
/// }
393+
/// } else {
394+
/// Err(())
395+
/// });
396+
/// assert_eq!(set.len(), 3);
397+
/// ```
398+
pub fn retain_with_break<F>(&mut self, mut f: F)
399+
where
400+
F: FnMut(&T) -> core::result::Result<bool, ()>,
401+
{
402+
self.map.retain_with_break(|k, _| f(k));
403+
}
404+
375405
/// Drains elements which are true under the given predicate,
376406
/// and returns an iterator over the removed items.
377407
///
@@ -2980,6 +3010,30 @@ mod test_set {
29803010
assert!(set.contains(&6));
29813011
}
29823012

3013+
#[test]
3014+
fn test_retain_with_break() {
3015+
let mut set: HashSet<i32> = (0..100).collect();
3016+
// looping and removing any key > 50, but stop after 40 iterations
3017+
let mut removed = 0;
3018+
set.retain_with_break(|&k| {
3019+
if removed < 40 {
3020+
if k > 50 {
3021+
removed += 1;
3022+
Ok(false)
3023+
} else {
3024+
Ok(true)
3025+
}
3026+
} else {
3027+
Err(())
3028+
}
3029+
});
3030+
assert_eq!(set.len(), 60);
3031+
// check nothing up to 50 is removed
3032+
for k in 0..=50 {
3033+
assert!(set.contains(&k));
3034+
}
3035+
}
3036+
29833037
#[test]
29843038
fn test_extract_if() {
29853039
{

src/table.rs

+95
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,64 @@ where
870870
}
871871
}
872872

873+
/// Retains only the elements specified by the predicate until the predicate fails.
874+
///
875+
/// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until
876+
/// `f(&e)` returns `Err(())`
877+
///
878+
/// # Examples
879+
///
880+
/// ```
881+
/// # #[cfg(feature = "nightly")]
882+
/// # fn test() {
883+
/// use hashbrown::{HashTable, DefaultHashBuilder};
884+
/// use std::hash::BuildHasher;
885+
///
886+
/// let mut table = HashTable::new();
887+
/// let hasher = DefaultHashBuilder::default();
888+
/// let hasher = |val: &_| {
889+
/// use core::hash::Hasher;
890+
/// let mut state = hasher.build_hasher();
891+
/// core::hash::Hash::hash(&val, &mut state);
892+
/// state.finish()
893+
/// };
894+
/// let mut removed = 0;
895+
/// for x in 1..=8 {
896+
/// table.insert_unique(hasher(&x), x, hasher);
897+
/// }
898+
/// table.retain_with_break(|&mut v| if removed < 3 {
899+
/// if v % 2 == 0 {
900+
/// Ok(true)
901+
/// } else {
902+
/// removed += 1;
903+
/// Ok(false)
904+
/// }
905+
/// } else {
906+
/// Err(())
907+
/// });
908+
/// assert_eq!(table.len(), 5);
909+
/// # }
910+
/// # fn main() {
911+
/// # #[cfg(feature = "nightly")]
912+
/// # test()
913+
/// # }
914+
/// ```
915+
pub fn retain_with_break(
916+
&mut self,
917+
mut f: impl FnMut(&mut T) -> core::result::Result<bool, ()>,
918+
) {
919+
// Here we only use `iter` as a temporary, preventing use-after-free
920+
unsafe {
921+
for item in self.raw.iter() {
922+
match f(item.as_mut()) {
923+
Ok(false) => self.raw.erase(item),
924+
Err(_) => break,
925+
_ => continue,
926+
}
927+
}
928+
}
929+
}
930+
873931
/// Clears the set, returning all elements in an iterator.
874932
///
875933
/// # Examples
@@ -2372,12 +2430,49 @@ impl<T, F, A: Allocator> FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut
23722430

23732431
#[cfg(test)]
23742432
mod tests {
2433+
use crate::DefaultHashBuilder;
2434+
23752435
use super::HashTable;
23762436

2437+
use core::hash::BuildHasher;
23772438
#[test]
23782439
fn test_allocation_info() {
23792440
assert_eq!(HashTable::<()>::new().allocation_size(), 0);
23802441
assert_eq!(HashTable::<u32>::new().allocation_size(), 0);
23812442
assert!(HashTable::<u32>::with_capacity(1).allocation_size() > core::mem::size_of::<u32>());
23822443
}
2444+
2445+
#[test]
2446+
fn test_retain_with_break() {
2447+
let mut table = HashTable::new();
2448+
let hasher = DefaultHashBuilder::default();
2449+
let hasher = |val: &_| {
2450+
use core::hash::Hasher;
2451+
let mut state = hasher.build_hasher();
2452+
core::hash::Hash::hash(&val, &mut state);
2453+
state.finish()
2454+
};
2455+
for x in 0..100 {
2456+
table.insert_unique(hasher(&x), x, hasher);
2457+
}
2458+
// looping and removing any value > 50, but stop after 40 iterations
2459+
let mut removed = 0;
2460+
table.retain_with_break(|&mut v| {
2461+
if removed < 40 {
2462+
if v > 50 {
2463+
removed += 1;
2464+
Ok(false)
2465+
} else {
2466+
Ok(true)
2467+
}
2468+
} else {
2469+
Err(())
2470+
}
2471+
});
2472+
assert_eq!(table.len(), 60);
2473+
// check nothing up to 50 is removed
2474+
for v in 0..=50 {
2475+
assert_eq!(table.find(hasher(&v), |&val| val == v), Some(&v));
2476+
}
2477+
}
23832478
}

0 commit comments

Comments
 (0)