@@ -43,6 +43,14 @@ pub struct Download {
4343) ]
4444pub struct GetDatasetVersionFiles ;
4545
46+ #[ derive( GraphQLQuery ) ]
47+ #[ graphql(
48+ query_path = "src/graphql/get_dataset_version_file_by_partition.graphql" ,
49+ schema_path = "schema.graphql" ,
50+ response_derives = "Debug"
51+ ) ]
52+ pub struct GetDatasetVersionFileByPartition ;
53+
4654pub async fn download ( args : Download , global : GlobalArgs ) -> Result < ( ) > {
4755 let m = MultiProgress :: new ( ) ;
4856
@@ -103,9 +111,9 @@ pub async fn download(args: Download, global: GlobalArgs) -> Result<()> {
103111
104112 stream:: iter ( nodes)
105113 . map ( |node| {
106- let client = & client;
107- let m = & m ;
108- let multipart_options = & multipart_options;
114+ let client = client. to_owned ( ) ;
115+ let m = m . to_owned ( ) ;
116+ let multipart_options = multipart_options. to_owned ( ) ;
109117 let dataset_dir = dataset_dir. to_owned ( ) ;
110118 let dataset_name = dataset_name. to_owned ( ) ;
111119
@@ -138,7 +146,44 @@ async fn download_partition_file(
138146 dataset_name : & str ,
139147 file_node : GetDatasetVersionFilesNodeOnDatasetVersionFilesNodes ,
140148) -> Result < ( ) > {
141- let metadata = client. s3_head ( file_node. url . clone ( ) ) . await ?;
149+ let ( metadata, url) = match client. s3_head ( file_node. url . clone ( ) ) . await {
150+ Ok ( metadata) => ( metadata, file_node. url . clone ( ) ) ,
151+ // retry if presigned url expired due to long dataset download time
152+ Err ( e) => {
153+ tracing:: warn!( error = %e, "Retrying: failed to fetch object header" ) ;
154+ let response = client
155+ . send :: < GetDatasetVersionFileByPartition > (
156+ get_dataset_version_file_by_partition:: Variables {
157+ dataset_version_id : file_node. dataset_version . id ,
158+ partition_num : file_node. partition_num ,
159+ } ,
160+ )
161+ . await ?;
162+
163+ let dataset_version_file = match response. node {
164+ get_dataset_version_file_by_partition:: GetDatasetVersionFileByPartitionNode :: DatasetVersion ( v) => v,
165+ _ => {
166+ return Err ( error:: system (
167+ "Invalid node type" ,
168+ "Unexpected GraphQL response" ,
169+ ) ) ;
170+ }
171+ } ;
172+ let file_by_partition_num = match dataset_version_file. file_by_partition_num {
173+ Some ( file_by_partition_num) => file_by_partition_num,
174+ None => {
175+ return Err ( error:: system (
176+ "Invalid partition number" ,
177+ "The partition does not exist" ,
178+ ) )
179+ }
180+ } ;
181+
182+ let file_url = file_by_partition_num. url ;
183+ ( client. s3_head ( file_url. clone ( ) ) . await ?, file_url)
184+ }
185+ } ;
186+
142187 let filename = format ! ( "{}-{}.parquet" , dataset_name, file_node. partition_num) ;
143188 let output_path = output_dir. join ( & filename) ;
144189
@@ -159,7 +204,7 @@ async fn download_partition_file(
159204 multipart_download (
160205 client,
161206 metadata. size ,
162- file_node . url . clone ( ) ,
207+ url,
163208 multipart_options,
164209 & temp_path,
165210 & pb,
0 commit comments