14
14
// See the License for the specific language governing permissions and
15
15
// limitations under the License.
16
16
17
+ #[ cfg( feature = "e2e-encryption" ) ]
18
+ use std:: ops:: Deref ;
17
19
use std:: {
18
20
collections:: { btree_map, BTreeMap } ,
19
21
fmt:: { self , Debug } ,
@@ -37,8 +39,6 @@ use matrix_sdk_base::{
37
39
StateStoreDataKey , StateStoreDataValue , SyncOutsideWasm ,
38
40
} ;
39
41
use matrix_sdk_common:: ttl_cache:: TtlCache ;
40
- #[ cfg( feature = "e2e-encryption" ) ]
41
- use ruma:: events:: { room:: encryption:: RoomEncryptionEventContent , InitialStateEvent } ;
42
42
use ruma:: {
43
43
api:: {
44
44
client:: {
@@ -69,6 +69,15 @@ use ruma::{
69
69
DeviceId , OwnedDeviceId , OwnedEventId , OwnedRoomId , OwnedRoomOrAliasId , OwnedServerName ,
70
70
RoomAliasId , RoomId , RoomOrAliasId , ServerName , UInt , UserId ,
71
71
} ;
72
+ #[ cfg( feature = "e2e-encryption" ) ]
73
+ use ruma:: {
74
+ events:: {
75
+ room:: encryption:: RoomEncryptionEventContent , AnyToDeviceEventContent , InitialStateEvent ,
76
+ } ,
77
+ serde:: Raw ,
78
+ to_device:: DeviceIdOrAllDevices ,
79
+ OwnedUserId ,
80
+ } ;
72
81
use serde:: de:: DeserializeOwned ;
73
82
use tokio:: sync:: { broadcast, Mutex , OnceCell , RwLock , RwLockReadGuard } ;
74
83
use tracing:: { debug, error, instrument, trace, warn, Instrument , Span } ;
@@ -99,7 +108,9 @@ use crate::{
99
108
} ;
100
109
#[ cfg( feature = "e2e-encryption" ) ]
101
110
use crate :: {
102
- encryption:: { Encryption , EncryptionData , EncryptionSettings , VerificationState } ,
111
+ encryption:: {
112
+ identities:: Device , Encryption , EncryptionData , EncryptionSettings , VerificationState ,
113
+ } ,
103
114
store_locks:: CrossProcessStoreLock ,
104
115
} ;
105
116
@@ -2513,6 +2524,74 @@ impl Client {
2513
2524
let base_room = self . inner . base_client . room_knocked ( & response. room_id ) . await ?;
2514
2525
Ok ( Room :: new ( self . clone ( ) , base_room) )
2515
2526
}
2527
+
2528
+ /// Encrypts then send the given content via the `sendToDevice` end-point
2529
+ /// using olm encryption.
2530
+ ///
2531
+ /// If there are a lot of targets this will be break down by chunks.
2532
+ ///
2533
+ /// # Returns
2534
+ /// A list of `ToDeviceRequest` to send out the event, and the list of
2535
+ /// devices where encryption did not succeed (device excluded or no olm)
2536
+ #[ cfg( feature = "e2e-encryption" ) ]
2537
+ pub async fn encrypt_and_send_custom_to_device (
2538
+ & self ,
2539
+ targets : Vec < & Device > ,
2540
+ event_type : & str ,
2541
+ content : Raw < AnyToDeviceEventContent > ,
2542
+ ) -> Result < Vec < ( OwnedUserId , OwnedDeviceId ) > > {
2543
+ let users = targets. iter ( ) . map ( |device| device. user_id ( ) ) ;
2544
+
2545
+ // Will claim one-time-key for users that needs it
2546
+ // TODO: For later optimisation: This will establish missing olm sessions with
2547
+ // all this users devices, but we just want for some devices.
2548
+ self . claim_one_time_keys ( users) . await ?;
2549
+
2550
+ let olm = self . olm_machine ( ) . await ;
2551
+ let olm = olm. as_ref ( ) . expect ( "Olm machine wasn't started" ) ;
2552
+
2553
+ let ( requests, withhelds) = olm
2554
+ . encrypt_content_for_devices (
2555
+ targets. into_iter ( ) . map ( |d| d. deref ( ) . clone ( ) ) . collect ( ) ,
2556
+ event_type,
2557
+ & content
2558
+ . deserialize_as :: < serde_json:: Value > ( )
2559
+ . expect ( "Deserialize as Value will always work" ) ,
2560
+ )
2561
+ . await ?;
2562
+
2563
+ let mut failures: Vec < ( OwnedUserId , OwnedDeviceId ) > = Default :: default ( ) ;
2564
+
2565
+ // Push the withhelds in the failures
2566
+ withhelds. iter ( ) . for_each ( |( d, _) | {
2567
+ failures. push ( ( d. user_id ( ) . to_owned ( ) , d. device_id ( ) . to_owned ( ) ) ) ;
2568
+ } ) ;
2569
+
2570
+ // TODO: parallelize that? it's already grouping 250 devices per chunk.
2571
+ for request in requests {
2572
+ let send_result =
2573
+ self . send_to_device_with_config ( & request, RequestConfig :: short_retry ( ) ) . await ;
2574
+
2575
+ // If the sending failed we need to collect the failures to report them
2576
+ if send_result. is_err ( ) {
2577
+ // Mark the sending as failed
2578
+ for ( user_id, device_map) in request. messages {
2579
+ for device_id in device_map. keys ( ) {
2580
+ match device_id {
2581
+ DeviceIdOrAllDevices :: DeviceId ( device_id) => {
2582
+ failures. push ( ( user_id. clone ( ) , device_id. to_owned ( ) ) ) ;
2583
+ }
2584
+ DeviceIdOrAllDevices :: AllDevices => {
2585
+ // Cannot happen in this case
2586
+ }
2587
+ }
2588
+ }
2589
+ }
2590
+ }
2591
+ }
2592
+
2593
+ Ok ( failures)
2594
+ }
2516
2595
}
2517
2596
2518
2597
/// A weak reference to the inner client, useful when trying to get a handle
0 commit comments