@@ -251,7 +251,6 @@ pub enum CrossProcessRefreshLockError {
251
251
252
252
#[ cfg( all( test, feature = "e2e-encryption" ) ) ]
253
253
mod tests {
254
- use std:: sync:: Arc ;
255
254
256
255
use anyhow:: Context as _;
257
256
use futures_util:: future:: join_all;
@@ -261,12 +260,7 @@ mod tests {
261
260
262
261
use super :: compute_session_hash;
263
262
use crate :: {
264
- authentication:: oidc:: {
265
- backend:: mock:: { MockImpl , ISSUER_URL } ,
266
- cross_process:: SessionHash ,
267
- tests:: prev_session_tokens,
268
- Oidc ,
269
- } ,
263
+ authentication:: oidc:: { cross_process:: SessionHash , tests:: prev_session_tokens} ,
270
264
test_utils:: {
271
265
client:: {
272
266
oauth:: { mock_session, mock_session_tokens} ,
@@ -302,7 +296,13 @@ mod tests {
302
296
) ?;
303
297
304
298
let session_hash = compute_session_hash ( & tokens) ;
305
- client. oidc ( ) . restore_session ( mock_session ( tokens. clone ( ) , ISSUER_URL . to_owned ( ) ) ) . await ?;
299
+ client
300
+ . oidc ( )
301
+ . restore_session ( mock_session (
302
+ tokens. clone ( ) ,
303
+ "https://oidc.example.com/issuer" . to_owned ( ) ,
304
+ ) )
305
+ . await ?;
306
306
307
307
assert_eq ! ( client. oidc( ) . session_tokens( ) . unwrap( ) , tokens) ;
308
308
@@ -376,37 +376,29 @@ mod tests {
376
376
// This tests that refresh token works, and that it doesn't cause multiple token
377
377
// refreshes whenever one spawns two refreshes around the same time.
378
378
379
+ let server = MatrixMockServer :: new ( ) . await ;
380
+
381
+ let oauth_server = server. oauth ( ) ;
382
+ oauth_server. mock_server_metadata ( ) . ok ( ) . expect ( 1 ..) . named ( "server_metadata" ) . mount ( ) . await ;
383
+ oauth_server. mock_token ( ) . ok ( ) . expect ( 1 ) . named ( "token" ) . mount ( ) . await ;
384
+
379
385
let tmp_dir = tempfile:: tempdir ( ) ?;
380
- let client = MockClientBuilder :: new ( "https://example.org" . to_owned ( ) )
381
- . sqlite_store ( & tmp_dir)
382
- . unlogged ( )
383
- . build ( )
384
- . await ;
386
+ let client = server. client_builder ( ) . sqlite_store ( & tmp_dir) . unlogged ( ) . build ( ) . await ;
387
+ let oidc = client. oidc ( ) ;
385
388
386
- let prev_tokens = prev_session_tokens ( ) ;
387
389
let next_tokens = mock_session_tokens ( ) ;
388
390
389
- let backend = Arc :: new (
390
- MockImpl :: new ( )
391
- . next_session_tokens ( next_tokens. clone ( ) )
392
- . expected_refresh_token ( prev_tokens. refresh_token . clone ( ) . unwrap ( ) ) ,
393
- ) ;
394
- let oidc = Oidc { client : client. clone ( ) , backend : backend. clone ( ) } ;
395
-
396
391
// Enable cross-process lock.
397
392
oidc. enable_cross_process_refresh_lock ( "lock" . to_owned ( ) ) . await ?;
398
393
399
394
// Restore the session.
400
- oidc. restore_session ( mock_session ( prev_tokens . clone ( ) , ISSUER_URL . to_owned ( ) ) ) . await ?;
395
+ oidc. restore_session ( mock_session ( prev_session_tokens ( ) , server . server ( ) . uri ( ) ) ) . await ?;
401
396
402
397
// Immediately try to refresh the access token twice in parallel.
403
398
for result in join_all ( [ oidc. refresh_access_token ( ) , oidc. refresh_access_token ( ) ] ) . await {
404
399
result?;
405
400
}
406
401
407
- // There should have been at most one refresh.
408
- assert_eq ! ( * backend. num_refreshes. lock( ) . unwrap( ) , 1 ) ;
409
-
410
402
{
411
403
// The cross process lock has been correctly updated, and the next attempt to
412
404
// take it won't result in a mismatch.
@@ -424,51 +416,40 @@ mod tests {
424
416
425
417
#[ async_test]
426
418
async fn test_cross_process_concurrent_refresh ( ) -> anyhow:: Result < ( ) > {
427
- // Create the backend.
419
+ let server = MatrixMockServer :: new ( ) . await ;
420
+ let issuer = server. server ( ) . uri ( ) ;
421
+
422
+ let oauth_server = server. oauth ( ) ;
423
+ oauth_server. mock_server_metadata ( ) . ok ( ) . expect ( 1 ..) . named ( "server_metadata" ) . mount ( ) . await ;
424
+ oauth_server. mock_token ( ) . ok ( ) . expect ( 1 ) . named ( "token" ) . mount ( ) . await ;
425
+
428
426
let prev_tokens = prev_session_tokens ( ) ;
429
427
let next_tokens = mock_session_tokens ( ) ;
430
428
431
- let backend = Arc :: new (
432
- MockImpl :: new ( )
433
- . next_session_tokens ( next_tokens. clone ( ) )
434
- . expected_refresh_token ( prev_tokens. refresh_token . clone ( ) . unwrap ( ) ) ,
435
- ) ;
436
-
437
429
// Create the first client.
438
430
let tmp_dir = tempfile:: tempdir ( ) ?;
439
- let client = MockClientBuilder :: new ( "https://example.org" . to_owned ( ) )
440
- . sqlite_store ( & tmp_dir)
441
- . unlogged ( )
442
- . build ( )
443
- . await ;
431
+ let client = server. client_builder ( ) . sqlite_store ( & tmp_dir) . unlogged ( ) . build ( ) . await ;
444
432
445
- let oidc = Oidc { client : client . clone ( ) , backend : backend . clone ( ) } ;
433
+ let oidc = client. oidc ( ) ;
446
434
oidc. enable_cross_process_refresh_lock ( "client1" . to_owned ( ) ) . await ?;
447
435
448
- oidc. restore_session ( mock_session ( prev_tokens. clone ( ) , ISSUER_URL . to_owned ( ) ) ) . await ?;
436
+ oidc. restore_session ( mock_session ( prev_tokens. clone ( ) , issuer . clone ( ) ) ) . await ?;
449
437
450
438
// Create a second client, without restoring it, to test that a token update
451
439
// before restoration doesn't cause new issues.
452
- let unrestored_client = MockClientBuilder :: new ( "https://example.org" . to_owned ( ) )
453
- . sqlite_store ( & tmp_dir)
454
- . unlogged ( )
455
- . build ( )
456
- . await ;
457
- let unrestored_oidc = Oidc { client : unrestored_client. clone ( ) , backend : backend. clone ( ) } ;
440
+ let unrestored_client =
441
+ server. client_builder ( ) . sqlite_store ( & tmp_dir) . unlogged ( ) . build ( ) . await ;
442
+ let unrestored_oidc = unrestored_client. oidc ( ) ;
458
443
unrestored_oidc. enable_cross_process_refresh_lock ( "unrestored_client" . to_owned ( ) ) . await ?;
459
444
460
445
{
461
446
// Create a third client that will run a refresh while the others two are doing
462
447
// nothing.
463
- let client3 = MockClientBuilder :: new ( "https://example.org" . to_owned ( ) )
464
- . sqlite_store ( & tmp_dir)
465
- . unlogged ( )
466
- . build ( )
467
- . await ;
448
+ let client3 = server. client_builder ( ) . sqlite_store ( & tmp_dir) . unlogged ( ) . build ( ) . await ;
468
449
469
- let oidc3 = Oidc { client : client3. clone ( ) , backend : backend . clone ( ) } ;
450
+ let oidc3 = client3. oidc ( ) ;
470
451
oidc3. enable_cross_process_refresh_lock ( "client3" . to_owned ( ) ) . await ?;
471
- oidc3. restore_session ( mock_session ( prev_tokens. clone ( ) , ISSUER_URL . to_owned ( ) ) ) . await ?;
452
+ oidc3. restore_session ( mock_session ( prev_tokens. clone ( ) , issuer . clone ( ) ) ) . await ?;
472
453
473
454
// Run a refresh in the second client; this will invalidate the tokens from the
474
455
// first token.
@@ -500,7 +481,7 @@ mod tests {
500
481
Box :: new ( |_| panic ! ( "save_session_callback shouldn't be called here" ) ) ,
501
482
) ?;
502
483
503
- oidc. restore_session ( mock_session ( prev_tokens. clone ( ) , ISSUER_URL . to_owned ( ) ) ) . await ?;
484
+ oidc. restore_session ( mock_session ( prev_tokens. clone ( ) , issuer ) ) . await ?;
504
485
505
486
// And this client is now aware of the latest tokens.
506
487
let xp_manager =
@@ -550,9 +531,6 @@ mod tests {
550
531
assert ! ( !guard. hash_mismatch) ;
551
532
}
552
533
553
- // There should have been at most one refresh.
554
- assert_eq ! ( * backend. num_refreshes. lock( ) . unwrap( ) , 1 ) ;
555
-
556
534
Ok ( ( ) )
557
535
}
558
536
0 commit comments