@@ -319,6 +319,10 @@ class HttpClient {
319
319
public:
320
320
int init (const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
321
321
const bool progress, std::string * response_str = nullptr ) {
322
+ if (std::filesystem::exists (output_file)) {
323
+ return 0 ;
324
+ }
325
+
322
326
std::string output_file_partial;
323
327
curl = curl_easy_init ();
324
328
if (!curl) {
@@ -558,13 +562,14 @@ class LlamaData {
558
562
}
559
563
560
564
sampler = initialize_sampler (opt);
565
+
561
566
return 0 ;
562
567
}
563
568
564
569
private:
565
570
#ifdef LLAMA_USE_CURL
566
- int download (const std::string & url, const std::vector<std:: string> & headers , const std::string & output_file ,
567
- const bool progress , std::string * response_str = nullptr ) {
571
+ int download (const std::string & url, const std::string & output_file , const bool progress ,
572
+ const std::vector<std::string> & headers = {} , std::string * response_str = nullptr ) {
568
573
HttpClient http;
569
574
if (http.init (url, headers, output_file, progress, response_str)) {
570
575
return 1 ;
@@ -573,57 +578,95 @@ class LlamaData {
573
578
return 0 ;
574
579
}
575
580
#else
576
- int download (const std::string &, const std::vector<std:: string> &, const std::string &, const bool ,
581
+ int download (const std::string &, const std::string &, const bool , const std::vector<std:: string> & = {} ,
577
582
std::string * = nullptr ) {
578
583
printe (" %s: llama.cpp built without libcurl, downloading from an url not supported.\n " , __func__);
584
+
579
585
return 1 ;
580
586
}
581
587
#endif
582
588
583
- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
589
+ // Helper function to handle model tag extraction and URL construction
590
+ std::pair<std::string, std::string> extract_model_and_tag (std::string & model, const std::string & base_url) {
591
+ std::string model_tag = " latest" ;
592
+ const size_t colon_pos = model.find (' :' );
593
+ if (colon_pos != std::string::npos) {
594
+ model_tag = model.substr (colon_pos + 1 );
595
+ model = model.substr (0 , colon_pos);
596
+ }
597
+
598
+ std::string url = base_url + model + " /manifests/" + model_tag;
599
+
600
+ return { model, url };
601
+ }
602
+
603
+ // Helper function to download and parse the manifest
604
+ int download_and_parse_manifest (const std::string & url, const std::vector<std::string> & headers,
605
+ nlohmann::json & manifest) {
606
+ std::string manifest_str;
607
+ int ret = download (url, " " , false , headers, &manifest_str);
608
+ if (ret) {
609
+ return ret;
610
+ }
611
+
612
+ manifest = nlohmann::json::parse (manifest_str);
613
+
614
+ return 0 ;
615
+ }
616
+
617
+ int huggingface_dl (std::string & model, const std::string & bn) {
584
618
// Find the second occurrence of '/' after protocol string
585
619
size_t pos = model.find (' /' );
586
620
pos = model.find (' /' , pos + 1 );
621
+ std::string hfr, hff;
622
+ std::vector<std::string> headers = { " User-Agent: llama-cpp" , " Accept: application/json" };
623
+ std::string url;
624
+
587
625
if (pos == std::string::npos) {
588
- return 1 ;
626
+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://huggingface.co/v2/" );
627
+ hfr = model_name;
628
+
629
+ nlohmann::json manifest;
630
+ int ret = download_and_parse_manifest (manifest_url, headers, manifest);
631
+ if (ret) {
632
+ return ret;
633
+ }
634
+
635
+ hff = manifest[" ggufFile" ][" rfilename" ];
636
+ } else {
637
+ hfr = model.substr (0 , pos);
638
+ hff = model.substr (pos + 1 );
589
639
}
590
640
591
- const std::string hfr = model.substr (0 , pos);
592
- const std::string hff = model.substr (pos + 1 );
593
- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
594
- return download (url, headers, bn, true );
641
+ url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
642
+
643
+ return download (url, bn, true , headers);
595
644
}
596
645
597
- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
646
+ int ollama_dl (std::string & model, const std::string & bn) {
647
+ const std::vector<std::string> headers = { " Accept: application/vnd.docker.distribution.manifest.v2+json" };
598
648
if (model.find (' /' ) == std::string::npos) {
599
649
model = " library/" + model;
600
650
}
601
651
602
- std::string model_tag = " latest" ;
603
- size_t colon_pos = model.find (' :' );
604
- if (colon_pos != std::string::npos) {
605
- model_tag = model.substr (colon_pos + 1 );
606
- model = model.substr (0 , colon_pos);
607
- }
608
-
609
- std::string manifest_url = " https://registry.ollama.ai/v2/" + model + " /manifests/" + model_tag;
610
- std::string manifest_str;
611
- const int ret = download (manifest_url, headers, " " , false , &manifest_str);
652
+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://registry.ollama.ai/v2/" );
653
+ nlohmann::json manifest;
654
+ int ret = download_and_parse_manifest (manifest_url, {}, manifest);
612
655
if (ret) {
613
656
return ret;
614
657
}
615
658
616
- nlohmann::json manifest = nlohmann::json::parse (manifest_str);
617
- std::string layer;
659
+ std::string layer;
618
660
for (const auto & l : manifest[" layers" ]) {
619
661
if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
620
662
layer = l[" digest" ];
621
663
break ;
622
664
}
623
665
}
624
666
625
- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
626
- return download (blob_url, headers, bn, true );
667
+ std::string blob_url = " https://registry.ollama.ai/v2/" + model_name + " /blobs/" + layer;
668
+
669
+ return download (blob_url, bn, true , headers);
627
670
}
628
671
629
672
std::string basename (const std::string & path) {
@@ -653,22 +696,18 @@ class LlamaData {
653
696
return ret;
654
697
}
655
698
656
- const std::string bn = basename (model_);
657
- const std::vector<std::string> headers = { " --header" ,
658
- " Accept: application/vnd.docker.distribution.manifest.v2+json" };
699
+ const std::string bn = basename (model_);
659
700
if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
660
701
rm_until_substring (model_, " ://" );
661
- ret = huggingface_dl (model_, headers, bn);
702
+ ret = huggingface_dl (model_, bn);
662
703
} else if (string_starts_with (model_, " hf.co/" )) {
663
704
rm_until_substring (model_, " hf.co/" );
664
- ret = huggingface_dl (model_, headers, bn);
665
- } else if (string_starts_with (model_, " ollama://" )) {
666
- rm_until_substring (model_, " ://" );
667
- ret = ollama_dl (model_, headers, bn);
705
+ ret = huggingface_dl (model_, bn);
668
706
} else if (string_starts_with (model_, " https://" )) {
669
- ret = download (model_, headers, bn, true );
670
- } else {
671
- ret = ollama_dl (model_, headers, bn);
707
+ ret = download (model_, bn, true );
708
+ } else { // ollama:// or nothing
709
+ rm_until_substring (model_, " ://" );
710
+ ret = ollama_dl (model_, bn);
672
711
}
673
712
674
713
model_ = bn;
0 commit comments