Skip to content

Commit ce4d348

Browse files
Wauplinrwightman
authored andcommitted
refactor push_to_hub helper
1 parent c037db0 commit ce4d348

File tree

1 file changed

+43
-49
lines changed

1 file changed

+43
-49
lines changed

timm/models/hub.py

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import os
44
from functools import partial
55
from pathlib import Path
6-
from typing import Union
6+
from tempfile import TemporaryDirectory
7+
from typing import Optional, Union
78

89
import torch
910
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
11+
1012
try:
1113
from torch.hub import get_dir
1214
except ImportError:
@@ -15,7 +17,10 @@
1517
from timm import __version__
1618

1719
try:
18-
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
20+
from huggingface_hub import (create_repo, get_hf_file_metadata,
21+
hf_hub_download, hf_hub_url,
22+
repo_type_and_id_from_hf_id, upload_folder)
23+
from huggingface_hub.utils import EntryNotFoundError
1924
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
2025
_has_hf_hub = True
2126
except ImportError:
@@ -121,56 +126,45 @@ def save_for_hf(model, save_directory, model_config=None):
121126

122127
def push_to_hf_hub(
123128
model,
124-
local_dir,
125-
repo_namespace_or_url=None,
126-
commit_message='Add model',
127-
use_auth_token=True,
128-
git_email=None,
129-
git_user=None,
130-
revision=None,
131-
model_config=None,
129+
repo_id: str,
130+
commit_message: str ='Add model',
131+
token: Optional[str] = None,
132+
revision: Optional[str] = None,
133+
private: bool = False,
134+
create_pr: bool = False,
135+
model_config: Optional[dict] = None,
132136
):
133-
if isinstance(use_auth_token, str):
134-
token = use_auth_token
135-
else:
136-
token = HfFolder.get_token()
137-
if token is None:
138-
raise ValueError(
139-
"You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
140-
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
141-
"token as the `use_auth_token` argument."
142-
)
143-
144-
if repo_namespace_or_url:
145-
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
146-
else:
147-
repo_owner = HfApi().whoami(token)['name']
148-
repo_name = Path(local_dir).name
149-
150-
repo_id = f'{repo_owner}/{repo_name}'
151-
repo_url = f'https://huggingface.co/{repo_id}'
152-
153137
# Create repo if doesn't exist yet
154-
HfApi().create_repo(repo_id, token=use_auth_token, exist_ok=True)
155-
156-
repo = Repository(
157-
local_dir,
158-
clone_from=repo_url,
159-
use_auth_token=use_auth_token,
160-
git_user=git_user,
161-
git_email=git_email,
162-
revision=revision,
163-
)
164-
165-
# Prepare a default model card that includes the necessary tags to enable inference.
166-
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
167-
with repo.commit(commit_message):
138+
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
139+
140+
# Infer complete repo_id from repo_url
141+
# Can be different from the input `repo_id` if repo_owner was implicit
142+
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
143+
repo_id = f"{repo_owner}/{repo_name}"
144+
145+
# Check if README file already exist in repo
146+
try:
147+
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
148+
has_readme = True
149+
except EntryNotFoundError:
150+
has_readme = False
151+
152+
# Dump model and push to Hub
153+
with TemporaryDirectory() as tmpdir:
168154
# Save model weights and config.
169-
save_for_hf(model, repo.local_dir, model_config=model_config)
155+
save_for_hf(model, tmpdir, model_config=model_config)
170156

171-
# Save a model card if it doesn't exist.
172-
readme_path = Path(repo.local_dir) / 'README.md'
173-
if not readme_path.exists():
157+
# Add readme if does not exist
158+
if not has_readme:
159+
readme_path = Path(tmpdir) / "README.md"
160+
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
174161
readme_path.write_text(readme_text)
175162

176-
return repo.git_remote_url()
163+
# Upload model and return
164+
return upload_folder(
165+
repo_id=repo_id,
166+
folder_path=tmpdir,
167+
revision=revision,
168+
create_pr=create_pr,
169+
commit_message=commit_message,
170+
)

0 commit comments

Comments
 (0)