Skip to content

Allow updating transport credentials #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions elasticsearch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ native-tls = ["reqwest/native-tls"]
rustls-tls = ["reqwest/rustls-tls"]

[dependencies]
parking_lot = "0.12"
base64 = "0.22"
bytes = "1"
dyn-clone = "1"
Expand Down
67 changes: 58 additions & 9 deletions elasticsearch/src/http/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@
io::{self, Write},
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, RwLock,
Arc,
},
time::{Duration, Instant},
};
use parking_lot::RwLock;
use url::Url;

/// Error that can occur when building a [Transport]
Expand All @@ -68,6 +69,9 @@

/// Certificate error
Cert(reqwest::Error),

/// Configuration error
Config(String),
}

impl From<io::Error> for BuildError {
Expand All @@ -88,13 +92,15 @@
match *self {
BuildError::Io(ref err) => err.description(),
BuildError::Cert(ref err) => err.description(),
BuildError::Config(ref err) => err.as_str(),
}
}

fn cause(&self) -> Option<&dyn error::Error> {
match *self {
BuildError::Io(ref err) => Some(err as &dyn error::Error),
BuildError::Cert(ref err) => Some(err as &dyn error::Error),
BuildError::Config(_) => None,
}
}
}
Expand All @@ -104,6 +110,7 @@
match *self {
BuildError::Io(ref err) => fmt::Display::fmt(err, f),
BuildError::Cert(ref err) => fmt::Display::fmt(err, f),
BuildError::Config(ref err) => fmt::Display::fmt(err, f),
}
}
}
Expand Down Expand Up @@ -131,7 +138,7 @@
}

let rustc = env!("RUSTC_VERSION");
let mut meta = format!("es={},rs={},t={}", version, rustc, version);

Check warning on line 141 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

variables can be used directly in the `format!` string

warning: variables can be used directly in the `format!` string --> elasticsearch/src/http/transport.rs:141:20 | 141 | let mut meta = format!("es={},rs={},t={}", version, rustc, version); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args = note: `#[warn(clippy::uninlined_format_args)]` on by default help: change this to | 141 - let mut meta = format!("es={},rs={},t={}", version, rustc, version); 141 + let mut meta = format!("es={version},rs={rustc},t={version}"); |

