15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
+ use arrow:: datatypes:: ToByteSlice ;
18
19
use async_trait:: async_trait;
19
20
use bytes:: Bytes ;
20
21
use datafusion:: common:: { config_err, Result } ;
21
22
use datafusion:: config:: {
22
23
ConfigEntry , ConfigExtension , ConfigField , ExtensionOptions , Visit ,
23
24
} ;
24
25
use datafusion:: error:: DataFusionError ;
26
+ use futures:: future:: join_all;
25
27
use futures:: stream:: BoxStream ;
26
- use futures:: StreamExt ;
28
+ use futures:: { StreamExt , TryStreamExt } ;
27
29
use http:: { header, HeaderMap } ;
28
30
use object_store:: http:: { HttpBuilder , HttpStore } ;
29
31
use object_store:: path:: Path ;
@@ -32,6 +34,8 @@ use object_store::{
32
34
MultipartId , ObjectMeta , ObjectStore , PutOptions , PutResult ,
33
35
Result as ObjectStoreResult ,
34
36
} ;
37
+ use serde:: Deserialize ;
38
+ use serde_json;
35
39
use std:: any:: Any ;
36
40
use std:: env;
37
41
use std:: fmt:: Display ;
@@ -98,7 +102,7 @@ impl ParsedHFUrl {
98
102
/// If the endpoint is not provided, it defaults to `https://huggingface.co`.
99
103
///
100
104
/// url: The HuggingFace URL to parse.
101
- pub fn parse ( url : String ) -> Result < Self > {
105
+ pub fn parse_hf_style ( url : String ) -> Result < Self > {
102
106
let mut parsed_url = Self :: default ( ) ;
103
107
let mut last_delim = 0 ;
104
108
@@ -168,14 +172,101 @@ impl ParsedHFUrl {
168
172
Ok ( parsed_url)
169
173
}
170
174
175
+ /// Parse a http style HuggingFace URL into a ParsedHFUrl struct.
176
+ /// The URL should be in the format `https://huggingface.co/<repo_type>/<repository>/resolve/<revision>/<path>`
177
+ /// where `repo_type` is either `datasets` or `spaces`.
178
+ ///
179
+ /// url: The HuggingFace URL to parse.
180
+ fn parse_http_style ( url : String ) -> Result < Self > {
181
+ let mut parsed_url = Self :: default ( ) ;
182
+ let mut last_delim = 0 ;
183
+
184
+ // parse repository type.
185
+ if let Some ( curr_delim) = url[ last_delim..] . find ( '/' ) {
186
+ let repo_type = & url[ last_delim..last_delim + curr_delim] ;
187
+ if ( repo_type != "datasets" ) && ( repo_type != "spaces" ) {
188
+ return config_err ! (
189
+ "Invalid HuggingFace URL: {}, currently only 'datasets' or 'spaces' are supported" ,
190
+ url
191
+ ) ;
192
+ }
193
+
194
+ parsed_url. repo_type = Some ( repo_type. to_string ( ) ) ;
195
+ last_delim += curr_delim + 1 ;
196
+ } else {
197
+ return config_err ! ( "Invalid HuggingFace URL: {}, please format as 'https://huggingface.co/<repo_type>/<repository>/resolve/<revision>/<path>'" , url) ;
198
+ }
199
+
200
+ let start_delim = last_delim;
201
+ // parse repository and revision.
202
+ if let Some ( curr_delim) = url[ last_delim..] . find ( '/' ) {
203
+ last_delim += curr_delim + 1 ;
204
+ } else {
205
+ return config_err ! ( "Invalid HuggingFace URL: {}, please format as 'https://huggingface.co/<repo_type>/<repository>/resolve/<revision>/<path>'" , url) ;
206
+ }
207
+
208
+ let next_slash = url[ last_delim..] . find ( '/' ) ;
209
+
210
+ // next slash is not found
211
+ if next_slash. is_none ( ) {
212
+ return config_err ! ( "Invalid HuggingFace URL: {}, please format as 'https://huggingface.co/<repo_type>/<repository>/resolve/<revision>/<path>'" , url) ;
213
+ }
214
+
215
+ parsed_url. repository = Some ( url[ start_delim..last_delim + next_slash. unwrap ( ) ] . to_string ( ) ) ;
216
+ last_delim += next_slash. unwrap ( ) ;
217
+
218
+ let next_resolve = url[ last_delim..] . find ( "resolve" ) ;
219
+ if next_resolve. is_none ( ) {
220
+ return config_err ! ( "Invalid HuggingFace URL: {}, please format as 'https://huggingface.co/<repo_type>/<repository>/resolve/<revision>/<path>'" , url) ;
221
+ }
222
+
223
+ last_delim += next_resolve. unwrap ( ) + "resolve" . len ( ) ;
224
+
225
+ let next_slash = url[ last_delim + 1 ..] . find ( '/' ) ;
226
+ if next_slash. is_none ( ) {
227
+ return config_err ! ( "Invalid HuggingFace URL: {}, please format as 'https://huggingface.co/<repo_type>/<repository>/resolve/<revision>/<path>'" , url) ;
228
+ }
229
+
230
+ parsed_url. revision = Some ( url[ last_delim + 1 ..last_delim + 1 + next_slash. unwrap ( ) ] . to_string ( ) ) ;
231
+ last_delim += 1 + next_slash. unwrap ( ) ;
232
+
233
+ // parse path.
234
+ let path = & url[ last_delim + 1 ..] ;
235
+ parsed_url. path = Some ( path. to_string ( ) ) ;
236
+
237
+ Ok ( parsed_url)
238
+ }
239
+
240
+ pub fn hf_path ( & self ) -> String {
241
+ let mut url = self . repository . as_deref ( ) . unwrap ( ) . to_string ( ) ;
242
+
243
+ if let Some ( revision) = & self . revision {
244
+ if revision != "main" {
245
+ url. push ( '@' ) ;
246
+ url. push_str ( revision) ;
247
+ }
248
+ }
249
+
250
+ url. push ( '/' ) ;
251
+ url. push_str ( self . path . as_deref ( ) . unwrap ( ) ) ;
252
+
253
+ url
254
+ }
255
+
171
256
pub fn file_path ( & self ) -> String {
257
+ let mut url = self . file_path_prefix ( ) ;
258
+ url. push ( '/' ) ;
259
+ url. push_str ( self . path . as_deref ( ) . unwrap ( ) ) ;
260
+
261
+ url
262
+ }
263
+
264
+ pub fn file_path_prefix ( & self ) -> String {
172
265
let mut url = self . repo_type . clone ( ) . unwrap ( ) ;
173
266
url. push ( '/' ) ;
174
267
url. push_str ( self . repository . as_deref ( ) . unwrap ( ) ) ;
175
268
url. push_str ( "/resolve/" ) ;
176
269
url. push_str ( self . revision . as_deref ( ) . unwrap ( ) ) ;
177
- url. push ( '/' ) ;
178
- url. push_str ( self . path . as_deref ( ) . unwrap ( ) ) ;
179
270
180
271
url
181
272
}
@@ -386,7 +477,20 @@ pub fn get_hf_object_store_builder(
386
477
pub struct HFStore {
387
478
endpoint : String ,
388
479
repo_type : String ,
389
- store : Arc < HttpStore > ,
480
+ store : Arc < dyn ObjectStore > ,
481
+ }
482
+
483
+ #[ derive( Debug , Deserialize ) ]
484
+ pub struct HFTreeEntry {
485
+ pub r#type : String ,
486
+ pub path : String ,
487
+ pub oid : String ,
488
+ }
489
+
490
+ impl HFTreeEntry {
491
+ pub fn is_file ( & self ) -> bool {
492
+ self . r#type == "file"
493
+ }
390
494
}
391
495
392
496
impl HFStore {
@@ -436,19 +540,16 @@ impl ObjectStore for HFStore {
436
540
location : & Path ,
437
541
options : GetOptions ,
438
542
) -> ObjectStoreResult < GetResult > {
439
- println ! ( "GETTING: {}" , location) ;
440
-
441
543
let formatted_location = format ! ( "{}/{}" , self . repo_type, location) ;
442
544
443
- let Ok ( parsed_url) = ParsedHFUrl :: parse ( formatted_location) else {
545
+ let Ok ( parsed_url) = ParsedHFUrl :: parse_hf_style ( formatted_location) else {
444
546
return Err ( ObjectStoreError :: Generic {
445
547
store : STORE ,
446
548
source : format ! ( "Unable to parse url {location}" ) . into ( ) ,
447
549
} ) ;
448
550
} ;
449
551
450
552
let file_path = parsed_url. file_path ( ) ;
451
- println ! ( "FILE_PATH: {:?}" , file_path) ;
452
553
453
554
let Ok ( file_path) = Path :: parse ( file_path. clone ( ) ) else {
454
555
return Err ( ObjectStoreError :: Generic {
@@ -458,7 +559,7 @@ impl ObjectStore for HFStore {
458
559
} ;
459
560
460
561
let mut res = self . store . get_opts ( & file_path, options) . await ?;
461
-
562
+
462
563
res. meta . location = location. clone ( ) ;
463
564
Ok ( res)
464
565
}
@@ -469,9 +570,8 @@ impl ObjectStore for HFStore {
469
570
470
571
async fn list_with_delimiter (
471
572
& self ,
472
- prefix : Option < & Path > ,
573
+ _prefix : Option < & Path > ,
473
574
) -> ObjectStoreResult < ListResult > {
474
- println ! ( "LISTING_WITH_DELIMITER: {:?}" , prefix) ;
475
575
476
576
Err ( ObjectStoreError :: NotImplemented )
477
577
}
@@ -480,6 +580,7 @@ impl ObjectStore for HFStore {
480
580
& self ,
481
581
prefix : Option < & Path > ,
482
582
) -> BoxStream < ' _ , ObjectStoreResult < ObjectMeta > > {
583
+
483
584
let Some ( prefix) = prefix else {
484
585
return futures:: stream:: once ( async {
485
586
Err ( ObjectStoreError :: Generic {
@@ -491,26 +592,71 @@ impl ObjectStore for HFStore {
491
592
} ;
492
593
493
594
let formatted_prefix = format ! ( "{}/{}" , self . repo_type, prefix) ;
494
- let Ok ( parsed_url) = ParsedHFUrl :: parse ( formatted_prefix. clone ( ) ) else {
595
+ let Ok ( parsed_url) = ParsedHFUrl :: parse_hf_style ( formatted_prefix. clone ( ) ) else {
495
596
return futures:: stream:: once ( async move {
496
597
Err ( ObjectStoreError :: Generic {
497
598
store : STORE ,
498
- source : format ! ( "Unable to parse url {}" , formatted_prefix. clone( ) ) . into ( ) ,
599
+ source : format ! ( "Unable to parse url {}" , formatted_prefix. clone( ) )
600
+ . into ( ) ,
499
601
} )
500
602
} )
501
603
. boxed ( ) ;
502
604
} ;
503
605
504
- let tree_path = Path :: from ( parsed_url. tree_path ( ) ) ;
505
- println ! ( "LISTING: {:?}" , tree_path ) ;
606
+ let tree_path = parsed_url. tree_path ( ) ;
607
+ let file_path_prefix = parsed_url . file_path_prefix ( ) ;
506
608
507
609
futures:: stream:: once ( async move {
508
- let result = self . store . get ( & tree_path) . await ;
509
-
510
- println ! ( "RESULT: {:?}" , result) ;
610
+ let result = self . store . get ( & Path :: from ( tree_path) ) . await ?;
611
+ let Ok ( bytes) = result. bytes ( ) . await else {
612
+ return Err ( ObjectStoreError :: Generic {
613
+ store : STORE ,
614
+ source : "Unable to get list body" . into ( ) ,
615
+ } ) ;
616
+ } ;
617
+
511
618
512
- Err ( ObjectStoreError :: NotImplemented )
619
+ let Ok ( tree_result) =
620
+ serde_json:: from_slice :: < Vec < HFTreeEntry > > ( bytes. to_byte_slice ( ) )
621
+ else {
622
+ return Err ( ObjectStoreError :: Generic {
623
+ store : STORE ,
624
+ source : "Unable to parse list body" . into ( ) ,
625
+ } ) ;
626
+ } ;
627
+
628
+ let iter = join_all (
629
+ tree_result
630
+ . into_iter ( )
631
+ . filter ( |entry| entry. is_file ( ) )
632
+ . map ( |entry| format ! ( "{}/{}" , file_path_prefix, entry. path. clone( ) ) )
633
+ . map ( |meta_location| async {
634
+ self . store . head ( & Path :: from ( meta_location) ) . await
635
+ } ) ,
636
+ )
637
+ . await
638
+ . into_iter ( )
639
+ . map ( |result| {
640
+ result. and_then ( |mut meta| {
641
+ let Ok ( location) = ParsedHFUrl :: parse_http_style ( meta. location . to_string ( ) ) else {
642
+ return Err ( ObjectStoreError :: Generic {
643
+ store : STORE ,
644
+ source : format ! ( "Unable to parse location {}" , meta. location)
645
+ . into ( ) ,
646
+ } ) ;
647
+ } ;
648
+ meta. location = Path :: from ( location. hf_path ( ) ) ;
649
+ if let Some ( e_tag) = meta. e_tag . as_deref ( ) {
650
+ meta. e_tag = Some ( e_tag. replace ( "\" " , "" ) ) ;
651
+ }
652
+
653
+ Ok ( meta)
654
+ } )
655
+ } ) ;
656
+
657
+ Ok :: < _ , ObjectStoreError > ( futures:: stream:: iter ( iter) )
513
658
} )
659
+ . try_flatten ( )
514
660
. boxed ( )
515
661
}
516
662
@@ -537,7 +683,7 @@ mod tests {
537
683
fn test_parse_hf_url ( ) {
538
684
let url = "datasets/datasets-examples/doc-formats-csv-1/data.csv" . to_string ( ) ;
539
685
540
- let parsed_url = ParsedHFUrl :: parse ( url) . unwrap ( ) ;
686
+ let parsed_url = ParsedHFUrl :: parse_hf_style ( url) . unwrap ( ) ;
541
687
542
688
assert_eq ! ( parsed_url. repo_type, Some ( "datasets" . to_string( ) ) ) ;
543
689
assert_eq ! (
@@ -553,7 +699,7 @@ mod tests {
553
699
let url =
554
700
"datasets/datasets-examples/doc-formats-csv-1@~csv/data.csv" . to_string ( ) ;
555
701
556
- let parsed_url = ParsedHFUrl :: parse ( url) . unwrap ( ) ;
702
+ let parsed_url = ParsedHFUrl :: parse_hf_style ( url) . unwrap ( ) ;
557
703
558
704
assert_eq ! ( parsed_url. repo_type, Some ( "datasets" . to_string( ) ) ) ;
559
705
assert_eq ! (
@@ -587,11 +733,27 @@ mod tests {
587
733
) ;
588
734
}
589
735
736
+ #[ test]
737
+ fn test_parse_http_url ( ) {
738
+ let url = "datasets/datasets-examples/doc-formats-csv-1/resolve/main/data.csv" . to_string ( ) ;
739
+
740
+ let parsed_url = ParsedHFUrl :: parse_http_style ( url) . unwrap ( ) ;
741
+
742
+ assert_eq ! ( parsed_url. repo_type, Some ( "datasets" . to_string( ) ) ) ;
743
+ assert_eq ! (
744
+ parsed_url. repository,
745
+ Some ( "datasets-examples/doc-formats-csv-1" . to_string( ) )
746
+ ) ;
747
+ assert_eq ! ( parsed_url. revision, Some ( "main" . to_string( ) ) ) ;
748
+ assert_eq ! ( parsed_url. path, Some ( "data.csv" . to_string( ) ) ) ;
749
+ }
750
+
751
+
590
752
#[ test]
591
753
fn test_file_path ( ) {
592
754
let url = "datasets/datasets-examples/doc-formats-csv-1/data.csv" . to_string ( ) ;
593
755
594
- let parsed_url = ParsedHFUrl :: parse ( url) ;
756
+ let parsed_url = ParsedHFUrl :: parse_hf_style ( url) ;
595
757
596
758
assert ! ( parsed_url. is_ok( ) ) ;
597
759
@@ -607,7 +769,7 @@ mod tests {
607
769
fn test_tree_path ( ) {
608
770
let url = "datasets/datasets-examples/doc-formats-csv-1/data.csv" . to_string ( ) ;
609
771
610
- let parsed_url = ParsedHFUrl :: parse ( url) ;
772
+ let parsed_url = ParsedHFUrl :: parse_hf_style ( url) ;
611
773
612
774
assert ! ( parsed_url. is_ok( ) ) ;
613
775
@@ -620,7 +782,7 @@ mod tests {
620
782
}
621
783
622
784
fn test_error ( url : & str , expected : & str ) {
623
- let parsed_url_result = ParsedHFUrl :: parse ( url. to_string ( ) ) ;
785
+ let parsed_url_result = ParsedHFUrl :: parse_hf_style ( url. to_string ( ) ) ;
624
786
625
787
match parsed_url_result {
626
788
Ok ( _) => panic ! ( "Expected error, but got success" ) ,
0 commit comments