Skip to content

Commit 90a80c3

Browse files
committed
WidgetDriver: temp, add crypto functions that are still missing from the crypto crate in a temp crate.
1 parent 60941be commit 90a80c3

File tree

1 file changed

+140
-3
lines changed

1 file changed

+140
-3
lines changed

crates/matrix-sdk/src/widget/matrix.rs

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
1818
use std::collections::BTreeMap;
1919

20+
use futures_util::future::join_all;
2021
use matrix_sdk_base::deserialized_responses::{EncryptionInfo, RawAnySyncOrStrippedState};
2122
use ruma::{
2223
api::client::{
@@ -273,9 +274,36 @@ impl MatrixDriver {
273274
let client = self.room.client();
274275

275276
let request = if encrypted {
276-
return Err(Error::UnknownError(
277-
"Sending encrypted to_device events is not supported by the widget driver.".into(),
278-
));
277+
// We first want to get all missing session before we start any to device
278+
// sending!
279+
client.claim_one_time_keys(messages.keys().map(|u| u.as_ref())).await?;
280+
let encrypted_content: BTreeMap<
281+
OwnedUserId,
282+
BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>,
283+
> = join_all(messages.into_iter().map(|(user_id, device_content_map)| {
284+
let event_type = event_type.clone();
285+
async move {
286+
(
287+
user_id.clone(),
288+
to_device_crypto::encrypted_device_content_map(
289+
&self.room.client(),
290+
&user_id,
291+
&event_type,
292+
device_content_map,
293+
)
294+
.await,
295+
)
296+
}
297+
}))
298+
.await
299+
.into_iter()
300+
.collect();
301+
302+
RumaToDeviceRequest::new_raw(
303+
ToDeviceEventType::RoomEncrypted,
304+
TransactionId::new(),
305+
encrypted_content,
306+
)
279307
} else {
280308
RumaToDeviceRequest::new_raw(event_type, TransactionId::new(), messages)
281309
};
@@ -337,3 +365,112 @@ fn add_props_to_raw<T>(
337365
Err(e) => Err(Error::from(e)),
338366
}
339367
}
368+
369+
/// Move this into the `matrix_crypto` crate!
370+
/// This module contains helper functions to encrypt to device events.
371+
mod to_device_crypto {
372+
use std::collections::BTreeMap;
373+
374+
use futures_util::future::join_all;
375+
use ruma::{
376+
events::{AnyToDeviceEventContent, ToDeviceEventType},
377+
serde::Raw,
378+
to_device::DeviceIdOrAllDevices,
379+
UserId,
380+
};
381+
use serde_json::Value;
382+
use tracing::{info, warn};
383+
384+
use crate::{encryption::identities::Device, executor::spawn, Client, Error, Result};
385+
386+
/// This encrypts to device content for a collection of devices.
387+
/// It will ignore all devices where errors occurred or where the device
388+
/// is not verified or where th user has a has_verification_violation.
389+
async fn encrypted_content_for_devices(
390+
unencrypted_content: &Raw<AnyToDeviceEventContent>,
391+
devices: Vec<Device>,
392+
event_type: &ToDeviceEventType,
393+
) -> Result<impl Iterator<Item = (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>> {
394+
let content: Value = unencrypted_content.deserialize_as().map_err(Into::<Error>::into)?;
395+
let event_type = event_type.clone();
396+
let device_content_tasks = devices.into_iter().map(|device| spawn({
397+
let event_type = event_type.clone();
398+
let content = content.clone();
399+
400+
async move {
401+
if !device.is_cross_signed_by_owner() {
402+
info!("Device {} is not verified, skipping encryption", device.device_id());
403+
return None;
404+
}
405+
match device
406+
.inner
407+
.encrypt_event_raw(&event_type.to_string(), &content)
408+
.await {
409+
Ok(encrypted) => Some((device.device_id().to_owned().into(), encrypted.cast())),
410+
Err(e) =>{ info!("Failed to encrypt to_device event from widget for device: {} because, {}", device.device_id(), e); None},
411+
}
412+
}
413+
}));
414+
let device_encrypted_content_map =
415+
join_all(device_content_tasks).await.into_iter().flatten().flatten();
416+
Ok(device_encrypted_content_map)
417+
}
418+
419+
/// Convert the device content map for one user into the same content
420+
/// map with encrypted content This needs to flatten the vectors
421+
/// we get from `encrypted_content_for_devices`
422+
/// since one `DeviceIdOrAllDevices` id can be multiple devices.
423+
pub(super) async fn encrypted_device_content_map(
424+
client: &Client,
425+
user_id: &UserId,
426+
event_type: &ToDeviceEventType,
427+
device_content_map: BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>,
428+
) -> BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>> {
429+
let device_map_futures =
430+
device_content_map.into_iter().map(|(device_or_all_id, content)| spawn({
431+
let client = client.clone();
432+
let user_id = user_id.to_owned();
433+
let event_type = event_type.clone();
434+
async move {
435+
let Ok(user_devices) = client.encryption().get_user_devices(&user_id).await else {
436+
warn!("Failed to get user devices for user: {}", user_id);
437+
return None;
438+
};
439+
let Ok(user_identity) = client.encryption().get_user_identity(&user_id).await else{
440+
warn!("Failed to get user identity for user: {}", user_id);
441+
return None;
442+
};
443+
if user_identity.map(|i|i.has_verification_violation()).unwrap_or(false) {
444+
info!("User {} has a verification violation, skipping encryption", user_id);
445+
return None;
446+
}
447+
let devices: Vec<Device> = match device_or_all_id {
448+
DeviceIdOrAllDevices::DeviceId(device_id) => {
449+
vec![user_devices.get(&device_id)].into_iter().flatten().collect()
450+
}
451+
DeviceIdOrAllDevices::AllDevices => user_devices.devices().collect(),
452+
};
453+
encrypted_content_for_devices(
454+
&content,
455+
devices,
456+
&event_type,
457+
)
458+
.await
459+
.map_err(|e| info!("WidgetDriver: could not encrypt content for to device widget event content: {}. because, {}", content.json(), e))
460+
.ok()
461+
}}));
462+
let content_map_iterator = join_all(device_map_futures).await.into_iter();
463+
464+
// The first flatten takes the iterator over Result<Option<impl Iterator<Item =
465+
// (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>>, JoinError>>
466+
// and flattens the Result (drops Err() items)
467+
// The second takes the iterator over: Option<impl Iterator<Item =
468+
// (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>>
469+
// and flattens the Option (drops None items)
470+
// The third takes the iterator over iterators: impl Iterator<Item =
471+
// (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>
472+
// and flattens it to just an iterator over (DeviceIdOrAllDevices,
473+
// Raw<AnyToDeviceEventContent>)
474+
content_map_iterator.flatten().flatten().flatten().collect()
475+
}
476+
}

0 commit comments

Comments
 (0)