Skip to content

Commit 7a8c073

Browse files
committed
Allow updating transport credentials
1 parent e8a928f commit 7a8c073

File tree

3 files changed

+99
-9
lines changed

3 files changed

+99
-9
lines changed

Cargo.lock

+40
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

elasticsearch/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ native-tls = ["reqwest/native-tls"]
2626
rustls-tls = ["reqwest/rustls-tls"]
2727

2828
[dependencies]
29+
parking_lot = "0.12"
2930
base64 = "0.22"
3031
bytes = "1"
3132
dyn-clone = "1"

elasticsearch/src/http/transport.rs

+58-9
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ use std::{
5454
io::{self, Write},
5555
sync::{
5656
atomic::{AtomicBool, AtomicUsize, Ordering},
57-
Arc, RwLock,
57+
Arc,
5858
},
5959
time::{Duration, Instant},
6060
};
61+
use parking_lot::RwLock;
6162
use url::Url;
6263

6364
/// Error that can occur when building a [Transport]
@@ -68,6 +69,9 @@ pub enum BuildError {
6869

6970
/// Certificate error
7071
Cert(reqwest::Error),
72+
73+
/// Configuration error
74+
Config(String),
7175
}
7276

7377
impl From<io::Error> for BuildError {
@@ -88,13 +92,15 @@ impl error::Error for BuildError {
8892
match *self {
8993
BuildError::Io(ref err) => err.description(),
9094
BuildError::Cert(ref err) => err.description(),
95+
BuildError::Config(ref err) => err.as_str(),
9196
}
9297
}
9398

9499
fn cause(&self) -> Option<&dyn error::Error> {
95100
match *self {
96101
BuildError::Io(ref err) => Some(err as &dyn error::Error),
97102
BuildError::Cert(ref err) => Some(err as &dyn error::Error),
103+
BuildError::Config(_) => None,
98104
}
99105
}
100106
}
@@ -104,6 +110,7 @@ impl fmt::Display for BuildError {
104110
match *self {
105111
BuildError::Io(ref err) => fmt::Display::fmt(err, f),
106112
BuildError::Cert(ref err) => fmt::Display::fmt(err, f),
113+
BuildError::Config(ref err) => fmt::Display::fmt(err, f),
107114
}
108115
}
109116
}
@@ -337,7 +344,7 @@ impl TransportBuilder {
337344
if let Some(c) = self.proxy_credentials {
338345
proxy = match c {
339346
Credentials::Basic(u, p) => proxy.basic_auth(&u, &p),
340-
_ => proxy,
347+
_ => return Err(BuildError::Config("Only Basic Authentication is supported for proxies".into())),
341348
};
342349
}
343350
client_builder = client_builder.proxy(proxy);
@@ -348,7 +355,7 @@ impl TransportBuilder {
348355
client,
349356
conn_pool: self.conn_pool,
350357
request_body_compression: self.request_body_compression,
351-
credentials: self.credentials,
358+
credentials: Arc::new(RwLock::new(self.credentials)),
352359
send_meta: self.meta_header,
353360
})
354361
}
@@ -393,7 +400,7 @@ impl Connection {
393400
#[derive(Debug, Clone)]
394401
pub struct Transport {
395402
client: reqwest::Client,
396-
credentials: Option<Credentials>,
403+
credentials: Arc<RwLock<Option<Credentials>>>,
397404
request_body_compression: bool,
398405
conn_pool: Arc<dyn ConnectionPool>,
399406
send_meta: bool,
@@ -513,7 +520,8 @@ impl Transport {
513520
// set credentials before any headers, as credentials append to existing headers in reqwest,
514521
// whilst setting headers() overwrites, so if an Authorization header has been specified
515522
// on a specific request, we want it to overwrite.
516-
if let Some(c) = &self.credentials {
523+
let creds_guard = self.credentials.read();
524+
if let Some(c) = creds_guard.as_ref() {
517525
request_builder = match c {
518526
Credentials::Basic(u, p) => request_builder.basic_auth(u, Some(p)),
519527
Credentials::Bearer(t) => request_builder.bearer_auth(t),
@@ -537,6 +545,7 @@ impl Transport {
537545
}
538546
}
539547
}
548+
drop(creds_guard);
540549

541550
// default headers first, overwrite with any provided
542551
let mut request_headers = HeaderMap::with_capacity(4 + headers.len());
@@ -696,6 +705,12 @@ impl Transport {
696705
Err(e) => Err(e.into()),
697706
}
698707
}
708+
709+
/// Update the auth credentials for this transport and all its clones, and all clients
710+
/// using them. Typically used to refresh a bearer token.
711+
pub fn set_auth(&self, credentials: Credentials) {
712+
*self.credentials.write() = Some(credentials);
713+
}
699714
}
700715

701716
impl Default for Transport {
@@ -895,14 +910,14 @@ where
895910
ConnSelector: ConnectionSelector + Clone,
896911
{
897912
fn next(&self) -> Connection {
898-
let inner = self.inner.read().expect("lock poisoned");
913+
let inner = self.inner.read();
899914
self.connection_selector
900915
.try_next(&inner.connections)
901916
.unwrap()
902917
}
903918

904919
fn reseedable(&self) -> bool {
905-
let inner = self.inner.read().expect("lock poisoned");
920+
let inner = self.inner.read();
906921
let reseed_frequency = match self.reseed_frequency {
907922
Some(wait) => wait,
908923
None => return false,
@@ -928,10 +943,11 @@ where
928943
}
929944

930945
fn reseed(&self, mut connection: Vec<Connection>) {
931-
let mut inner = self.inner.write().expect("lock poisoned");
946+
let mut inner = self.inner.write();
932947
inner.last_update = Some(Instant::now());
933948
inner.connections.clear();
934949
inner.connections.append(&mut connection);
950+
drop(inner);
935951
self.reseeding.store(false, Ordering::Relaxed);
936952
}
937953
}
@@ -1210,7 +1226,7 @@ pub mod tests {
12101226
);
12111227

12121228
// Set internal last_update to a minute ago
1213-
let mut inner = connection_pool.inner.write().expect("lock poisoned");
1229+
let mut inner = connection_pool.inner.write();
12141230
inner.last_update = Some(Instant::now() - Duration::from_secs(60));
12151231
drop(inner);
12161232

@@ -1249,4 +1265,37 @@ pub mod tests {
12491265
let connections = MultiNodeConnectionPool::round_robin(vec![], None);
12501266
connections.next();
12511267
}
1268+
1269+
#[test]
1270+
fn set_credentials() -> anyhow::Result<()> {
1271+
let t1: Transport = TransportBuilder::new(SingleNodeConnectionPool::default())
1272+
.auth(Credentials::Basic("foo".to_string(), "bar".to_string()))
1273+
.build()?;
1274+
1275+
if let Some(Credentials::Basic(login, password)) = t1.credentials.read().as_ref() {
1276+
assert_eq!(login, "foo");
1277+
assert_eq!(password, "bar");
1278+
} else {
1279+
panic!("Expected Basic credentials");
1280+
}
1281+
1282+
let t2 = t1.clone();
1283+
1284+
t1.set_auth(Credentials::Bearer("The bear".to_string()));
1285+
1286+
if let Some(Credentials::Bearer(token)) = t1.credentials.read().as_ref() {
1287+
assert_eq!(token, "The bear");
1288+
} else {
1289+
panic!("Expected Bearer credentials");
1290+
}
1291+
1292+
// Verify that cloned transport also has the same credentials
1293+
if let Some(Credentials::Bearer(token)) = t2.credentials.read().as_ref() {
1294+
assert_eq!(token, "The bear");
1295+
} else {
1296+
panic!("Expected Bearer credentials");
1297+
}
1298+
1299+
Ok(())
1300+
}
12521301
}

0 commit comments

Comments
 (0)