Skip to content

Allow updating transport credentials #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions api_generator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ void = "1"

[dev-dependencies]
tempfile = "3.12"

[lints.clippy]
uninlined_format_args = "allow" # too pedantic
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl<'a> RequestBuilder<'a> {
}
});

let query_ctor = endpoint_params.iter().map(|(param_name, _)| {
let query_ctor = endpoint_params.keys().map(|param_name| {
let field_name = ident(valid_name(param_name).to_lowercase());
quote! {
#field_name: self.#field_name
Expand Down
5 changes: 5 additions & 0 deletions elasticsearch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ native-tls = ["reqwest/native-tls"]
rustls-tls = ["reqwest/rustls-tls"]

[dependencies]
parking_lot = "0.12"
base64 = "0.22"
bytes = "1"
dyn-clone = "1"
Expand Down Expand Up @@ -61,3 +62,7 @@ xml-rs = "0.8"

[build-dependencies]
rustc_version = "0.4"

[lints.clippy]
needless_lifetimes = "allow" # generated lifetimes
uninlined_format_args = "allow" # too pedantic
69 changes: 59 additions & 10 deletions elasticsearch/src/http/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ use std::{
io::{self, Write},
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, RwLock,
Arc,
},
time::{Duration, Instant},
};
use parking_lot::RwLock;
use url::Url;

/// Error that can occur when building a [Transport]
Expand All @@ -68,6 +69,9 @@ pub enum BuildError {

/// Certificate error
Cert(reqwest::Error),

/// Configuration error
Config(String),
}

impl From<io::Error> for BuildError {
Expand All @@ -88,13 +92,15 @@ impl error::Error for BuildError {
match *self {
BuildError::Io(ref err) => err.description(),
BuildError::Cert(ref err) => err.description(),
BuildError::Config(ref err) => err.as_str(),
}
}

fn cause(&self) -> Option<&dyn error::Error> {
match *self {
BuildError::Io(ref err) => Some(err as &dyn error::Error),
BuildError::Cert(ref err) => Some(err as &dyn error::Error),
BuildError::Config(_) => None,
}
}
}
Expand All @@ -104,6 +110,7 @@ impl fmt::Display for BuildError {
match *self {
BuildError::Io(ref err) => fmt::Display::fmt(err, f),
BuildError::Cert(ref err) => fmt::Display::fmt(err, f),
BuildError::Config(ref err) => fmt::Display::fmt(err, f),
}
}
}
Expand Down Expand Up @@ -337,7 +344,7 @@ impl TransportBuilder {
if let Some(c) = self.proxy_credentials {
proxy = match c {
Credentials::Basic(u, p) => proxy.basic_auth(&u, &p),
_ => proxy,
_ => return Err(BuildError::Config("Only Basic Authentication is supported for proxies".into())),
};
}
client_builder = client_builder.proxy(proxy);
Expand All @@ -348,7 +355,7 @@ impl TransportBuilder {
client,
conn_pool: self.conn_pool,
request_body_compression: self.request_body_compression,
credentials: self.credentials,
credentials: Arc::new(RwLock::new(self.credentials)),
send_meta: self.meta_header,
})
}
Expand Down Expand Up @@ -393,7 +400,7 @@ impl Connection {
#[derive(Debug, Clone)]
pub struct Transport {
client: reqwest::Client,
credentials: Option<Credentials>,
credentials: Arc<RwLock<Option<Credentials>>>,
request_body_compression: bool,
conn_pool: Arc<dyn ConnectionPool>,
send_meta: bool,
Expand Down Expand Up @@ -478,7 +485,7 @@ impl Transport {
/// [Elasticsearch service in Elastic Cloud](https://www.elastic.co/cloud/).
///
/// * `cloud_id`: The Elastic Cloud Id retrieved from the cloud web console, that uniquely
/// identifies the deployment instance.
/// identifies the deployment instance.
/// * `credentials`: A set of credentials the client should use to authenticate to Elasticsearch service.
pub fn cloud(cloud_id: &str, credentials: Credentials) -> Result<Transport, Error> {
let conn_pool = CloudConnectionPool::new(cloud_id)?;
Expand Down Expand Up @@ -513,7 +520,8 @@ impl Transport {
// set credentials before any headers, as credentials append to existing headers in reqwest,
// whilst setting headers() overwrites, so if an Authorization header has been specified
// on a specific request, we want it to overwrite.
if let Some(c) = &self.credentials {
let creds_guard = self.credentials.read();
if let Some(c) = creds_guard.as_ref() {
request_builder = match c {
Credentials::Basic(u, p) => request_builder.basic_auth(u, Some(p)),
Credentials::Bearer(t) => request_builder.bearer_auth(t),
Expand All @@ -537,6 +545,7 @@ impl Transport {
}
}
}
drop(creds_guard);

// default headers first, overwrite with any provided
let mut request_headers = HeaderMap::with_capacity(4 + headers.len());
Expand Down Expand Up @@ -696,6 +705,12 @@ impl Transport {
Err(e) => Err(e.into()),
}
}

/// Update the auth credentials for this transport and all its clones, and all clients
/// using them. Typically used to refresh a bearer token.
pub fn set_auth(&self, credentials: Credentials) {
*self.credentials.write() = Some(credentials);
}
}

impl Default for Transport {
Expand Down Expand Up @@ -895,14 +910,14 @@ where
ConnSelector: ConnectionSelector + Clone,
{
fn next(&self) -> Connection {
let inner = self.inner.read().expect("lock poisoned");
let inner = self.inner.read();
self.connection_selector
.try_next(&inner.connections)
.unwrap()
}

fn reseedable(&self) -> bool {
let inner = self.inner.read().expect("lock poisoned");
let inner = self.inner.read();
let reseed_frequency = match self.reseed_frequency {
Some(wait) => wait,
None => return false,
Expand All @@ -928,10 +943,11 @@ where
}

fn reseed(&self, mut connection: Vec<Connection>) {
let mut inner = self.inner.write().expect("lock poisoned");
let mut inner = self.inner.write();
inner.last_update = Some(Instant::now());
inner.connections.clear();
inner.connections.append(&mut connection);
drop(inner);
self.reseeding.store(false, Ordering::Relaxed);
}
}
Expand Down Expand Up @@ -1210,7 +1226,7 @@ pub mod tests {
);

// Set internal last_update to a minute ago
let mut inner = connection_pool.inner.write().expect("lock poisoned");
let mut inner = connection_pool.inner.write();
inner.last_update = Some(Instant::now() - Duration::from_secs(60));
drop(inner);

Expand Down Expand Up @@ -1249,4 +1265,37 @@ pub mod tests {
let connections = MultiNodeConnectionPool::round_robin(vec![], None);
connections.next();
}

#[test]
fn set_credentials() -> anyhow::Result<()> {
let t1: Transport = TransportBuilder::new(SingleNodeConnectionPool::default())
.auth(Credentials::Basic("foo".to_string(), "bar".to_string()))
.build()?;

if let Some(Credentials::Basic(login, password)) = t1.credentials.read().as_ref() {
assert_eq!(login, "foo");
assert_eq!(password, "bar");
} else {
panic!("Expected Basic credentials");
}

let t2 = t1.clone();

t1.set_auth(Credentials::Bearer("The bear".to_string()));

if let Some(Credentials::Bearer(token)) = t1.credentials.read().as_ref() {
assert_eq!(token, "The bear");
} else {
panic!("Expected Bearer credentials");
}

// Verify that cloned transport also has the same credentials
if let Some(Credentials::Bearer(token)) = t2.credentials.read().as_ref() {
assert_eq!(token, "The bear");
} else {
panic!("Expected Bearer credentials");
}

Ok(())
}
}
2 changes: 1 addition & 1 deletion elasticsearch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
//! - **native-tls** *(enabled by default)*: Enables TLS functionality provided by `native-tls`.
//! - **rustls-tls**: Enables TLS functionality provided by `rustls`.
//! - **beta-apis**: Enables beta APIs. Beta APIs are on track to become stable and permanent features. Use them with
//! caution because it is possible that breaking changes are made to these APIs in a minor version.
//! caution because it is possible that breaking changes are made to these APIs in a minor version.
//! - **experimental-apis**: Enables experimental APIs. Experimental APIs are just that - an experiment. An experimental
//! API might have breaking changes in any future version, or it might even be removed entirely. This feature also
//! enables `beta-apis`.
Expand Down
3 changes: 3 additions & 0 deletions xtask/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ chrono = { version = "0.4", features = ["serde"] }
zip = "2"
regex = "1"
xshell = "0.2"

[lints.clippy]
uninlined_format_args = "allow" # too pedantic
1 change: 1 addition & 0 deletions yaml_test_runner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ tokio = { version = "1", default-features = false, features = ["macros", "net",
[lints.clippy]
# yaml tests contain approximate values of PI
approx_constant = "allow"
uninlined_format_args = "allow" # too pedantic
2 changes: 1 addition & 1 deletion yaml_test_runner/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ pub fn generate_tests_from_yaml(
error!(
"skipping {}. cannot read as Yaml struct: {}",
relative_path.to_slash_lossy(),
result.err().unwrap().to_string()
result.err().unwrap()
);
continue;
}
Expand Down
Loading