Skip to content

Commit a4417dd

Browse files
authored
Add new hf protocol for ollama (ggml-org#11449)
https://huggingface.co/docs/hub/en/ollama Signed-off-by: Eric Curtin <[email protected]>
1 parent d6d24cd commit a4417dd

File tree

1 file changed

+74
-35
lines changed

1 file changed

+74
-35
lines changed

examples/run/run.cpp

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ class HttpClient {
319319
public:
320320
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
321321
const bool progress, std::string * response_str = nullptr) {
322+
if (std::filesystem::exists(output_file)) {
323+
return 0;
324+
}
325+
322326
std::string output_file_partial;
323327
curl = curl_easy_init();
324328
if (!curl) {
@@ -558,13 +562,14 @@ class LlamaData {
558562
}
559563

560564
sampler = initialize_sampler(opt);
565+
561566
return 0;
562567
}
563568

564569
private:
565570
#ifdef LLAMA_USE_CURL
566-
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
567-
const bool progress, std::string * response_str = nullptr) {
571+
int download(const std::string & url, const std::string & output_file, const bool progress,
572+
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
568573
HttpClient http;
569574
if (http.init(url, headers, output_file, progress, response_str)) {
570575
return 1;
@@ -573,57 +578,95 @@ class LlamaData {
573578
return 0;
574579
}
575580
#else
576-
int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
581+
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
577582
std::string * = nullptr) {
578583
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
584+
579585
return 1;
580586
}
581587
#endif
582588

583-
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
589+
// Helper function to handle model tag extraction and URL construction
590+
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
591+
std::string model_tag = "latest";
592+
const size_t colon_pos = model.find(':');
593+
if (colon_pos != std::string::npos) {
594+
model_tag = model.substr(colon_pos + 1);
595+
model = model.substr(0, colon_pos);
596+
}
597+
598+
std::string url = base_url + model + "/manifests/" + model_tag;
599+
600+
return { model, url };
601+
}
602+
603+
// Helper function to download and parse the manifest
604+
int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
605+
nlohmann::json & manifest) {
606+
std::string manifest_str;
607+
int ret = download(url, "", false, headers, &manifest_str);
608+
if (ret) {
609+
return ret;
610+
}
611+
612+
manifest = nlohmann::json::parse(manifest_str);
613+
614+
return 0;
615+
}
616+
617+
int huggingface_dl(std::string & model, const std::string & bn) {
584618
// Find the second occurrence of '/' after protocol string
585619
size_t pos = model.find('/');
586620
pos = model.find('/', pos + 1);
621+
std::string hfr, hff;
622+
std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
623+
std::string url;
624+
587625
if (pos == std::string::npos) {
588-
return 1;
626+
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/");
627+
hfr = model_name;
628+
629+
nlohmann::json manifest;
630+
int ret = download_and_parse_manifest(manifest_url, headers, manifest);
631+
if (ret) {
632+
return ret;
633+
}
634+
635+
hff = manifest["ggufFile"]["rfilename"];
636+
} else {
637+
hfr = model.substr(0, pos);
638+
hff = model.substr(pos + 1);
589639
}
590640

591-
const std::string hfr = model.substr(0, pos);
592-
const std::string hff = model.substr(pos + 1);
593-
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
594-
return download(url, headers, bn, true);
641+
url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
642+
643+
return download(url, bn, true, headers);
595644
}
596645

597-
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
646+
int ollama_dl(std::string & model, const std::string & bn) {
647+
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
598648
if (model.find('/') == std::string::npos) {
599649
model = "library/" + model;
600650
}
601651

602-
std::string model_tag = "latest";
603-
size_t colon_pos = model.find(':');
604-
if (colon_pos != std::string::npos) {
605-
model_tag = model.substr(colon_pos + 1);
606-
model = model.substr(0, colon_pos);
607-
}
608-
609-
std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
610-
std::string manifest_str;
611-
const int ret = download(manifest_url, headers, "", false, &manifest_str);
652+
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/");
653+
nlohmann::json manifest;
654+
int ret = download_and_parse_manifest(manifest_url, {}, manifest);
612655
if (ret) {
613656
return ret;
614657
}
615658

616-
nlohmann::json manifest = nlohmann::json::parse(manifest_str);
617-
std::string layer;
659+
std::string layer;
618660
for (const auto & l : manifest["layers"]) {
619661
if (l["mediaType"] == "application/vnd.ollama.image.model") {
620662
layer = l["digest"];
621663
break;
622664
}
623665
}
624666

625-
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
626-
return download(blob_url, headers, bn, true);
667+
std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer;
668+
669+
return download(blob_url, bn, true, headers);
627670
}
628671

629672
std::string basename(const std::string & path) {
@@ -653,22 +696,18 @@ class LlamaData {
653696
return ret;
654697
}
655698

656-
const std::string bn = basename(model_);
657-
const std::vector<std::string> headers = { "--header",
658-
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
699+
const std::string bn = basename(model_);
659700
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
660701
rm_until_substring(model_, "://");
661-
ret = huggingface_dl(model_, headers, bn);
702+
ret = huggingface_dl(model_, bn);
662703
} else if (string_starts_with(model_, "hf.co/")) {
663704
rm_until_substring(model_, "hf.co/");
664-
ret = huggingface_dl(model_, headers, bn);
665-
} else if (string_starts_with(model_, "ollama://")) {
666-
rm_until_substring(model_, "://");
667-
ret = ollama_dl(model_, headers, bn);
705+
ret = huggingface_dl(model_, bn);
668706
} else if (string_starts_with(model_, "https://")) {
669-
ret = download(model_, headers, bn, true);
670-
} else {
671-
ret = ollama_dl(model_, headers, bn);
707+
ret = download(model_, bn, true);
708+
} else { // ollama:// or nothing
709+
rm_until_substring(model_, "://");
710+
ret = ollama_dl(model_, bn);
672711
}
673712

674713
model_ = bn;

0 commit comments

Comments
 (0)