Skip to content

Commit 6a82788

Browse files
authored
[8.18] Allow updating transport credentials (#254) (#256)
1 parent a4f2fea commit 6a82788

File tree

9 files changed

+114
-13
lines changed

9 files changed

+114
-13
lines changed

Cargo.lock

+40
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api_generator/Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@ void = "1"
3838

3939
[dev-dependencies]
4040
tempfile = "3.12"
41+
42+
[lints.clippy]
43+
uninlined_format_args = "allow" # too pedantic

api_generator/src/generator/code_gen/request/request_builder.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ impl<'a> RequestBuilder<'a> {
154154
}
155155
});
156156

157-
let query_ctor = endpoint_params.iter().map(|(param_name, _)| {
157+
let query_ctor = endpoint_params.keys().map(|param_name| {
158158
let field_name = ident(valid_name(param_name).to_lowercase());
159159
quote! {
160160
#field_name: self.#field_name

elasticsearch/Cargo.toml

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ native-tls = ["reqwest/native-tls"]
2626
rustls-tls = ["reqwest/rustls-tls"]
2727

2828
[dependencies]
29+
parking_lot = "0.12"
2930
base64 = "0.22"
3031
bytes = "1"
3132
dyn-clone = "1"
@@ -66,3 +67,7 @@ xml-rs = "0.8"
6667

6768
[build-dependencies]
6869
rustc_version = "0.4"
70+
71+
[lints.clippy]
72+
needless_lifetimes = "allow" # generated lifetimes
73+
uninlined_format_args = "allow" # too pedantic

elasticsearch/src/http/transport.rs

+59-10
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ use std::{
5454
io::{self, Write},
5555
sync::{
5656
atomic::{AtomicBool, AtomicUsize, Ordering},
57-
Arc, RwLock,
57+
Arc,
5858
},
5959
time::{Duration, Instant},
6060
};
61+
use parking_lot::RwLock;
6162
use url::Url;
6263

6364
/// Error that can occur when building a [Transport]
@@ -68,6 +69,9 @@ pub enum BuildError {
6869

6970
/// Certificate error
7071
Cert(reqwest::Error),
72+
73+
/// Configuration error
74+
Config(String),
7175
}
7276

7377
impl From<io::Error> for BuildError {
@@ -88,13 +92,15 @@ impl error::Error for BuildError {
8892
match *self {
8993
BuildError::Io(ref err) => err.description(),
9094
BuildError::Cert(ref err) => err.description(),
95+
BuildError::Config(ref err) => err.as_str(),
9196
}
9297
}
9398

9499
fn cause(&self) -> Option<&dyn error::Error> {
95100
match *self {
96101
BuildError::Io(ref err) => Some(err as &dyn error::Error),
97102
BuildError::Cert(ref err) => Some(err as &dyn error::Error),
103+
BuildError::Config(_) => None,
98104
}
99105
}
100106
}
@@ -104,6 +110,7 @@ impl fmt::Display for BuildError {
104110
match *self {
105111
BuildError::Io(ref err) => fmt::Display::fmt(err, f),
106112
BuildError::Cert(ref err) => fmt::Display::fmt(err, f),
113+
BuildError::Config(ref err) => fmt::Display::fmt(err, f),
107114
}
108115
}
109116
}
@@ -337,7 +344,7 @@ impl TransportBuilder {
337344
if let Some(c) = self.proxy_credentials {
338345
proxy = match c {
339346
Credentials::Basic(u, p) => proxy.basic_auth(&u, &p),
340-
_ => proxy,
347+
_ => return Err(BuildError::Config("Only Basic Authentication is supported for proxies".into())),
341348
};
342349
}
343350
client_builder = client_builder.proxy(proxy);
@@ -348,7 +355,7 @@ impl TransportBuilder {
348355
client,
349356
conn_pool: self.conn_pool,
350357
request_body_compression: self.request_body_compression,
351-
credentials: self.credentials,
358+
credentials: Arc::new(RwLock::new(self.credentials)),
352359
send_meta: self.meta_header,
353360
})
354361
}
@@ -393,7 +400,7 @@ impl Connection {
393400
#[derive(Debug, Clone)]
394401
pub struct Transport {
395402
client: reqwest::Client,
396-
credentials: Option<Credentials>,
403+
credentials: Arc<RwLock<Option<Credentials>>>,
397404
request_body_compression: bool,
398405
conn_pool: Arc<dyn ConnectionPool>,
399406
send_meta: bool,
@@ -478,7 +485,7 @@ impl Transport {
478485
/// [Elasticsearch service in Elastic Cloud](https://www.elastic.co/cloud/).
479486
///
480487
/// * `cloud_id`: The Elastic Cloud Id retrieved from the cloud web console, that uniquely
481-
/// identifies the deployment instance.
488+
/// identifies the deployment instance.
482489
/// * `credentials`: A set of credentials the client should use to authenticate to Elasticsearch service.
483490
pub fn cloud(cloud_id: &str, credentials: Credentials) -> Result<Transport, Error> {
484491
let conn_pool = CloudConnectionPool::new(cloud_id)?;
@@ -513,7 +520,8 @@ impl Transport {
513520
// set credentials before any headers, as credentials append to existing headers in reqwest,
514521
// whilst setting headers() overwrites, so if an Authorization header has been specified
515522
// on a specific request, we want it to overwrite.
516-
if let Some(c) = &self.credentials {
523+
let creds_guard = self.credentials.read();
524+
if let Some(c) = creds_guard.as_ref() {
517525
request_builder = match c {
518526
Credentials::Basic(u, p) => request_builder.basic_auth(u, Some(p)),
519527
Credentials::Bearer(t) => request_builder.bearer_auth(t),
@@ -537,6 +545,7 @@ impl Transport {
537545
}
538546
}
539547
}
548+
drop(creds_guard);
540549

541550
// default headers first, overwrite with any provided
542551
let mut request_headers = HeaderMap::with_capacity(4 + headers.len());
@@ -696,6 +705,12 @@ impl Transport {
696705
Err(e) => Err(e.into()),
697706
}
698707
}
708+
709+
/// Update the auth credentials for this transport and all its clones, and all clients
710+
/// using them. Typically used to refresh a bearer token.
711+
pub fn set_auth(&self, credentials: Credentials) {
712+
*self.credentials.write() = Some(credentials);
713+
}
699714
}
700715

701716
impl Default for Transport {
@@ -895,14 +910,14 @@ where
895910
ConnSelector: ConnectionSelector + Clone,
896911
{
897912
fn next(&self) -> Connection {
898-
let inner = self.inner.read().expect("lock poisoned");
913+
let inner = self.inner.read();
899914
self.connection_selector
900915
.try_next(&inner.connections)
901916
.unwrap()
902917
}
903918

904919
fn reseedable(&self) -> bool {
905-
let inner = self.inner.read().expect("lock poisoned");
920+
let inner = self.inner.read();
906921
let reseed_frequency = match self.reseed_frequency {
907922
Some(wait) => wait,
908923
None => return false,
@@ -928,10 +943,11 @@ where
928943
}
929944

930945
fn reseed(&self, mut connection: Vec<Connection>) {
931-
let mut inner = self.inner.write().expect("lock poisoned");
946+
let mut inner = self.inner.write();
932947
inner.last_update = Some(Instant::now());
933948
inner.connections.clear();
934949
inner.connections.append(&mut connection);
950+
drop(inner);
935951
self.reseeding.store(false, Ordering::Relaxed);
936952
}
937953
}
@@ -1210,7 +1226,7 @@ pub mod tests {
12101226
);
12111227

12121228
// Set internal last_update to a minute ago
1213-
let mut inner = connection_pool.inner.write().expect("lock poisoned");
1229+
let mut inner = connection_pool.inner.write();
12141230
inner.last_update = Some(Instant::now() - Duration::from_secs(60));
12151231
drop(inner);
12161232

@@ -1249,4 +1265,37 @@ pub mod tests {
12491265
let connections = MultiNodeConnectionPool::round_robin(vec![], None);
12501266
connections.next();
12511267
}
1268+
1269+
#[test]
1270+
fn set_credentials() -> Result<(), failure::Error> {
1271+
let t1: Transport = TransportBuilder::new(SingleNodeConnectionPool::default())
1272+
.auth(Credentials::Basic("foo".to_string(), "bar".to_string()))
1273+
.build()?;
1274+
1275+
if let Some(Credentials::Basic(login, password)) = t1.credentials.read().as_ref() {
1276+
assert_eq!(login, "foo");
1277+
assert_eq!(password, "bar");
1278+
} else {
1279+
panic!("Expected Basic credentials");
1280+
}
1281+
1282+
let t2 = t1.clone();
1283+
1284+
t1.set_auth(Credentials::Bearer("The bear".to_string()));
1285+
1286+
if let Some(Credentials::Bearer(token)) = t1.credentials.read().as_ref() {
1287+
assert_eq!(token, "The bear");
1288+
} else {
1289+
panic!("Expected Bearer credentials");
1290+
}
1291+
1292+
// Verify that cloned transport also has the same credentials
1293+
if let Some(Credentials::Bearer(token)) = t2.credentials.read().as_ref() {
1294+
assert_eq!(token, "The bear");
1295+
} else {
1296+
panic!("Expected Bearer credentials");
1297+
}
1298+
1299+
Ok(())
1300+
}
12521301
}

elasticsearch/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
//! - **native-tls** *(enabled by default)*: Enables TLS functionality provided by `native-tls`.
5555
//! - **rustls-tls**: Enables TLS functionality provided by `rustls`.
5656
//! - **beta-apis**: Enables beta APIs. Beta APIs are on track to become stable and permanent features. Use them with
57-
//! caution because it is possible that breaking changes are made to these APIs in a minor version.
57+
//! caution because it is possible that breaking changes are made to these APIs in a minor version.
5858
//! - **experimental-apis**: Enables experimental APIs. Experimental APIs are just that - an experiment. An experimental
5959
//! API might have breaking changes in any future version, or it might even be removed entirely. This feature also
6060
//! enables `beta-apis`.

xtask/Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@ chrono = { version = "0.4", features = ["serde"] }
1515
zip = "2"
1616
regex = "1"
1717
xshell = "0.2"
18+
19+
[lints.clippy]
20+
uninlined_format_args = "allow" # too pedantic

yaml_test_runner/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ tokio = { version = "1", default-features = false, features = ["macros", "net",
4343
[lints.clippy]
4444
# yaml tests contain approximate values of PI
4545
approx_constant = "allow"
46+
uninlined_format_args = "allow" # too pedantic

yaml_test_runner/src/generator.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ pub fn generate_tests_from_yaml(
424424
error!(
425425
"skipping {}. cannot read as Yaml struct: {}",
426426
relative_path.to_slash_lossy(),
427-
result.err().unwrap().to_string()
427+
result.err().unwrap()
428428
);
429429
continue;
430430
}

0 commit comments

Comments
 (0)