@@ -26,7 +26,7 @@ use matrix_sdk_common::{
26
26
deserialized_responses:: WithheldCode , executor:: spawn, locks:: RwLock as StdRwLock ,
27
27
} ;
28
28
use ruma:: {
29
- events:: { AnyMessageLikeEventContent , ToDeviceEventType } ,
29
+ events:: { AnyMessageLikeEventContent , AnyToDeviceEventContent , ToDeviceEventType } ,
30
30
serde:: Raw ,
31
31
to_device:: DeviceIdOrAllDevices ,
32
32
OwnedDeviceId , OwnedRoomId , OwnedTransactionId , OwnedUserId , RoomId , TransactionId , UserId ,
@@ -254,29 +254,24 @@ impl GroupSessionManager {
254
254
}
255
255
}
256
256
257
- /// Encrypt the given content for the given devices and create a to-device
257
+ /// Encrypt the given content for the given devices and create to-device
258
258
/// requests that sends the encrypted content to them.
259
259
async fn encrypt_session_for (
260
260
store : Arc < CryptoStoreWrapper > ,
261
261
group_session : OutboundGroupSession ,
262
262
devices : Vec < DeviceData > ,
263
263
) -> OlmResult < (
264
- OwnedTransactionId ,
265
- ToDeviceRequest ,
264
+ EncryptForDevicesResult ,
266
265
BTreeMap < OwnedUserId , BTreeMap < OwnedDeviceId , ShareInfo > > ,
267
- Vec < Session > ,
268
- Vec < ( DeviceData , WithheldCode ) > ,
269
266
) > {
270
267
// Use a named type instead of a tuple with rather long type name
271
268
pub struct DeviceResult {
272
269
device : DeviceData ,
273
270
maybe_encrypted_room_key : MaybeEncryptedRoomKey ,
274
271
}
275
272
276
- let mut messages = BTreeMap :: new ( ) ;
277
- let mut changed_sessions = Vec :: new ( ) ;
273
+ let mut result_builder = EncryptForDevicesResultBuilder :: default ( ) ;
278
274
let mut share_infos = BTreeMap :: new ( ) ;
279
- let mut withheld_devices = Vec :: new ( ) ;
280
275
281
276
// XXX is there a way to do this that doesn't involve cloning the
282
277
// `Arc<CryptoStoreWrapper>` for each device?
@@ -300,35 +295,22 @@ impl GroupSessionManager {
300
295
301
296
match result. maybe_encrypted_room_key {
302
297
MaybeEncryptedRoomKey :: Encrypted { used_session, share_info, message } => {
303
- changed_sessions . push ( used_session) ;
298
+ result_builder . on_successful_encryption ( & result . device , used_session, message ) ;
304
299
305
300
let user_id = result. device . user_id ( ) . to_owned ( ) ;
306
301
let device_id = result. device . device_id ( ) . to_owned ( ) ;
307
-
308
- messages
309
- . entry ( user_id. to_owned ( ) )
310
- . or_insert_with ( BTreeMap :: new)
311
- . insert ( DeviceIdOrAllDevices :: DeviceId ( device_id. to_owned ( ) ) , message) ;
312
-
313
302
share_infos
314
303
. entry ( user_id)
315
304
. or_insert_with ( BTreeMap :: new)
316
305
. insert ( device_id, share_info) ;
317
306
}
318
- MaybeEncryptedRoomKey :: Withheld { code } => {
319
- withheld_devices . push ( ( result. device , code ) ) ;
307
+ MaybeEncryptedRoomKey :: MissingSession => {
308
+ result_builder . on_missing_session ( result. device ) ;
320
309
}
321
310
}
322
311
}
323
312
324
- let txn_id = TransactionId :: new ( ) ;
325
- let request = ToDeviceRequest {
326
- event_type : ToDeviceEventType :: RoomEncrypted ,
327
- txn_id : txn_id. to_owned ( ) ,
328
- messages,
329
- } ;
330
-
331
- Ok ( ( txn_id, request, share_infos, changed_sessions, withheld_devices) )
313
+ Ok ( ( result_builder. into_result ( ) , share_infos) )
332
314
}
333
315
334
316
/// Given a list of user and an outbound session, return the list of users
@@ -353,21 +335,16 @@ impl GroupSessionManager {
353
335
outbound : OutboundGroupSession ,
354
336
sessions : GroupSessionCache ,
355
337
) -> OlmResult < ( Vec < Session > , Vec < ( DeviceData , WithheldCode ) > ) > {
356
- let ( id , request , share_infos, used_sessions , no_olm ) =
338
+ let ( result , share_infos) =
357
339
Self :: encrypt_session_for ( store, outbound. clone ( ) , chunk) . await ?;
358
340
359
- if !request. messages . is_empty ( ) {
360
- trace ! (
361
- recipient_count = request. message_count( ) ,
362
- transaction_id = ?id,
363
- "Created a to-device request carrying a room_key"
364
- ) ;
365
-
341
+ if let Some ( request) = result. to_device_request {
342
+ let id = request. txn_id . clone ( ) ;
366
343
outbound. add_request ( id. clone ( ) , request. into ( ) , share_infos) ;
367
344
sessions. mark_as_being_shared ( id, outbound. clone ( ) ) ;
368
345
}
369
346
370
- Ok ( ( used_sessions , no_olm ) )
347
+ Ok ( ( result . updated_olm_sessions , result . no_olm_devices ) )
371
348
}
372
349
373
350
pub ( crate ) fn session_cache ( & self ) -> GroupSessionCache {
@@ -771,6 +748,88 @@ impl GroupSessionManager {
771
748
}
772
749
}
773
750
751
+ /// Result of [`GroupSessionManager::encrypt_session_for`]
752
+ #[ derive( Debug ) ]
753
+ struct EncryptForDevicesResult {
754
+ /// The request to send the to-device messages containing the encrypted
755
+ /// payload, if any devices were found.
756
+ to_device_request : Option < ToDeviceRequest > ,
757
+
758
+ /// The devices which lack an Olm session and therefore need a withheld code
759
+ no_olm_devices : Vec < ( DeviceData , WithheldCode ) > ,
760
+
761
+ /// The Olm sessions which were used to encrypt the requests and now need
762
+ /// persisting to the store.
763
+ updated_olm_sessions : Vec < Session > ,
764
+ }
765
+
766
+ /// A helper for building [`EncryptForDevicesResult`]
767
+ #[ derive( Debug , Default ) ]
768
+ struct EncryptForDevicesResultBuilder {
769
+ /// The payloads of the to-device messages
770
+ messages : BTreeMap < OwnedUserId , BTreeMap < DeviceIdOrAllDevices , Raw < AnyToDeviceEventContent > > > ,
771
+
772
+ /// The devices which lack an Olm session and therefore need a withheld code
773
+ no_olm_devices : Vec < ( DeviceData , WithheldCode ) > ,
774
+
775
+ /// The Olm sessions which were used to encrypt the requests and now need
776
+ /// persisting to the store.
777
+ updated_olm_sessions : Vec < Session > ,
778
+ }
779
+
780
+ impl EncryptForDevicesResultBuilder {
781
+ /// Record a successful encryption. The encrypted message is added to the
782
+ /// list to be sent, and the olm session is added to the list of those
783
+ /// that have been modified.
784
+ pub fn on_successful_encryption (
785
+ & mut self ,
786
+ device : & DeviceData ,
787
+ used_session : Session ,
788
+ message : Raw < AnyToDeviceEventContent > ,
789
+ ) {
790
+ self . updated_olm_sessions . push ( used_session) ;
791
+
792
+ self . messages
793
+ . entry ( device. user_id ( ) . to_owned ( ) )
794
+ . or_default ( )
795
+ . insert ( DeviceIdOrAllDevices :: DeviceId ( device. device_id ( ) . to_owned ( ) ) , message) ;
796
+ }
797
+
798
+ /// Record a device which didn't have an active Olm session.
799
+ pub fn on_missing_session ( & mut self , device : DeviceData ) {
800
+ self . no_olm_devices . push ( ( device, WithheldCode :: NoOlm ) ) ;
801
+ }
802
+
803
+ /// Transform the accumulated results into an [`EncryptForDevicesResult`],
804
+ /// wrapping the messages, if any, into a `ToDeviceRequest`.
805
+ pub fn into_result ( self ) -> EncryptForDevicesResult {
806
+ let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
807
+ self ;
808
+
809
+ let mut encrypt_for_devices_result = EncryptForDevicesResult {
810
+ to_device_request : None ,
811
+ updated_olm_sessions,
812
+ no_olm_devices,
813
+ } ;
814
+
815
+ if !messages. is_empty ( ) {
816
+ let request = ToDeviceRequest {
817
+ event_type : ToDeviceEventType :: RoomEncrypted ,
818
+ txn_id : TransactionId :: new ( ) ,
819
+ messages,
820
+ } ;
821
+ trace ! (
822
+ recipient_count = request. message_count( ) ,
823
+ transaction_id = ?request. txn_id,
824
+ "Created a to-device request carrying room keys" ,
825
+ ) ;
826
+ encrypt_for_devices_result. to_device_request = Some ( request) ;
827
+ } ;
828
+
829
+ encrypt_for_devices_result
830
+ }
831
+ }
832
+
774
833
#[ cfg( test) ]
775
834
mod tests {
776
835
use std:: {
0 commit comments