if cfg!(feature = "native-tls") {
meta.push_str(",tls=n");
Expand Down Expand Up @@ -337,7 +344,7 @@
if let Some(c) = self.proxy_credentials {
proxy = match c {
Credentials::Basic(u, p) => proxy.basic_auth(&u, &p),
_ => proxy,
_ => return Err(BuildError::Config("Only Basic Authentication is supported for proxies".into())),
};
}
client_builder = client_builder.proxy(proxy);
Expand All @@ -348,7 +355,7 @@
client,
conn_pool: self.conn_pool,
request_body_compression: self.request_body_compression,
credentials: self.credentials,
credentials: Arc::new(RwLock::new(self.credentials)),
send_meta: self.meta_header,
})
}
Expand Down Expand Up @@ -393,7 +400,7 @@
#[derive(Debug, Clone)]
pub struct Transport {
client: reqwest::Client,
credentials: Option<Credentials>,
credentials: Arc<RwLock<Option<Credentials>>>,
request_body_compression: bool,
conn_pool: Arc<dyn ConnectionPool>,
send_meta: bool,
Expand Down Expand Up @@ -478,7 +485,7 @@
/// [Elasticsearch service in Elastic Cloud](https://www.elastic.co/cloud/).
///
/// * `cloud_id`: The Elastic Cloud Id retrieved from the cloud web console, that uniquely
/// identifies the deployment instance.

Check warning on line 488 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

doc list item overindented

warning: doc list item overindented --> elasticsearch/src/http/transport.rs:488:9 | 488 | /// identifies the deployment instance. | ^^^^^^^^^^^^^^ help: try using ` ` (2 spaces) | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#doc_overindented_list_items
/// * `credentials`: A set of credentials the client should use to authenticate to Elasticsearch service.
pub fn cloud(cloud_id: &str, credentials: Credentials) -> Result<Transport, Error> {
let conn_pool = CloudConnectionPool::new(cloud_id)?;
Expand Down Expand Up @@ -513,7 +520,8 @@
// set credentials before any headers, as credentials append to existing headers in reqwest,
// whilst setting headers() overwrites, so if an Authorization header has been specified
// on a specific request, we want it to overwrite.
if let Some(c) = &self.credentials {
let creds_guard = self.credentials.read();
if let Some(c) = creds_guard.as_ref() {
request_builder = match c {
Credentials::Basic(u, p) => request_builder.basic_auth(u, Some(p)),
Credentials::Bearer(t) => request_builder.bearer_auth(t),
Expand All @@ -523,20 +531,21 @@
let mut header_value = b"ApiKey ".to_vec();
{
let mut encoder = EncoderWriter::new(&mut header_value, &BASE64_STANDARD);
write!(encoder, "{}:", i).unwrap();

Check warning on line 534 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

variables can be used directly in the `format!` string

warning: variables can be used directly in the `format!` string --> elasticsearch/src/http/transport.rs:534:25 | 534 | write!(encoder, "{}:", i).unwrap(); | ^^^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args help: change this to | 534 - write!(encoder, "{}:", i).unwrap(); 534 + write!(encoder, "{i}:").unwrap(); |
write!(encoder, "{}", k).unwrap();

Check warning on line 535 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

variables can be used directly in the `format!` string

warning: variables can be used directly in the `format!` string --> elasticsearch/src/http/transport.rs:535:25 | 535 | write!(encoder, "{}", k).unwrap(); | ^^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args help: change this to | 535 - write!(encoder, "{}", k).unwrap(); 535 + write!(encoder, "{k}").unwrap(); |
}
let mut header_value = HeaderValue::from_bytes(&header_value).unwrap();
header_value.set_sensitive(true);
request_builder.header(AUTHORIZATION, header_value)
}
Credentials::EncodedApiKey(k) => {
let mut header_value = HeaderValue::try_from(format!("ApiKey {}", k)).unwrap();

Check warning on line 542 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

variables can be used directly in the `format!` string

warning: variables can be used directly in the `format!` string --> elasticsearch/src/http/transport.rs:542:66 | 542 | let mut header_value = HeaderValue::try_from(format!("ApiKey {}", k)).unwrap(); | ^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args help: change this to | 542 - let mut header_value = HeaderValue::try_from(format!("ApiKey {}", k)).unwrap(); 542 + let mut header_value = HeaderValue::try_from(format!("ApiKey {k}")).unwrap(); |
header_value.set_sensitive(true);
request_builder.header(AUTHORIZATION, header_value)
}
}
}
drop(creds_guard);

// default headers first, overwrite with any provided
let mut request_headers = HeaderMap::with_capacity(4 + headers.len());
Expand Down Expand Up @@ -599,11 +608,11 @@
}

let (host, port) = host_port.ok_or_else(|| {
crate::error::lib(format!("error parsing address into url: {}", address))

Check warning on line 611 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

variables can be used directly in the `format!` string

warning: variables can be used directly in the `format!` string --> elasticsearch/src/http/transport.rs:611:31 | 611 | crate::error::lib(format!("error parsing address into url: {}", address)) | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args help: change this to | 611 - crate::error::lib(format!("error parsing address into url: {}", address)) 611 + crate::error::lib(format!("error parsing address into url: {address}")) |
})?;

Ok(Url::parse(
format!("{}://{}:{}", scheme, host, port).as_str(),

Check warning on line 615 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

variables can be used directly in the `format!` string

warning: variables can be used directly in the `format!` string --> elasticsearch/src/http/transport.rs:615:13 | 615 | format!("{}://{}:{}", scheme, host, port).as_str(), | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args help: change this to | 615 - format!("{}://{}:{}", scheme, host, port).as_str(), 615 + format!("{scheme}://{host}:{port}").as_str(), |
)?)
}

Expand Down Expand Up @@ -696,6 +705,12 @@
Err(e) => Err(e.into()),
}
}

