Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ otel = [
# Exports code dependent on private interfaces for the integration test suite
test = ["dep:snapbox", "dep:walkdir", "clap-cargo/testing_colors"]

# Run the tests that require containers.
test-with-containers = ["test"]

# Sorted by alphabetic order
[dependencies]
anstream = "0.6.20"
Expand Down
10 changes: 10 additions & 0 deletions doc/user-guide/src/environment-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
- `RUSTUP_VERSION` (default: none). Overrides the rustup version (e.g. `1.27.1`)
to be downloaded when executing `rustup-init.sh` or `rustup self update`.

- `RUSTUP_AUTHORIZATION_HEADER` (default: none). The value to an `Authorization` HTTP
header that should be added to all requests made by rustup. This is meant for use when
using an alternate rustup distribution server (through the `RUSTUP_DIST_SERVER`
environment variable) which requires authentication such as basic username:password
credentials or a bearer token.

- `RUSTUP_PROXY_AUTHORIZATION_HEADER` (default: none). This is like the `RUSTUP_AUTHORIZATION_HEADER` except
this will add a `Proxy-Authorization` HTTP header. This is for authenticating to forward
proxies (via the `HTTP_PROXY` or `HTTPS_PROXY`) environment variables.

- `RUSTUP_IO_THREADS` *unstable* (default: reported cpu count, max 8). Sets the
number of threads to perform close IO in. Set to `1` to force
single-threaded IO for troubleshooting, or an arbitrary number to override
Expand Down
181 changes: 154 additions & 27 deletions src/download/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,14 @@ async fn download_file_(
};

let res = backend
.download_to_path(url, path, resume_from_partial, Some(callback), timeout)
.download_to_path(
url,
path,
resume_from_partial,
Some(callback),
timeout,
process,
)
.await;

// The notification should only be sent if the download was successful (i.e. didn't timeout)
Expand All @@ -221,6 +228,19 @@ async fn download_file_(
res
}

#[cfg(any(
feature = "curl-backend",
feature = "reqwest-rustls-tls",
feature = "reqwest-native-tls"
))]
const RUSTUP_AUTHORIZATION_HEADER_ENV_VAR: &str = "RUSTUP_AUTHORIZATION_HEADER";
#[cfg(any(
feature = "curl-backend",
feature = "reqwest-rustls-tls",
feature = "reqwest-native-tls"
))]
const RUSTUP_PROXY_AUTHORIZATION_HEADER_ENV_VAR: &str = "RUSTUP_PROXY_AUTHORIZATION_HEADER";

