Skip to content

Commit 67b0cf0

Browse files
committed
WidgetDriver: temp, add crypto functions that are still missing from the crypto crate in a temp crate.
1 parent f1facc6 commit 67b0cf0

File tree

1 file changed

+141
-3
lines changed

1 file changed

+141
-3
lines changed

crates/matrix-sdk/src/widget/matrix.rs

+141-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
1818
use std::collections::BTreeMap;
1919

20+
use futures_util::future::join_all;
2021
use matrix_sdk_base::deserialized_responses::{EncryptionInfo, RawAnySyncOrStrippedState};
2122
use ruma::{
2223
api::client::{
@@ -273,9 +274,36 @@ impl MatrixDriver {
273274
let client = self.room.client();
274275

275276
let request = if encrypted {
276-
return Err(Error::UnknownError(
277-
"Sending encrypted to_device events is not supported by the widget driver.".into(),
278-
));
277+
// We first want to get all missing session before we start any to device
278+
// sending!
279+
client.claim_one_time_keys(messages.keys().map(|u| u.as_ref())).await?;
280+
let encrypted_content: BTreeMap<
281+
OwnedUserId,
282+
BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>,
283+
> = join_all(messages.into_iter().map(|(user_id, device_content_map)| {
284+
let event_type = event_type.clone();
285+
async move {
286+
(
287+
user_id.clone(),
288+
to_device_crypto::encrypted_device_content_map(
289+
&self.room.client(),
290+
&user_id,
291+
&event_type,
292+
device_content_map,
293+
)
294+
.await,
295+
)
296+
}
297+
}))
298+
.await
299+
.into_iter()
300+
.collect();
301+
302+
RumaToDeviceRequest::new_raw(
303+
ToDeviceEventType::RoomEncrypted,
304+
TransactionId::new(),
305+
encrypted_content,
306+
)
279307
} else {
280308
RumaToDeviceRequest::new_raw(event_type, TransactionId::new(), messages)
281309
};
@@ -338,6 +366,116 @@ fn add_props_to_raw<T>(
338366
Err(e) => Err(Error::from(e)),
339367
}
340368
}
369+
370+
/// Move this into the `matrix_crypto` crate!
371+
/// This module contains helper functions to encrypt to device events.
372+
mod to_device_crypto {
373+
use std::collections::BTreeMap;
374+
375+
use futures_util::future::join_all;
376+
use ruma::{
377+
events::{AnyToDeviceEventContent, ToDeviceEventType},
378+
serde::Raw,
379+
to_device::DeviceIdOrAllDevices,
380+
UserId,
381+
};
382+
use serde_json::Value;
383+
use tracing::{info, warn};
384+
385+
use crate::{encryption::identities::Device, executor::spawn, Client, Error, Result};
386+
387+
/// This encrypts to device content for a collection of devices.
388+
/// It will ignore all devices where errors occurred or where the device
389+
/// is not verified or where th user has a has_verification_violation.
390+
async fn encrypted_content_for_devices(
391+
unencrypted_content: &Raw<AnyToDeviceEventContent>,
392+
devices: Vec<Device>,
393+
event_type: &ToDeviceEventType,
394+
) -> Result<impl Iterator<Item = (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>> {
395+
let content: Value = unencrypted_content.deserialize_as().map_err(Into::<Error>::into)?;
396+
let event_type = event_type.clone();
397+
let device_content_tasks = devices.into_iter().map(|device| spawn({
398+
let event_type = event_type.clone();
399+
let content = content.clone();
400+
401+
async move {
402+
if !device.is_cross_signed_by_owner() {
403+
info!("Device {} is not verified, skipping encryption", device.device_id());
404+
return None;
405+
}
406+
match device
407+
.inner
408+
.encrypt_event_raw(&event_type.to_string(), &content)
409+
.await {
410+
Ok(encrypted) => Some((device.device_id().to_owned().into(), encrypted.cast())),
411+
Err(e) =>{ info!("Failed to encrypt to_device event from widget for device: {} because, {}", device.device_id(), e); None},
412+
}
413+
}
414+
}));
415+
let device_encrypted_content_map =
416+
join_all(device_content_tasks).await.into_iter().flatten().flatten();
417+
Ok(device_encrypted_content_map)
418+
}
419+
420+
/// Convert the device content map for one user into the same content
421+
/// map with encrypted content This needs to flatten the vectors
422+
/// we get from `encrypted_content_for_devices`
423+
/// since one `DeviceIdOrAllDevices` id can be multiple devices.
424+
pub(super) async fn encrypted_device_content_map(
425+
client: &Client,
426+
user_id: &UserId,
427+
event_type: &ToDeviceEventType,
428+
device_content_map: BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>,
429+
) -> BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>> {
430+
let device_map_futures =
431+
device_content_map.into_iter().map(|(device_or_all_id, content)| spawn({
432+
let client = client.clone();
433+
let user_id = user_id.to_owned();
434+
let event_type = event_type.clone();
435+
async move {
436+
let Ok(user_devices) = client.encryption().get_user_devices(&user_id).await else {
437+
warn!("Failed to get user devices for user: {}", user_id);
438+
return None;
439+
};
440+
let Ok(user_identity) = client.encryption().get_user_identity(&user_id).await else{
441+
warn!("Failed to get user identity for user: {}", user_id);
442+
return None;
443+
};
444+
if user_identity.map(|i|i.has_verification_violation()).unwrap_or(false) {
445+
info!("User {} has a verification violation, skipping encryption", user_id);
446+
return None;
447+
}
448+
let devices: Vec<Device> = match device_or_all_id {
449+
DeviceIdOrAllDevices::DeviceId(device_id) => {
450+
vec![user_devices.get(&device_id)].into_iter().flatten().collect()
451+
}
452+
DeviceIdOrAllDevices::AllDevices => user_devices.devices().collect(),
453+
};
454+
encrypted_content_for_devices(
455+
&content,
456+
devices,
457+
&event_type,
458+
)
459+
.await
460+
.map_err(|e| info!("WidgetDriver: could not encrypt content for to device widget event content: {}. because, {}", content.json(), e))
461+
.ok()
462+
}}));
463+
let content_map_iterator = join_all(device_map_futures).await.into_iter();
464+
465+
// The first flatten takes the iterator over Result<Option<impl Iterator<Item =
466+
// (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>>, JoinError>>
467+
// and flattens the Result (drops Err() items)
468+
// The second takes the iterator over: Option<impl Iterator<Item =
469+
// (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>>
470+
// and flattens the Option (drops None items)
471+
// The third takes the iterator over iterators: impl Iterator<Item =
472+
// (DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>)>
473+
// and flattens it to just an iterator over (DeviceIdOrAllDevices,
474+
// Raw<AnyToDeviceEventContent>)
475+
content_map_iterator.flatten().flatten().flatten().collect()
476+
}
477+
}
478+
341479
#[cfg(test)]
342480
mod tests {
343481
use ruma::{room_id, serde::Raw};

0 commit comments

Comments
 (0)