@@ -54,10 +54,11 @@ use std::{
54
54
io:: { self , Write } ,
55
55
sync:: {
56
56
atomic:: { AtomicBool , AtomicUsize , Ordering } ,
57
- Arc , RwLock ,
57
+ Arc ,
58
58
} ,
59
59
time:: { Duration , Instant } ,
60
60
} ;
61
+ use parking_lot:: RwLock ;
61
62
use url:: Url ;
62
63
63
64
/// Error that can occur when building a [Transport]
@@ -68,6 +69,9 @@ pub enum BuildError {
68
69
69
70
/// Certificate error
70
71
Cert ( reqwest:: Error ) ,
72
+
73
+ /// Configuration error
74
+ Config ( String ) ,
71
75
}
72
76
73
77
impl From < io:: Error > for BuildError {
@@ -88,13 +92,15 @@ impl error::Error for BuildError {
88
92
match * self {
89
93
BuildError :: Io ( ref err) => err. description ( ) ,
90
94
BuildError :: Cert ( ref err) => err. description ( ) ,
95
+ BuildError :: Config ( ref err) => err. as_str ( ) ,
91
96
}
92
97
}
93
98
94
99
fn cause ( & self ) -> Option < & dyn error:: Error > {
95
100
match * self {
96
101
BuildError :: Io ( ref err) => Some ( err as & dyn error:: Error ) ,
97
102
BuildError :: Cert ( ref err) => Some ( err as & dyn error:: Error ) ,
103
+ BuildError :: Config ( _) => None ,
98
104
}
99
105
}
100
106
}
@@ -104,6 +110,7 @@ impl fmt::Display for BuildError {
104
110
match * self {
105
111
BuildError :: Io ( ref err) => fmt:: Display :: fmt ( err, f) ,
106
112
BuildError :: Cert ( ref err) => fmt:: Display :: fmt ( err, f) ,
113
+ BuildError :: Config ( ref err) => fmt:: Display :: fmt ( err, f) ,
107
114
}
108
115
}
109
116
}
@@ -337,7 +344,7 @@ impl TransportBuilder {
337
344
if let Some ( c) = self . proxy_credentials {
338
345
proxy = match c {
339
346
Credentials :: Basic ( u, p) => proxy. basic_auth ( & u, & p) ,
340
- _ => proxy ,
347
+ _ => return Err ( BuildError :: Config ( "Only Basic Authentication is supported for proxies" . into ( ) ) ) ,
341
348
} ;
342
349
}
343
350
client_builder = client_builder. proxy ( proxy) ;
@@ -348,7 +355,7 @@ impl TransportBuilder {
348
355
client,
349
356
conn_pool : self . conn_pool ,
350
357
request_body_compression : self . request_body_compression ,
351
- credentials : self . credentials ,
358
+ credentials : Arc :: new ( RwLock :: new ( self . credentials ) ) ,
352
359
send_meta : self . meta_header ,
353
360
} )
354
361
}
@@ -393,7 +400,7 @@ impl Connection {
393
400
#[ derive( Debug , Clone ) ]
394
401
pub struct Transport {
395
402
client : reqwest:: Client ,
396
- credentials : Option < Credentials > ,
403
+ credentials : Arc < RwLock < Option < Credentials > > > ,
397
404
request_body_compression : bool ,
398
405
conn_pool : Arc < dyn ConnectionPool > ,
399
406
send_meta : bool ,
@@ -478,7 +485,7 @@ impl Transport {
478
485
/// [Elasticsearch service in Elastic Cloud](https://www.elastic.co/cloud/).
479
486
///
480
487
/// * `cloud_id`: The Elastic Cloud Id retrieved from the cloud web console, that uniquely
481
- /// identifies the deployment instance.
488
+ /// identifies the deployment instance.
482
489
/// * `credentials`: A set of credentials the client should use to authenticate to Elasticsearch service.
483
490
pub fn cloud ( cloud_id : & str , credentials : Credentials ) -> Result < Transport , Error > {
484
491
let conn_pool = CloudConnectionPool :: new ( cloud_id) ?;
@@ -513,7 +520,8 @@ impl Transport {
513
520
// set credentials before any headers, as credentials append to existing headers in reqwest,
514
521
// whilst setting headers() overwrites, so if an Authorization header has been specified
515
522
// on a specific request, we want it to overwrite.
516
- if let Some ( c) = & self . credentials {
523
+ let creds_guard = self . credentials . read ( ) ;
524
+ if let Some ( c) = creds_guard. as_ref ( ) {
517
525
request_builder = match c {
518
526
Credentials :: Basic ( u, p) => request_builder. basic_auth ( u, Some ( p) ) ,
519
527
Credentials :: Bearer ( t) => request_builder. bearer_auth ( t) ,
@@ -537,6 +545,7 @@ impl Transport {
537
545
}
538
546
}
539
547
}
548
+ drop ( creds_guard) ;
540
549
541
550
// default headers first, overwrite with any provided
542
551
let mut request_headers = HeaderMap :: with_capacity ( 4 + headers. len ( ) ) ;
@@ -696,6 +705,12 @@ impl Transport {
696
705
Err ( e) => Err ( e. into ( ) ) ,
697
706
}
698
707
}
708
+
709
+ /// Update the auth credentials for this transport and all its clones, and all clients
710
+ /// using them. Typically used to refresh a bearer token.
711
+ pub fn set_auth ( & self , credentials : Credentials ) {
712
+ * self . credentials . write ( ) = Some ( credentials) ;
713
+ }
699
714
}
700
715
701
716
impl Default for Transport {
@@ -895,14 +910,14 @@ where
895
910
ConnSelector : ConnectionSelector + Clone ,
896
911
{
897
912
fn next ( & self ) -> Connection {
898
- let inner = self . inner . read ( ) . expect ( "lock poisoned" ) ;
913
+ let inner = self . inner . read ( ) ;
899
914
self . connection_selector
900
915
. try_next ( & inner. connections )
901
916
. unwrap ( )
902
917
}
903
918
904
919
fn reseedable ( & self ) -> bool {
905
- let inner = self . inner . read ( ) . expect ( "lock poisoned" ) ;
920
+ let inner = self . inner . read ( ) ;
906
921
let reseed_frequency = match self . reseed_frequency {
907
922
Some ( wait) => wait,
908
923
None => return false ,
@@ -928,10 +943,11 @@ where
928
943
}
929
944
930
945
fn reseed ( & self , mut connection : Vec < Connection > ) {
931
- let mut inner = self . inner . write ( ) . expect ( "lock poisoned" ) ;
946
+ let mut inner = self . inner . write ( ) ;
932
947
inner. last_update = Some ( Instant :: now ( ) ) ;
933
948
inner. connections . clear ( ) ;
934
949
inner. connections . append ( & mut connection) ;
950
+ drop ( inner) ;
935
951
self . reseeding . store ( false , Ordering :: Relaxed ) ;
936
952
}
937
953
}
@@ -1210,7 +1226,7 @@ pub mod tests {
1210
1226
) ;
1211
1227
1212
1228
// Set internal last_update to a minute ago
1213
- let mut inner = connection_pool. inner . write ( ) . expect ( "lock poisoned" ) ;
1229
+ let mut inner = connection_pool. inner . write ( ) ;
1214
1230
inner. last_update = Some ( Instant :: now ( ) - Duration :: from_secs ( 60 ) ) ;
1215
1231
drop ( inner) ;
1216
1232
@@ -1249,4 +1265,37 @@ pub mod tests {
1249
1265
let connections = MultiNodeConnectionPool :: round_robin ( vec ! [ ] , None ) ;
1250
1266
connections. next ( ) ;
1251
1267
}
1268
+
1269
+ #[ test]
1270
+ fn set_credentials ( ) -> Result < ( ) , failure:: Error > {
1271
+ let t1: Transport = TransportBuilder :: new ( SingleNodeConnectionPool :: default ( ) )
1272
+ . auth ( Credentials :: Basic ( "foo" . to_string ( ) , "bar" . to_string ( ) ) )
1273
+ . build ( ) ?;
1274
+
1275
+ if let Some ( Credentials :: Basic ( login, password) ) = t1. credentials . read ( ) . as_ref ( ) {
1276
+ assert_eq ! ( login, "foo" ) ;
1277
+ assert_eq ! ( password, "bar" ) ;
1278
+ } else {
1279
+ panic ! ( "Expected Basic credentials" ) ;
1280
+ }
1281
+
1282
+ let t2 = t1. clone ( ) ;
1283
+
1284
+ t1. set_auth ( Credentials :: Bearer ( "The bear" . to_string ( ) ) ) ;
1285
+
1286
+ if let Some ( Credentials :: Bearer ( token) ) = t1. credentials . read ( ) . as_ref ( ) {
1287
+ assert_eq ! ( token, "The bear" ) ;
1288
+ } else {
1289
+ panic ! ( "Expected Bearer credentials" ) ;
1290
+ }
1291
+
1292
+ // Verify that cloned transport also has the same credentials
1293
+ if let Some ( Credentials :: Bearer ( token) ) = t2. credentials . read ( ) . as_ref ( ) {
1294
+ assert_eq ! ( token, "The bear" ) ;
1295
+ } else {
1296
+ panic ! ( "Expected Bearer credentials" ) ;
1297
+ }
1298
+
1299
+ Ok ( ( ) )
1300
+ }
1252
1301
}
0 commit comments