|
17 | 17 |
|
18 | 18 | use std::collections::BTreeMap;
|
19 | 19 |
|
| 20 | +use futures_util::future::join_all; |
20 | 21 | use matrix_sdk_base::deserialized_responses::{EncryptionInfo, RawAnySyncOrStrippedState};
|
21 | 22 | use ruma::{
|
22 | 23 | api::client::{
|
@@ -273,9 +274,36 @@ impl MatrixDriver {
|
273 | 274 | let client = self.room.client();
|
274 | 275 |
|
275 | 276 | let request = if encrypted {
|
276 |
| - return Err(Error::UnknownError( |
277 |
| - "Sending encrypted to_device events is not supported by the widget driver.".into(), |
278 |
| - )); |
| 277 | + // We first want to get all missing session before we start any to device |
| 278 | + // sending! |
| 279 | + client.claim_one_time_keys(messages.keys().map(|u| u.as_ref())).await?; |
| 280 | + let encrypted_content: BTreeMap< |
| 281 | + OwnedUserId, |
| 282 | + BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>, |
| 283 | + > = join_all(messages.into_iter().map(|(user_id, device_content_map)| { |
| 284 | + let event_type = event_type.clone(); |
| 285 | + async move { |
| 286 | + ( |
| 287 | + user_id.clone(), |
| 288 | + to_device_crypto::encrypted_device_content_map( |
| 289 | + &self.room.client(), |
| 290 | + &user_id, |
| 291 | + &event_type, |
| 292 | + device_content_map, |
| 293 | + ) |
| 294 | + .await, |
| 295 | + ) |
| 296 | + } |
| 297 | + })) |
| 298 | + .await |
| 299 | + .into_iter() |
| 300 | + .collect(); |
| 301 | + |
| 302 | + RumaToDeviceRequest::new_raw( |
| 303 | + ToDeviceEventType::RoomEncrypted, |
| 304 | + TransactionId::new(), |
| 305 | + encrypted_content, |
| 306 | + ) |
279 | 307 | } else {
|
280 | 308 | RumaToDeviceRequest::new_raw(event_type, TransactionId::new(), messages)
|
281 | 309 | };
|
@@ -338,6 +366,116 @@ fn add_props_to_raw<T>(
|
338 | 366 | Err(e) => Err(Error::from(e)),
|
339 | 367 | }
|
340 | 368 | }
|
| 369 | + |
| 370 | +/// Move this into the `matrix_crypto` crate! |
| 371 | +/// This module contains helper functions to encrypt to device events. |
| 372 | +mod to_device_crypto { |
| 373 | + use std::collections::BTreeMap; |
| 374 | + |
| 375 | + use futures_util::future::join_all; |
| 376 | + use ruma::{ |
| 377 | + events::{AnyToDeviceEventContent, ToDeviceEventType}, |
| 378 | + serde::Raw, |
| 379 | + to_device::DeviceIdOrAllDevices, |
| 380 | + UserId, |
| 381 | + }; |
| 382 | + use serde_json::Value; |
| 383 | + use tracing::{info, warn}; |
| 384 | + |
| 385 | + use crate::{encryption::identities::Device, executor::spawn, Client, Error, Result}; |
| 386 | + |
| 387 | + /// This encrypts to device content for a collection of devices. |
| 388 | + /// It will ignore all devices where errors occurred or where the device |
| 389 | + /// is not verified or where th user has a has_verification_violation. |
| 390 | + async fn encrypted_content_for_devices( |
| 391 | + unencrypted_content: &Raw<AnyToDeviceEventContent>, |
| 392 | + devices: Vec<Device>, |
| 393 | + event_type: &ToDeviceEventType, |
| 394 | + ) -> Result<impl Iterator<Item = (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>> { |
| 395 | + let content: Value = unencrypted_content.deserialize_as().map_err(Into::<Error>::into)?; |
| 396 | + let event_type = event_type.clone(); |
| 397 | + let device_content_tasks = devices.into_iter().map(|device| spawn({ |
| 398 | + let event_type = event_type.clone(); |
| 399 | + let content = content.clone(); |
| 400 | + |
| 401 | + async move { |
| 402 | + if !device.is_cross_signed_by_owner() { |
| 403 | + info!("Device {} is not verified, skipping encryption", device.device_id()); |
| 404 | + return None; |
| 405 | + } |
| 406 | + match device |
| 407 | + .inner |
| 408 | + .encrypt_event_raw(&event_type.to_string(), &content) |
| 409 | + .await { |
| 410 | + Ok(encrypted) => Some((device.device_id().to_owned().into(), encrypted.cast())), |
| 411 | + Err(e) =>{ info!("Failed to encrypt to_device event from widget for device: {} because, {}", device.device_id(), e); None}, |
| 412 | + } |
| 413 | + } |
| 414 | + })); |
| 415 | + let device_encrypted_content_map = |
| 416 | + join_all(device_content_tasks).await.into_iter().flatten().flatten(); |
| 417 | + Ok(device_encrypted_content_map) |
| 418 | + } |
| 419 | + |
| 420 | + /// Convert the device content map for one user into the same content |
| 421 | + /// map with encrypted content This needs to flatten the vectors |
| 422 | + /// we get from `encrypted_content_for_devices` |
| 423 | + /// since one `DeviceIdOrAllDevices` id can be multiple devices. |
| 424 | + pub(super) async fn encrypted_device_content_map( |
| 425 | + client: &Client, |
| 426 | + user_id: &UserId, |
| 427 | + event_type: &ToDeviceEventType, |
| 428 | + device_content_map: BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>, |
| 429 | + ) -> BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>> { |
| 430 | + let device_map_futures = |
| 431 | + device_content_map.into_iter().map(|(device_or_all_id, content)| spawn({ |
| 432 | + let client = client.clone(); |
| 433 | + let user_id = user_id.to_owned(); |
| 434 | + let event_type = event_type.clone(); |
| 435 | + async move { |
| 436 | + let Ok(user_devices) = client.encryption().get_user_devices(&user_id).await else { |
| 437 | + warn!("Failed to get user devices for user: {}", user_id); |
| 438 | + return None; |
| 439 | + }; |
| 440 | + let Ok(user_identity) = client.encryption().get_user_identity(&user_id).await else{ |
| 441 | + warn!("Failed to get user identity for user: {}", user_id); |
| 442 | + return None; |
| 443 | + }; |
| 444 | + if user_identity.map(|i|i.has_verification_violation()).unwrap_or(false) { |
| 445 | + info!("User {} has a verification violation, skipping encryption", user_id); |
| 446 | + return None; |
| 447 | + } |
| 448 | + let devices: Vec<Device> = match device_or_all_id { |
| 449 | + DeviceIdOrAllDevices::DeviceId(device_id) => { |
| 450 | + vec![user_devices.get(&device_id)].into_iter().flatten().collect() |
| 451 | + } |
| 452 | + DeviceIdOrAllDevices::AllDevices => user_devices.devices().collect(), |
| 453 | + }; |
| 454 | + encrypted_content_for_devices( |
| 455 | + &content, |
| 456 | + devices, |
| 457 | + &event_type, |
| 458 | + ) |
| 459 | + .await |
| 460 | + .map_err(|e| info!("WidgetDriver: could not encrypt content for to device widget event content: {}. because, {}", content.json(), e)) |
| 461 | + .ok() |
| 462 | + }})); |
| 463 | + let content_map_iterator = join_all(device_map_futures).await.into_iter(); |
| 464 | + |
| 465 | + // The first flatten takes the iterator over Result<Option<impl Iterator<Item = |
| 466 | + // (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>>, JoinError>> |
| 467 | + // and flattens the Result (drops Err() items) |
| 468 | + // The second takes the iterator over: Option<impl Iterator<Item = |
| 469 | + // (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>> |
| 470 | + // and flattens the Option (drops None items) |
| 471 | + // The third takes the iterator over iterators: impl Iterator<Item = |
| 472 | + // (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)> |
| 473 | + // and flattens it to just an iterator over (DeviceIdOrAllDevices, |
| 474 | + // Raw<AnyToDeviceEventContent>) |
| 475 | + content_map_iterator.flatten().flatten().flatten().collect() |
| 476 | + } |
| 477 | +} |
| 478 | + |
341 | 479 | #[cfg(test)]
|
342 | 480 | mod tests {
|
343 | 481 | use ruma::{room_id, serde::Raw};
|
|
0 commit comments