Skip to content

Commit 63d93fa

Browse files
committed
feat: download dataset
1 parent 054d2d0 commit 63d93fa

File tree

11 files changed

+407
-3
lines changed

11 files changed

+407
-3
lines changed

client/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ pub enum S3Error {
5151
InvalidContentDisposition,
5252
#[error("Invalid Content-Disposition")]
5353
MissingBody,
54+
#[error("Missing header {0}")]
55+
MissingHeader(&'static str),
56+
#[error("Invalid header {0}")]
57+
InvalidHeader(&'static str),
5458
}
5559

5660
#[derive(Error, Debug)]

client/src/s3.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,61 @@ impl TryFrom<Response> for S3GetResponse {
7878
}
7979
}
8080

81+
pub struct S3HeadResponse {
82+
pub location: String,
83+
pub last_modified: String,
84+
pub size: u64,
85+
pub e_tag: Option<String>,
86+
pub version: Option<String>,
87+
}
88+
89+
impl TryFrom<Response> for S3HeadResponse {
90+
type Error = S3Error;
91+
92+
fn try_from(response: Response) -> Result<Self, Self::Error> {
93+
let headers = response.headers();
94+
95+
let last_modified = headers
96+
.get(http::header::LAST_MODIFIED)
97+
.ok_or(S3Error::MissingHeader("Last-Modified"))?
98+
.to_str()
99+
.map_err(|_| S3Error::InvalidHeader("Last-Modified"))?
100+
.to_string();
101+
102+
let size = headers
103+
.get(http::header::CONTENT_LENGTH)
104+
.ok_or(S3Error::MissingHeader("Content-Length"))?
105+
.to_str()
106+
.map_err(|_| S3Error::InvalidHeader("Content-Length"))?
107+
.parse::<u64>()
108+
.map_err(|_| S3Error::InvalidHeader("Content-Length"))?;
109+
110+
let e_tag = headers
111+
.get(http::header::ETAG)
112+
.and_then(|v| v.to_str().ok())
113+
.map(|v| v.trim_matches('"').to_string());
114+
115+
let version = headers
116+
.get("x-amz-version-id")
117+
.and_then(|v| v.to_str().ok())
118+
.map(|v| v.to_string());
119+
120+
let location = response
121+
.extensions()
122+
.get::<http::Uri>()
123+
.map(|uri| uri.path().to_string())
124+
.unwrap_or_default();
125+
126+
Ok(Self {
127+
location,
128+
last_modified,
129+
size,
130+
e_tag,
131+
version,
132+
})
133+
}
134+
}
135+
81136
impl Client {
82137
pub fn s3_layer<L, E>(&mut self, layer: L) -> &mut Self
83138
where
@@ -131,6 +186,25 @@ impl Client {
131186
check_status(&res.status())?;
132187
Ok(res.try_into()?)
133188
}
189+
190+
pub async fn s3_head(&self, url: Url) -> Result<S3HeadResponse> {
191+
self.validate_url_host(&url)?;
192+
193+
let body = Body::from(r#"{"head": true}"#.to_string());
194+
195+
let mut request = http::Request::builder()
196+
.method(http::Method::GET)
197+
.uri(url.to_string());
198+
199+
if let Some(content_length) = body.content_length() {
200+
if let Some(headers) = request.headers_mut() { headers.insert(reqwest::header::CONTENT_LENGTH, content_length.into()); }
201+
}
202+
203+
let request = request.body(body)?;
204+
let res = self.s3_service().oneshot(request).await?;
205+
check_status(&res.status())?;
206+
Ok(res.try_into()?)
207+
}
134208
}
135209

136210
pub struct S3Range {

src/commands/dataset/common.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub struct GetDatasetBySlug;
3232
pub struct GetDatasetSlugResponse {
3333
pub id: String,
3434
pub viewer_can_create_version: bool,
35+
pub viewer_can_read_dataset_version_file: bool,
3536
}
3637

3738
pub async fn get_dataset_by_slug(
@@ -60,11 +61,13 @@ pub async fn get_dataset_by_slug(
6061
return Ok(GetDatasetSlugResponse {
6162
id: new_dataset.id,
6263
viewer_can_create_version: new_dataset.viewer_can_create_version,
64+
viewer_can_read_dataset_version_file: new_dataset.viewer_can_read_dataset_version_file,
6365
});
6466
};
6567

6668
Ok(GetDatasetSlugResponse {
6769
id: dataset.id,
6870
viewer_can_create_version: dataset.viewer_can_create_version,
71+
viewer_can_read_dataset_version_file: dataset.viewer_can_read_dataset_version_file,
6972
})
7073
}

src/commands/dataset/download.rs

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
use clap::Args;
2+
use futures::{stream, StreamExt, TryStreamExt};
3+
use graphql_client::GraphQLQuery;
4+
use indicatif::{HumanBytes, MultiProgress, ProgressBar};
5+
use serde::Serialize;
6+
use std::path::PathBuf;
7+
use url::Url;
8+
9+
use crate::{
10+
commands::{
11+
dataset::{
12+
common::{get_dataset_by_slug, DatasetCommonArgs},
13+
download::get_dataset_version_files::GetDatasetVersionFilesNodeOnDatasetVersionFilesNodes,
14+
version::common::get_dataset_version,
15+
},
16+
GlobalArgs,
17+
},
18+
download::{multipart_download, MultipartOptions},
19+
error::{self, Result},
20+
};
21+
22+
#[derive(Args, Debug, Serialize)]
23+
pub struct Download {
24+
#[command(flatten)]
25+
common: DatasetCommonArgs,
26+
#[arg(short, long)]
27+
version: semver::Version,
28+
#[arg(short, long)]
29+
destination: PathBuf,
30+
#[clap(long, short = 'c', default_value_t = 10_000_000)]
31+
chunk_size: usize,
32+
#[clap(long, default_value_t = 10)]
33+
part_download_concurrency: usize,
34+
#[clap(long, default_value_t = 10)]
35+
file_download_concurrency: usize,
36+
}
37+
38+
#[derive(GraphQLQuery)]
39+
#[graphql(
40+
query_path = "src/graphql/get_dataset_version_files.graphql",
41+
schema_path = "schema.graphql",
42+
response_derives = "Debug"
43+
)]
44+
pub struct GetDatasetVersionFiles;
45+
46+
pub async fn download(args: Download, global: GlobalArgs) -> Result<()> {
47+
let m = MultiProgress::new();
48+
49+
let client = global.graphql_client().await?;
50+
51+
let (owner, local_slug) = args.common.slug_pair()?;
52+
let multipart_options = MultipartOptions::new(args.chunk_size, args.part_download_concurrency);
53+
54+
let dataset = get_dataset_by_slug(&global, owner, local_slug).await?;
55+
if !dataset.viewer_can_read_dataset_version_file {
56+
return Err(error::user(
57+
"Permission denied",
58+
"Cannot read dataset files",
59+
));
60+
}
61+
62+
let dataset_version = get_dataset_version(
63+
&client,
64+
dataset.id,
65+
args.version.major as _,
66+
args.version.minor as _,
67+
args.version.patch as _,
68+
)
69+
.await?
70+
.ok_or_else(|| error::user("Not found", "Dataset version not found"))?;
71+
72+
let response = client
73+
.send::<GetDatasetVersionFiles>(get_dataset_version_files::Variables {
74+
dataset_version_id: dataset_version.id,
75+
})
76+
.await?;
77+
78+
let dataset_version_files = match response.node {
79+
get_dataset_version_files::GetDatasetVersionFilesNode::DatasetVersion(v) => v,
80+
_ => {
81+
return Err(error::system(
82+
"Invalid node type",
83+
"Unexpected GraphQL response",
84+
))
85+
}
86+
};
87+
88+
let nodes = dataset_version_files.files.nodes;
89+
let dataset_name = dataset_version_files.dataset.name;
90+
91+
let dataset_dir = args.destination.join(&dataset_name);
92+
tokio::fs::create_dir_all(&dataset_dir).await?;
93+
94+
let total_size = dataset_version.size as u64;
95+
let total_files = nodes.len();
96+
97+
let overall_progress = m.add(global.spinner().with_message(format!(
98+
"Downloading '{}' ({} files, {})",
99+
dataset_name,
100+
total_files,
101+
HumanBytes(total_size)
102+
)));
103+
104+
stream::iter(nodes)
105+
.map(|node| {
106+
let client = &client;
107+
let m = &m;
108+
let multipart_options = &multipart_options;
109+
let dataset_dir = dataset_dir.to_owned();
110+
let dataset_name = dataset_name.to_owned();
111+
112+
async move {
113+
download_partition_file(
114+
&m,
115+
&client,
116+
&multipart_options,
117+
&dataset_dir,
118+
&dataset_name,
119+
node,
120+
)
121+
.await
122+
}
123+
})
124+
.buffer_unordered(args.file_download_concurrency)
125+
.try_collect::<()>()
126+
.await?;
127+
128+
overall_progress.finish_with_message("Done");
129+
130+
Ok(())
131+
}
132+
133+
async fn download_partition_file(
134+
m: &MultiProgress,
135+
client: &aqora_client::Client,
136+
multipart_options: &MultipartOptions,
137+
output_dir: &std::path::Path,
138+
dataset_name: &str,
139+
file_node: GetDatasetVersionFilesNodeOnDatasetVersionFilesNodes,
140+
) -> Result<()> {
141+
let metadata = client.s3_head(file_node.url.clone()).await?;
142+
let filename = format!("{}-{}.parquet", dataset_name, file_node.partition_num);
143+
let output_path = output_dir.join(&filename);
144+
145+
if let Ok(existing) = tokio::fs::metadata(&output_path).await {
146+
if existing.len() == metadata.size {
147+
return Ok(());
148+
}
149+
}
150+
151+
tokio::fs::create_dir_all(output_path.parent().unwrap()).await?;
152+
153+
let temp = tempfile::NamedTempFile::new_in(output_dir)?;
154+
let temp_path = temp.path().to_owned();
155+
156+
let pb = m.add(ProgressBar::new_spinner());
157+
pb.set_message(filename);
158+
159+
multipart_download(
160+
client,
161+
metadata.size,
162+
file_node.url.clone(),
163+
multipart_options,
164+
&temp_path,
165+
&pb,
166+
)
167+
.await?;
168+
169+
pb.finish_and_clear();
170+
tokio::fs::rename(&temp_path, &output_path).await?;
171+
172+
Ok(())
173+
}

src/commands/dataset/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod common;
22
mod convert;
3+
mod download;
34
mod infer;
45
mod new;
56
mod upload;
@@ -13,6 +14,7 @@ use crate::commands::GlobalArgs;
1314
use crate::error::Result;
1415

1516
use convert::{convert, Convert};
17+
use download::{download, Download};
1618
use infer::{infer, Infer};
1719
use new::{new, New};
1820
use upload::{upload, Upload};
@@ -26,6 +28,7 @@ pub enum Dataset {
2628
Convert(Convert),
2729
New(New),
2830
Upload(Upload),
31+
Download(Download),
2932
Version {
3033
#[command(subcommand)]
3134
args: Version,
@@ -38,6 +41,7 @@ pub async fn dataset(args: Dataset, global: GlobalArgs) -> Result<()> {
3841
Dataset::Convert(args) => convert(args, global).await,
3942
Dataset::New(args) => new(args, global).await,
4043
Dataset::Upload(args) => upload(args, global).await,
44+
Dataset::Download(args) => download(args, global).await,
4145
Dataset::Version { args } => version(args, global).await,
4246
}
4347
}

src/commands/global_args.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use std::path::PathBuf;
1616
use url::Url;
1717

1818
lazy_static::lazy_static! {
19-
static ref DEFAULT_PARALLELISM: usize = std::thread::available_parallelism()
19+
pub static ref DEFAULT_PARALLELISM: usize = std::thread::available_parallelism()
2020
.map(usize::from)
2121
.unwrap_or(1);
2222
}

0 commit comments

Comments
 (0)