@@ -12,7 +12,7 @@ use core::{
1212 cmp:: { Ord , Ordering } ,
1313 marker:: PhantomData ,
1414 mem:: MaybeUninit ,
15- ptr:: { addr_of_mut, NonNull } ,
15+ ptr:: { addr_of_mut, from_mut , NonNull } ,
1616} ;
1717
1818/// A red-black tree with owned nodes.
@@ -194,11 +194,31 @@ impl<K, V> RBTree<K, V> {
194194
195195 /// Returns an iterator over the tree nodes, sorted by key.
196196 pub fn iter ( & self ) -> Iter < ' _ , K , V > {
197- // INVARIANT: `bindings::rb_first` returns a valid pointer to a tree node given a valid pointer to a tree root.
198197 Iter {
199198 _tree : PhantomData ,
200- // SAFETY: `self.root` is a valid pointer to the tree root.
201- next : unsafe { bindings:: rb_first ( & self . root ) } ,
199+ // INVARIANT:
200+ // - `self.root` is a valid pointer to a tree root.
201+ // - `bindings::rb_first` produces a valid pointer to a node given `root` is valid.
202+ iter_raw : IterRaw {
203+ // SAFETY: by the invariants, all pointers are valid.
204+ next : unsafe { bindings:: rb_first ( & self . root ) } ,
205+ _phantom : PhantomData ,
206+ } ,
207+ }
208+ }
209+
210+ /// Returns a mutable iterator over the tree nodes, sorted by key.
211+ pub fn iter_mut ( & mut self ) -> IterMut < ' _ , K , V > {
212+ IterMut {
213+ _tree : PhantomData ,
214+ // INVARIANT:
215+ // - `self.root` is a valid pointer to a tree root.
216+ // - `bindings::rb_first` produces a valid pointer to a node given `root` is valid.
217+ iter_raw : IterRaw {
218+ // SAFETY: by the invariants, all pointers are valid.
219+ next : unsafe { bindings:: rb_first ( from_mut ( & mut self . root ) ) } ,
220+ _phantom : PhantomData ,
221+ } ,
202222 }
203223 }
204224
@@ -211,6 +231,11 @@ impl<K, V> RBTree<K, V> {
211231 pub fn values ( & self ) -> impl Iterator < Item = & ' _ V > {
212232 self . iter ( ) . map ( |( _, v) | v)
213233 }
234+
235+ /// Returns a mutable iterator over the values of the nodes in the tree, sorted by key.
236+ pub fn values_mut ( & mut self ) -> impl Iterator < Item = & ' _ mut V > {
237+ self . iter_mut ( ) . map ( |( _, v) | v)
238+ }
214239}
215240
216241impl < K , V > RBTree < K , V >
@@ -414,13 +439,9 @@ impl<'a, K, V> IntoIterator for &'a RBTree<K, V> {
414439/// An iterator over the nodes of a [`RBTree`].
415440///
416441/// Instances are created by calling [`RBTree::iter`].
417- ///
418- /// # Invariants
419- /// - `self.next` is a valid pointer.
420- /// - `self.next` points to a node stored inside of a valid `RBTree`.
421442pub struct Iter < ' a , K , V > {
422443 _tree : PhantomData < & ' a RBTree < K , V > > ,
423- next : * mut bindings :: rb_node ,
444+ iter_raw : IterRaw < K , V > ,
424445}
425446
426447// SAFETY: The [`Iter`] gives out immutable references to K and V, so it has the same
@@ -434,21 +455,76 @@ unsafe impl<'a, K: Sync, V: Sync> Sync for Iter<'a, K, V> {}
434455impl < ' a , K , V > Iterator for Iter < ' a , K , V > {
435456 type Item = ( & ' a K , & ' a V ) ;
436457
458+ fn next ( & mut self ) -> Option < Self :: Item > {
459+ // SAFETY: Due to `self._tree`, `k` and `v` are valid for the lifetime of `'a`.
460+ self . iter_raw . next ( ) . map ( |( k, v) | unsafe { ( & * k, & * v) } )
461+ }
462+ }
463+
464+ impl < ' a , K , V > IntoIterator for & ' a mut RBTree < K , V > {
465+ type Item = ( & ' a K , & ' a mut V ) ;
466+ type IntoIter = IterMut < ' a , K , V > ;
467+
468+ fn into_iter ( self ) -> Self :: IntoIter {
469+ self . iter_mut ( )
470+ }
471+ }
472+
473+ /// A mutable iterator over the nodes of a [`RBTree`].
474+ ///
475+ /// Instances are created by calling [`RBTree::iter_mut`].
476+ pub struct IterMut < ' a , K , V > {
477+ _tree : PhantomData < & ' a mut RBTree < K , V > > ,
478+ iter_raw : IterRaw < K , V > ,
479+ }
480+
481+ // SAFETY: The [`IterMut`] has exclusive access to both `K` and `V`, so it is sufficient to require them to be `Send`.
482+ // The iterator only gives out immutable references to the keys, but since the iterator has excusive access to those same
483+ // keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the user.
484+ unsafe impl < ' a , K : Send , V : Send > Send for IterMut < ' a , K , V > { }
485+
486+ // SAFETY: The [`IterMut`] gives out immutable references to K and mutable references to V, so it has the same
487+ // thread safety requirements as mutable references.
488+ unsafe impl < ' a , K : Sync , V : Sync > Sync for IterMut < ' a , K , V > { }
489+
490+ impl < ' a , K , V > Iterator for IterMut < ' a , K , V > {
491+ type Item = ( & ' a K , & ' a mut V ) ;
492+
493+ fn next ( & mut self ) -> Option < Self :: Item > {
494+ self . iter_raw . next ( ) . map ( |( k, v) |
495+ // SAFETY: Due to `&mut self`, we have exclusive access to `k` and `v`, for the lifetime of `'a`.
496+ unsafe { ( & * k, & mut * v) } )
497+ }
498+ }
499+
500+ /// A raw iterator over the nodes of a [`RBTree`].
501+ ///
502+ /// # Invariants
503+ /// - `self.next` is a valid pointer.
504+ /// - `self.next` points to a node stored inside of a valid `RBTree`.
505+ struct IterRaw < K , V > {
506+ next : * mut bindings:: rb_node ,
507+ _phantom : PhantomData < fn ( ) -> ( K , V ) > ,
508+ }
509+
510+ impl < K , V > Iterator for IterRaw < K , V > {
511+ type Item = ( * mut K , * mut V ) ;
512+
437513 fn next ( & mut self ) -> Option < Self :: Item > {
438514 if self . next . is_null ( ) {
439515 return None ;
440516 }
441517
442- // SAFETY: By the type invariant of `Iter `, `self.next` is a valid node in an `RBTree`,
518+ // SAFETY: By the type invariant of `IterRaw `, `self.next` is a valid node in an `RBTree`,
443519 // and by the type invariant of `RBTree`, all nodes point to the links field of `Node<K, V>` objects.
444- let cur = unsafe { container_of ! ( self . next, Node <K , V >, links) } ;
520+ let cur: * mut Node < K , V > =
521+ unsafe { container_of ! ( self . next, Node <K , V >, links) } . cast_mut ( ) ;
445522
446523 // SAFETY: `self.next` is a valid tree node by the type invariants.
447524 self . next = unsafe { bindings:: rb_next ( self . next ) } ;
448525
449- // SAFETY: By the same reasoning above, it is safe to dereference the node. Additionally,
450- // it is ok to return a reference to members because the iterator must outlive it.
451- Some ( unsafe { ( & ( * cur) . key , & ( * cur) . value ) } )
526+ // SAFETY: By the same reasoning above, it is safe to dereference the node.
527+ Some ( unsafe { ( addr_of_mut ! ( ( * cur) . key) , addr_of_mut ! ( ( * cur) . value) ) } )
452528 }
453529}
454530
0 commit comments