|
3 | 3 | import os
|
4 | 4 | from functools import partial
|
5 | 5 | from pathlib import Path
|
6 |
| -from typing import Union |
| 6 | +from tempfile import TemporaryDirectory |
| 7 | +from typing import Optional, Union |
7 | 8 |
|
8 | 9 | import torch
|
9 | 10 | from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
| 11 | + |
10 | 12 | try:
|
11 | 13 | from torch.hub import get_dir
|
12 | 14 | except ImportError:
|
|
15 | 17 | from timm import __version__
|
16 | 18 |
|
17 | 19 | try:
|
18 |
| - from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url |
| 20 | + from huggingface_hub import (create_repo, get_hf_file_metadata, |
| 21 | + hf_hub_download, hf_hub_url, |
| 22 | + repo_type_and_id_from_hf_id, upload_folder) |
| 23 | + from huggingface_hub.utils import EntryNotFoundError |
19 | 24 | hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
20 | 25 | _has_hf_hub = True
|
21 | 26 | except ImportError:
|
@@ -121,56 +126,45 @@ def save_for_hf(model, save_directory, model_config=None):
|
121 | 126 |
|
122 | 127 | def push_to_hf_hub(
|
123 | 128 | model,
|
124 |
| - local_dir, |
125 |
| - repo_namespace_or_url=None, |
126 |
| - commit_message='Add model', |
127 |
| - use_auth_token=True, |
128 |
| - git_email=None, |
129 |
| - git_user=None, |
130 |
| - revision=None, |
131 |
| - model_config=None, |
| 129 | + repo_id: str, |
| 130 | + commit_message: str ='Add model', |
| 131 | + token: Optional[str] = None, |
| 132 | + revision: Optional[str] = None, |
| 133 | + private: bool = False, |
| 134 | + create_pr: bool = False, |
| 135 | + model_config: Optional[dict] = None, |
132 | 136 | ):
|
133 |
| - if isinstance(use_auth_token, str): |
134 |
| - token = use_auth_token |
135 |
| - else: |
136 |
| - token = HfFolder.get_token() |
137 |
| - if token is None: |
138 |
| - raise ValueError( |
139 |
| - "You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and " |
140 |
| - "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " |
141 |
| - "token as the `use_auth_token` argument." |
142 |
| - ) |
143 |
| - |
144 |
| - if repo_namespace_or_url: |
145 |
| - repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:] |
146 |
| - else: |
147 |
| - repo_owner = HfApi().whoami(token)['name'] |
148 |
| - repo_name = Path(local_dir).name |
149 |
| - |
150 |
| - repo_id = f'{repo_owner}/{repo_name}' |
151 |
| - repo_url = f'https://huggingface.co/{repo_id}' |
152 |
| - |
153 | 137 | # Create repo if doesn't exist yet
|
154 |
| - HfApi().create_repo(repo_id, token=use_auth_token, exist_ok=True) |
155 |
| - |
156 |
| - repo = Repository( |
157 |
| - local_dir, |
158 |
| - clone_from=repo_url, |
159 |
| - use_auth_token=use_auth_token, |
160 |
| - git_user=git_user, |
161 |
| - git_email=git_email, |
162 |
| - revision=revision, |
163 |
| - ) |
164 |
| - |
165 |
| - # Prepare a default model card that includes the necessary tags to enable inference. |
166 |
| - readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}' |
167 |
| - with repo.commit(commit_message): |
| 138 | + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) |
| 139 | + |
| 140 | + # Infer complete repo_id from repo_url |
| 141 | + # Can be different from the input `repo_id` if repo_owner was implicit |
| 142 | + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) |
| 143 | + repo_id = f"{repo_owner}/{repo_name}" |
| 144 | + |
| 145 | + # Check if README file already exist in repo |
| 146 | + try: |
| 147 | + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) |
| 148 | + has_readme = True |
| 149 | + except EntryNotFoundError: |
| 150 | + has_readme = False |
| 151 | + |
| 152 | + # Dump model and push to Hub |
| 153 | + with TemporaryDirectory() as tmpdir: |
168 | 154 | # Save model weights and config.
|
169 |
| - save_for_hf(model, repo.local_dir, model_config=model_config) |
| 155 | + save_for_hf(model, tmpdir, model_config=model_config) |
170 | 156 |
|
171 |
| - # Save a model card if it doesn't exist. |
172 |
| - readme_path = Path(repo.local_dir) / 'README.md' |
173 |
| - if not readme_path.exists(): |
| 157 | + # Add readme if does not exist |
| 158 | + if not has_readme: |
| 159 | + readme_path = Path(tmpdir) / "README.md" |
| 160 | + readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}' |
174 | 161 | readme_path.write_text(readme_text)
|
175 | 162 |
|
176 |
| - return repo.git_remote_url() |
| 163 | + # Upload model and return |
| 164 | + return upload_folder( |
| 165 | + repo_id=repo_id, |
| 166 | + folder_path=tmpdir, |
| 167 | + revision=revision, |
| 168 | + create_pr=create_pr, |
| 169 | + commit_message=commit_message, |
| 170 | + ) |
0 commit comments