/// User agent header value for HTTP request.
/// See: https://github.com/rust-lang/rustup/issues/2860.
#[cfg(feature = "curl-backend")]
Expand Down Expand Up @@ -253,9 +273,10 @@ impl Backend {
resume_from_partial: bool,
callback: Option<DownloadCallback<'_>>,
timeout: Duration,
process: &Process,
) -> anyhow::Result<()> {
let Err(err) = self
.download_impl(url, path, resume_from_partial, callback, timeout)
.download_impl(url, path, resume_from_partial, callback, timeout, process)
.await
else {
return Ok(());
Expand All @@ -278,6 +299,7 @@ impl Backend {
resume_from_partial: bool,
callback: Option<DownloadCallback<'_>>,
timeout: Duration,
process: &Process,
) -> anyhow::Result<()> {
use std::cell::RefCell;
use std::fs::OpenOptions;
Expand Down Expand Up @@ -337,17 +359,23 @@ impl Backend {
let file = RefCell::new(file);

// TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange.
self.download(url, resume_from, timeout, &|event| {
if let Event::DownloadDataReceived(data) = event {
file.borrow_mut()
.write_all(data)
.context("unable to write download to disk")?;
}
match callback {
Some(cb) => cb(event),
None => Ok(()),
}
})
self.download(
url,
resume_from,
timeout,
&|event| {
if let Event::DownloadDataReceived(data) = event {
file.borrow_mut()
.write_all(data)
.context("unable to write download to disk")?;
}
match callback {
Some(cb) => cb(event),
None => Ok(()),
}
},
process,
)
.await?;

file.borrow_mut()
Expand All @@ -371,12 +399,16 @@ impl Backend {
resume_from: u64,
timeout: Duration,
callback: DownloadCallback<'_>,
process: &Process,
) -> anyhow::Result<()> {
match self {
#[cfg(feature = "curl-backend")]
Self::Curl => curl::download(url, resume_from, callback, timeout),
Self::Curl => curl::download(url, resume_from, callback, timeout, process),
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
Self::Reqwest(tls) => tls.download(url, resume_from, callback, timeout).await,
Self::Reqwest(tls) => {
tls.download(url, resume_from, callback, timeout, process)
.await
}
}
}
}
Expand All @@ -398,12 +430,13 @@ impl TlsBackend {
resume_from: u64,
callback: DownloadCallback<'_>,
timeout: Duration,
process: &Process,
) -> anyhow::Result<()> {
let client = match self {
#[cfg(feature = "reqwest-rustls-tls")]
Self::Rustls => reqwest_be::rustls_client(timeout)?,
Self::Rustls => reqwest_be::rustls_client(timeout, process)?,
#[cfg(feature = "reqwest-native-tls")]
Self::NativeTls => reqwest_be::native_tls_client(timeout)?,
Self::NativeTls => reqwest_be::native_tls_client(timeout, process)?,
};

reqwest_be::download(url, resume_from, callback, client).await
Expand All @@ -430,16 +463,41 @@ mod curl {
use std::time::Duration;

use anyhow::{Context, Result};
use curl::easy::Easy;
use curl::easy::{Easy, List};
use tracing::debug;
use url::Url;

use super::{DownloadError, Event};
use super::{
DownloadError, Event, Process, RUSTUP_AUTHORIZATION_HEADER_ENV_VAR,
RUSTUP_PROXY_AUTHORIZATION_HEADER_ENV_VAR,
};

macro_rules! add_header_for_curl_easy_handle {
($handle:ident, $process:ident, $env_var:ident, $header_name:literal, $header_list:ident) => {
if let Some(rustup_header_value) = $process.var_opt($env_var).map_err(|error| {
anyhow::anyhow!(
"Internal error getting `{}` environment variable: {}",
$env_var,
anyhow::format_err!(error)
)
})? {
let list = $header_list.get_or_insert(List::new());
list.append(format!("{}: {}", $header_name, rustup_header_value).as_str())
.map_err(|_| {
// The error could contain sensitive data so give a generic error instead.
anyhow::anyhow!("Failed to add `{}` HTTP header.", $header_name)
})?;
debug!("Adding `{}` header.", $header_name);
}
};
}

pub(super) fn download(
url: &Url,
resume_from: u64,
callback: &dyn Fn(Event<'_>) -> Result<()>,
timeout: Duration,
process: &Process,
) -> Result<()> {
// Fetch either a cached libcurl handle (which will preserve open
// connections) or create a new one if it isn't listed.
Expand All @@ -453,6 +511,27 @@ mod curl {
handle.url(url.as_ref())?;
handle.follow_location(true)?;
handle.useragent(super::CURL_USER_AGENT)?;
let mut header_list: Option<List> = None;
add_header_for_curl_easy_handle!(
handle,
process,
RUSTUP_AUTHORIZATION_HEADER_ENV_VAR,
"Authorization",
header_list
);
add_header_for_curl_easy_handle!(
handle,
process,
RUSTUP_PROXY_AUTHORIZATION_HEADER_ENV_VAR,
"Proxy-Authorization",
header_list
);
if let Some(list) = header_list {
handle.http_headers(list).map_err(|_| {
// The error could contain sensitive data so give a generic error instead.
anyhow::anyhow!("Failed to add headers to curl easy handle.")
})?;
}

if resume_from > 0 {
handle.resume_from(resume_from)?;
Expand Down Expand Up @@ -557,7 +636,36 @@ mod reqwest_be {
use tokio_stream::StreamExt;
use url::Url;

use super::{DownloadError, Event};
use super::{
DownloadError, Event, Process, RUSTUP_AUTHORIZATION_HEADER_ENV_VAR,
RUSTUP_PROXY_AUTHORIZATION_HEADER_ENV_VAR, debug,
};

macro_rules! add_header_for_client_builder {
($client_builder:ident, $process:ident, $env_var:ident, $header_name:path) => {
if let Some(rustup_header_value) = $process.var_opt($env_var).map_err(|_| {
// The error could contain sensitive data so give a generic error instead.
DownloadError::Message(format!(
"Internal error getting `{}` environment variable",
$env_var
))
})? {
let mut headers = header::HeaderMap::new();
let mut auth_value =
header::HeaderValue::from_str(&rustup_header_value).map_err(|_| {
// The error could contain sensitive data so give a generic error instead.
DownloadError::Message(format!(
"The `{}` environment variable set to an invalid HTTP header value.",
$env_var
))
})?;
auth_value.set_sensitive(true);
headers.insert($header_name, auth_value);
$client_builder = $client_builder.default_headers(headers);
debug!("Added `{}` header.", $header_name);
}
};
}

pub(super) async fn download(
url: &Url,
Expand Down Expand Up @@ -592,18 +700,34 @@ mod reqwest_be {
Ok(())
}

fn client_generic() -> ClientBuilder {
Client::builder()
fn client_generic(process: &Process) -> Result<ClientBuilder, DownloadError> {
let mut client_builder = Client::builder()
// HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying
// `hyper` library that causes the `reqwest` client to hang in some cases.
// See <https://github.com/hyperium/hyper/issues/2312> for more details.
.pool_max_idle_per_host(0)
.gzip(false)
.proxy(Proxy::custom(env_proxy))
.proxy(Proxy::custom(env_proxy));
add_header_for_client_builder!(
client_builder,
process,
RUSTUP_AUTHORIZATION_HEADER_ENV_VAR,
header::AUTHORIZATION
);
add_header_for_client_builder!(
client_builder,
process,
RUSTUP_PROXY_AUTHORIZATION_HEADER_ENV_VAR,
header::PROXY_AUTHORIZATION
);
Ok(client_builder)
}

#[cfg(feature = "reqwest-rustls-tls")]
pub(super) fn rustls_client(timeout: Duration) -> Result<&'static Client, DownloadError> {
pub(super) fn rustls_client(
timeout: Duration,
process: &Process,
) -> Result<&'static Client, DownloadError> {
// If the client is already initialized, the passed timeout is ignored.
if let Some(client) = CLIENT_RUSTLS_TLS.get() {
return Ok(client);
Expand All @@ -627,7 +751,7 @@ mod reqwest_be {
.with_no_client_auth();
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];

let client = client_generic()
let client = client_generic(process)?
.read_timeout(timeout)
.use_preconfigured_tls(tls_config)
.user_agent(super::REQWEST_RUSTLS_TLS_USER_AGENT)
Expand All @@ -644,13 +768,16 @@ mod reqwest_be {
static CLIENT_RUSTLS_TLS: OnceLock<Client> = OnceLock::new();

#[cfg(feature = "reqwest-native-tls")]
pub(super) fn native_tls_client(timeout: Duration) -> Result<&'static Client, DownloadError> {
pub(super) fn native_tls_client(
timeout: Duration,
process: &Process,
) -> Result<&'static Client, DownloadError> {
// If the client is already initialized, the passed timeout is ignored.
if let Some(client) = CLIENT_NATIVE_TLS.get() {
return Ok(client);
}

let client = client_generic()
let client = client_generic(process)?
.read_timeout(timeout)
.user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT)
.build()
Expand Down
10 changes: 10 additions & 0 deletions src/download/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ mod curl {
use super::{scrub_env, serve_file, tmp_dir, write_file};
use crate::download::{Backend, Event};

#[cfg(feature = "test")]
use crate::process::TestProcess;

#[tokio::test]
async fn partially_downloaded_file_gets_resumed_from_byte_offset() {
let tmpdir = tmp_dir();
Expand All @@ -43,6 +46,7 @@ mod curl {
true,
None,
Duration::from_secs(180),
&TestProcess::default().process,
)
.await
.expect("Test download failed");
Expand Down Expand Up @@ -91,6 +95,7 @@ mod curl {
Ok(())
}),
Duration::from_secs(180),
&TestProcess::default().process,
)
.await
.expect("Test download failed");
Expand Down Expand Up @@ -120,6 +125,9 @@ mod reqwest {
use super::{scrub_env, serve_file, tmp_dir, write_file};
use crate::download::{Backend, Event, TlsBackend};

#[cfg(feature = "test")]
use crate::process::TestProcess;

// Tests for correctly retrieving the proxy (host, port) tuple from $https_proxy
#[tokio::test]
async fn read_basic_proxy_params() {
Expand Down Expand Up @@ -199,6 +207,7 @@ mod reqwest {
true,
None,
Duration::from_secs(180),
&TestProcess::default().process,
)
.await
.expect("Test download failed");
Expand Down Expand Up @@ -247,6 +256,7 @@ mod reqwest {
Ok(())
}),
Duration::from_secs(180),
&TestProcess::default().process,
)
.await
.expect("Test download failed");
Expand Down
2 changes: 1 addition & 1 deletion src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub use crate::cli::self_update::{RegistryGuard, RegistryValueId, USER_PATH, get
mod clitools;
pub use clitools::{
Assert, CliTestContext, Config, SanitizedOutput, Scenario, SelfUpdateTestContext,
output_release_file, print_command, print_indented,
TestContainer, TestContainerContext, output_release_file, print_command, print_indented,
};
pub(crate) mod dist;
pub use dist::DistContext;
Expand Down
Loading