/// Update the auth credentials for this transport and all its clones, and all clients
/// using them. Typically used to refresh a bearer token.
pub fn set_auth(&self, credentials: Credentials) {
*self.credentials.write() = Some(credentials);
}
}

impl Default for Transport {
Expand Down Expand Up @@ -785,10 +800,10 @@
let data = parts[1];
let decoded_result = BASE64_STANDARD.decode(data);
if decoded_result.is_err() {
return Err(crate::error::lib(format!(
"cannot base 64 decode '{}'",
data
)));

Check warning on line 806 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

variables can be used directly in the `format!` string

warning: variables can be used directly in the `format!` string --> elasticsearch/src/http/transport.rs:803:42 | 803 | return Err(crate::error::lib(format!( | __________________________________________^ 804 | | "cannot base 64 decode '{}'", 805 | | data 806 | | ))); | |_____________^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args
}

let decoded = decoded_result.unwrap();
Expand Down Expand Up @@ -842,7 +857,7 @@
}
}

let url = Url::parse(format!("https://{}.{}", uuid, domain_name).as_ref())?;

Check warning on line 860 in elasticsearch/src/http/transport.rs

View workflow job for this annotation

GitHub Actions / clippy

variables can be used directly in the `format!` string

warning: variables can be used directly in the `format!` string --> elasticsearch/src/http/transport.rs:860:30 | 860 | let url = Url::parse(format!("https://{}.{}", uuid, domain_name).as_ref())?; | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#uninlined_format_args help: change this to | 860 - let url = Url::parse(format!("https://{}.{}", uuid, domain_name).as_ref())?; 860 + let url = Url::parse(format!("https://{uuid}.{domain_name}").as_ref())?; |
Ok(CloudId { name, url })
}
}
Expand Down Expand Up @@ -895,14 +910,14 @@
ConnSelector: ConnectionSelector + Clone,
{
fn next(&self) -> Connection {
let inner = self.inner.read().expect("lock poisoned");
let inner = self.inner.read();
self.connection_selector
.try_next(&inner.connections)
.unwrap()
}

fn reseedable(&self) -> bool {
let inner = self.inner.read().expect("lock poisoned");
let inner = self.inner.read();
let reseed_frequency = match self.reseed_frequency {
Some(wait) => wait,
None => return false,
Expand All @@ -928,10 +943,11 @@
}

fn reseed(&self, mut connection: Vec<Connection>) {
let mut inner = self.inner.write().expect("lock poisoned");
let mut inner = self.inner.write();
inner.last_update = Some(Instant::now());
inner.connections.clear();
inner.connections.append(&mut connection);
drop(inner);
self.reseeding.store(false, Ordering::Relaxed);
}
}
Expand Down Expand Up @@ -1210,7 +1226,7 @@
);

// Set internal last_update to a minute ago
let mut inner = connection_pool.inner.write().expect("lock poisoned");
let mut inner = connection_pool.inner.write();
inner.last_update = Some(Instant::now() - Duration::from_secs(60));
drop(inner);

Expand Down Expand Up @@ -1249,4 +1265,37 @@
let connections = MultiNodeConnectionPool::round_robin(vec![], None);
connections.next();
}

#[test]
fn set_credentials() -> anyhow::Result<()> {
let t1: Transport = TransportBuilder::new(SingleNodeConnectionPool::default())
.auth(Credentials::Basic("foo".to_string(), "bar".to_string()))
.build()?;

if let Some(Credentials::Basic(login, password)) = t1.credentials.read().as_ref() {
assert_eq!(login, "foo");
assert_eq!(password, "bar");
} else {
panic!("Expected Basic credentials");
}

let t2 = t1.clone();

t1.set_auth(Credentials::Bearer("The bear".to_string()));

if let Some(Credentials::Bearer(token)) = t1.credentials.read().as_ref() {
assert_eq!(token, "The bear");
} else {
panic!("Expected Bearer credentials");
}

// Verify that cloned transport also has the same credentials
if let Some(Credentials::Bearer(token)) = t2.credentials.read().as_ref() {
assert_eq!(token, "The bear");
} else {
panic!("Expected Bearer credentials");
}

Ok(())
}
}
Loading