@@ -56,7 +56,7 @@ pub struct MemoryStore {
56
56
inbound_group_sessions : GroupSessionStore ,
57
57
outbound_group_sessions : StdRwLock < BTreeMap < OwnedRoomId , OutboundGroupSession > > ,
58
58
private_identity : StdRwLock < Option < PrivateCrossSigningIdentity > > ,
59
- tracked_users : StdRwLock < Vec < TrackedUser > > ,
59
+ tracked_users : StdRwLock < HashMap < OwnedUserId , TrackedUser > > ,
60
60
olm_hashes : StdRwLock < HashMap < String , HashSet < String > > > ,
61
61
devices : DeviceStore ,
62
62
identities : StdRwLock < HashMap < OwnedUserId , ReadOnlyUserIdentities > > ,
@@ -324,12 +324,13 @@ impl CryptoStore for MemoryStore {
324
324
}
325
325
326
326
async fn load_tracked_users ( & self ) -> Result < Vec < TrackedUser > > {
327
- Ok ( self . tracked_users . read ( ) . unwrap ( ) . clone ( ) )
327
+ Ok ( self . tracked_users . read ( ) . unwrap ( ) . values ( ) . cloned ( ) . collect ( ) )
328
328
}
329
329
330
330
async fn save_tracked_users ( & self , tracked_users : & [ ( & UserId , bool ) ] ) -> Result < ( ) > {
331
331
self . tracked_users . write ( ) . unwrap ( ) . extend ( tracked_users. iter ( ) . map ( |( user_id, dirty) | {
332
- TrackedUser { user_id : user_id. to_owned ( ) . into ( ) , dirty : * dirty }
332
+ let user_id: OwnedUserId = user_id. to_owned ( ) . into ( ) ;
333
+ ( user_id. clone ( ) , TrackedUser { user_id, dirty : * dirty } )
333
334
} ) ) ;
334
335
Ok ( ( ) )
335
336
}
@@ -559,23 +560,29 @@ mod tests {
559
560
}
560
561
561
562
#[ async_test]
562
- async fn test_tracked_users_store ( ) {
563
- // Given some tracked users
564
- let tracked_users =
565
- & [ ( user_id ! ( "@dirty_user:s" ) , true ) , ( user_id ! ( "@clean_user:t" ) , false ) ] ;
566
-
567
- // When we save them to the store
563
+ async fn test_tracked_users_are_stored_once_per_user_id ( ) {
564
+ // Given a store containing 2 tracked users, both dirty
565
+ let user1 = user_id ! ( "@user1:s" ) ;
566
+ let user2 = user_id ! ( "@user2:s" ) ;
567
+ let user3 = user_id ! ( "@user3:s" ) ;
568
568
let store = MemoryStore :: new ( ) ;
569
- store. save_tracked_users ( tracked_users) . await . unwrap ( ) ;
569
+ store. save_tracked_users ( & [ ( user1, true ) , ( user2, true ) ] ) . await . unwrap ( ) ;
570
+
571
+ // When we mark one as clean and add another
572
+ store. save_tracked_users ( & [ ( user2, false ) , ( user3, false ) ] ) . await . unwrap ( ) ;
570
573
571
- // Then we can get them out again
574
+ // Then we can get them out again and their dirty flags are correct
572
575
let loaded_tracked_users =
573
576
store. load_tracked_users ( ) . await . expect ( "failed to load tracked users" ) ;
574
- assert_eq ! ( loaded_tracked_users[ 0 ] . user_id, user_id!( "@dirty_user:s" ) ) ;
575
- assert ! ( loaded_tracked_users[ 0 ] . dirty) ;
576
- assert_eq ! ( loaded_tracked_users[ 1 ] . user_id, user_id!( "@clean_user:t" ) ) ;
577
- assert ! ( !loaded_tracked_users[ 1 ] . dirty) ;
578
- assert_eq ! ( loaded_tracked_users. len( ) , 2 ) ;
577
+
578
+ let tracked_contains = |user_id, dirty| {
579
+ loaded_tracked_users. iter ( ) . any ( |u| u. user_id == user_id && u. dirty == dirty)
580
+ } ;
581
+
582
+ assert ! ( tracked_contains( user1, true ) ) ;
583
+ assert ! ( tracked_contains( user2, false ) ) ;
584
+ assert ! ( tracked_contains( user3, false ) ) ;
585
+ assert_eq ! ( loaded_tracked_users. len( ) , 3 ) ;
579
586
}
580
587
581
588
#[ async_test]
0 commit comments