Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion ampup/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ console = "0.16"
dialoguer = "0.12"
fs-err = "3.0.0"
futures = "0.3"
indicatif = "0.18"
reqwest = { version = "0.13", default-features = false, features = [
"json",
"query",
Expand Down
57 changes: 48 additions & 9 deletions ampup/src/download_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use anyhow::{Context, Result};
use fs_err as fs;
use tokio::{sync::Semaphore, task::JoinSet};

use crate::github::{GitHubClient, ResolvedAsset};
use crate::{
github::{GitHubClient, ResolvedAsset},
progress::ProgressReporter,
};

// ---------------------------------------------------------------------------
// Public types
Expand Down Expand Up @@ -168,6 +171,7 @@ impl DownloadManager {
tasks: Vec<DownloadTask>,
version: &str,
version_dir: PathBuf,
reporter: Arc<dyn ProgressReporter>,
) -> Result<()> {
// Resolve all asset metadata with a single API call so that each
// spawned task can download directly without re-fetching the release.
Expand All @@ -185,13 +189,17 @@ impl DownloadManager {
let staging_dir =
tempfile::tempdir_in(parent).context("Failed to create staging directory")?;

let names: Vec<String> = tasks.iter().map(|t| t.artifact_name.clone()).collect();
reporter.set_total(tasks.len(), names);

let semaphore = Arc::new(Semaphore::new(self.max_concurrent));
let mut join_set: JoinSet<std::result::Result<(), DownloadError>> = JoinSet::new();
let mut join_set: JoinSet<std::result::Result<String, DownloadError>> = JoinSet::new();

for (task, asset) in tasks.into_iter().zip(resolved) {
let github = self.github.clone();
let sem = semaphore.clone();
let staging_path = staging_dir.path().to_path_buf();
let reporter = reporter.clone();

join_set.spawn(async move {
let _permit = sem
Expand All @@ -201,29 +209,39 @@ impl DownloadManager {
artifact_name: task.artifact_name.clone(),
})?;

reporter.component_started(&task.artifact_name);

let data = download_with_retry(&github, &asset).await?;
verify_artifact(&task.artifact_name, &data)?;
write_to_staging(&staging_path, &task.dest_filename, &data)?;

Ok(())
Ok(task.artifact_name)
});
}

// Collect results — fail fast on first error
while let Some(result) = join_set.join_next().await {
match result {
Ok(Ok(())) => {}
Ok(Ok(artifact_name)) => {
reporter.component_completed(&artifact_name);
}
Ok(Err(e)) => {
let artifact_name = download_error_artifact_name(&e);
reporter.component_failed(artifact_name);
reporter.finish();
join_set.shutdown().await;
return Err(e.into());
}
Err(join_err) => {
reporter.finish();
join_set.shutdown().await;
return Err(anyhow::anyhow!("download task panicked: {}", join_err));
}
}
}

reporter.finish();

// Set executable permissions on all staged files
#[cfg(unix)]
set_executable_permissions(staging_dir.path())?;
Expand Down Expand Up @@ -313,15 +331,13 @@ async fn download_with_retry(
github: &GitHubClient,
asset: &ResolvedAsset,
) -> std::result::Result<Vec<u8>, DownloadError> {
// `false` suppresses per-file progress bars — DownloadManager will
// provide aggregate progress reporting in a future PR.
match github.download_resolved_asset(asset, false).await {
match github.download_resolved_asset(asset).await {
Ok(data) => Ok(data),
Err(first_err) => {
crate::ui::warn!("Download failed for {}, retrying once...", asset.name);

github
.download_resolved_asset(asset, false)
.download_resolved_asset(asset)
.await
.map_err(|retry_err| DownloadError::TaskFailed {
artifact_name: asset.name.clone(),
Expand Down Expand Up @@ -357,6 +373,16 @@ fn write_to_staging(
})
}

/// Extract the artifact name from a [`DownloadError`].
fn download_error_artifact_name(err: &DownloadError) -> &str {
match err {
DownloadError::TaskFailed { artifact_name, .. }
| DownloadError::EmptyArtifact { artifact_name }
| DownloadError::StagingWrite { artifact_name, .. }
| DownloadError::SemaphoreClosed { artifact_name } => artifact_name,
}
}

/// Set executable permissions (0o755) on all files in a directory.
#[cfg(unix)]
fn set_executable_permissions(dir: &Path) -> Result<()> {
Expand Down Expand Up @@ -548,6 +574,18 @@ mod tests {
use tokio::io::{AsyncReadExt, AsyncWriteExt};

use super::*;
use crate::progress::ProgressReporter;

/// No-op reporter for tests that don't need progress output.
struct NoopReporter;

impl ProgressReporter for NoopReporter {
fn set_total(&self, _total: usize, _names: Vec<String>) {}
fn component_started(&self, _name: &str) {}
fn component_completed(&self, _name: &str) {}
fn component_failed(&self, _name: &str) {}
fn finish(&self) {}
}

/// Route configuration for the mock HTTP server.
#[derive(Clone)]
Expand Down Expand Up @@ -729,8 +767,9 @@ mod tests {

/// Run `download_all` with the given tasks.
async fn download(&self, tasks: Vec<DownloadTask>) -> Result<()> {
let reporter: Arc<dyn ProgressReporter> = Arc::new(NoopReporter);
self.manager
.download_all(tasks, "v1.0.0", self.version_dir.clone())
.download_all(tasks, "v1.0.0", self.version_dir.clone(), reporter)
.await
}
}
Expand Down
92 changes: 12 additions & 80 deletions ampup/src/github.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::sync::Arc;

use anyhow::{Context, Result};
use futures::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use serde::Deserialize;

use crate::rate_limiter::GitHubRateLimiter;
Expand Down Expand Up @@ -353,17 +352,11 @@ impl GitHubClient {

/// Download a previously resolved asset without re-fetching release
/// metadata.
pub async fn download_resolved_asset(
&self,
asset: &ResolvedAsset,
show_progress: bool,
) -> Result<Vec<u8>> {
pub async fn download_resolved_asset(&self, asset: &ResolvedAsset) -> Result<Vec<u8>> {
if self.token.is_some() {
self.download_asset_via_api(asset.id, &asset.name, show_progress)
.await
self.download_asset_via_api(asset.id, &asset.name).await
} else {
self.download_asset_direct(&asset.url, &asset.name, show_progress)
.await
self.download_asset_direct(&asset.url, &asset.name).await
}
}

Expand Down Expand Up @@ -506,37 +499,21 @@ impl GitHubClient {
}

/// Download a release asset by name.
///
/// When `show_progress` is `true`, an indicatif progress bar is rendered for
/// this individual download. Set to `false` when downloads are managed by
/// `DownloadManager` (which will provide aggregate progress in the future).
pub async fn download_release_asset(
&self,
version: &str,
asset_name: &str,
show_progress: bool,
) -> Result<Vec<u8>> {
pub async fn download_release_asset(&self, version: &str, asset_name: &str) -> Result<Vec<u8>> {
let release = self.get_tagged_release(version).await?;
let asset = self.find_asset(&release, asset_name, version)?;

if self.token.is_some() {
// For private repositories, we need to use the API to download
self.download_asset_via_api(asset.id, asset_name, show_progress)
.await
self.download_asset_via_api(asset.id, asset_name).await
} else {
// For public repositories, use direct download URL
self.download_asset_direct(&asset.url, asset_name, show_progress)
.await
self.download_asset_direct(&asset.url, asset_name).await
}
}

/// Download asset via GitHub API (for private repos)
async fn download_asset_via_api(
&self,
asset_id: u64,
asset_name: &str,
show_progress: bool,
) -> Result<Vec<u8>> {
async fn download_asset_via_api(&self, asset_id: u64, asset_name: &str) -> Result<Vec<u8>> {
let url = format!(
"https://api.github.com/repos/{}/releases/assets/{}",
self.repo, asset_id
Expand All @@ -553,35 +530,24 @@ impl GitHubClient {
)
.await?;

self.download_with_progress(response, &url, asset_name, show_progress)
.await
self.download_response(response, &url, asset_name).await
}

/// Download asset directly (for public repos)
async fn download_asset_direct(
&self,
url: &str,
asset_name: &str,
show_progress: bool,
) -> Result<Vec<u8>> {
async fn download_asset_direct(&self, url: &str, asset_name: &str) -> Result<Vec<u8>> {
let response = self
.send_with_rate_limit(|| self.client.get(url), "Failed to download asset")
.await?;

self.download_with_progress(response, url, asset_name, show_progress)
.await
self.download_response(response, url, asset_name).await
}

/// Download with optional progress bar from a response.
///
/// When `show_progress` is `false`, bytes are collected silently (used by
/// `DownloadManager` which manages its own aggregate progress reporting).
async fn download_with_progress(
/// Stream a response body into a buffer.
async fn download_response(
&self,
response: reqwest::Response,
url: &str,
asset_name: &str,
show_progress: bool,
) -> Result<Vec<u8>> {
if !response.status().is_success() {
let status = response.status();
Expand All @@ -594,47 +560,13 @@ impl GitHubClient {
.into());
}

// Setup progress bar (hidden when DownloadManager handles progress)
let pb = if show_progress {
let total_size = response.content_length();
if let Some(size) = total_size {
let pb = ProgressBar::new(size);
pb.set_style(
ProgressStyle::default_bar()
.template(
"{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})",
)
.context("Invalid progress bar template")?
.progress_chars("#>-"),
);
pb.set_message(format!("{} Downloading", console::style("→").cyan()));
pb
} else {
let pb = ProgressBar::new_spinner();
pb.set_message(format!(
"{} Downloading (size unknown)",
console::style("→").cyan()
));
pb
}
} else {
ProgressBar::hidden()
};

// Stream and collect chunks
let mut downloaded: u64 = 0;
let mut buffer = Vec::new();
let mut stream = response.bytes_stream();

while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Error while downloading file")?;
buffer.extend_from_slice(&chunk);
downloaded += chunk.len() as u64;
pb.set_position(downloaded);
}

if show_progress {
pb.finish_with_message(format!("{} Downloaded", console::style("✓").green().bold()));
}

Ok(buffer)
Expand Down
5 changes: 3 additions & 2 deletions ampup/src/install.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use anyhow::Result;
use crate::{
download_manager::{DownloadManager, DownloadTask},
platform::{Architecture, Platform},
ui,
progress, ui,
version_manager::VersionManager,
};

Expand Down Expand Up @@ -50,10 +50,11 @@ impl Installer {
},
];

let reporter = progress::create_reporter();
let version_dir = self.version_manager.config().versions_dir.join(version);

self.download_manager
.download_all(tasks, version, version_dir)
.download_all(tasks, version, version_dir, reporter)
.await?;

// Activation barrier: all downloads succeeded, now create symlinks
Expand Down
1 change: 1 addition & 0 deletions ampup/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod download_manager;
pub mod github;
pub mod install;
pub mod platform;
pub mod progress;
pub mod rate_limiter;
pub mod shell;
pub mod token;
Expand Down
Loading