Skip to content

Commit bf2e0cf

Browse files
sguggerjulien-cpatrickvonplaten
authored
Trainer push to hub (huggingface#11328)
* Initial support for upload to hub * push -> upload * Fixes + examples * Fix torchhub test * Torchhub test I hate you * push_model_to_hub -> push_to_hub * Apply mixin to other pretrained models * Remove ABC inheritance * Add tests * Typo * Run tests * Install git-lfs * Change approach * Add push_to_hub to all * Staging test suite * Typo * Maybe like this? * More deps * Cache * Adapt name * Quality * MOAR tests * Put it in testing_utils * Docs + torchhub last hope * Styling * Wrong method * Typos * Update src/transformers/file_utils.py Co-authored-by: Julien Chaumond <[email protected]> * Address review comments * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7bc86be commit bf2e0cf

31 files changed

+766
-31
lines changed

.circleci/config.yml

+13-4
Original file line numberDiff line numberDiff line change
@@ -317,24 +317,33 @@ jobs:
317317
- store_artifacts:
318318
path: ~/transformers/reports
319319

320-
run_tests_git_lfs:
320+
run_tests_hub:
321321
working_directory: ~/transformers
322322
docker:
323323
- image: circleci/python:3.7
324324
environment:
325+
HUGGINGFACE_CO_STAGING: yes
325326
RUN_GIT_LFS_TESTS: yes
326327
TRANSFORMERS_IS_CI: yes
327328
resource_class: xlarge
328329
parallelism: 1
329330
steps:
330331
- checkout
332+
- restore_cache:
333+
keys:
334+
- v0.4-hub-{{ checksum "setup.py" }}
335+
- v0.4-{{ checksum "setup.py" }}
331336
- run: sudo apt-get install git-lfs
332337
- run: |
333338
git config --global user.email "[email protected]"
334339
git config --global user.name "ci"
335340
- run: pip install --upgrade pip
336-
- run: pip install .[testing]
337-
- run: python -m pytest -sv ./tests/test_hf_api.py -k "HfLargefilesTest"
341+
- run: pip install .[torch,sentencepiece,testing]
342+
- save_cache:
343+
key: v0.4-hub-{{ checksum "setup.py" }}
344+
paths:
345+
- '~/.cache/pip'
346+
- run: python -m pytest -sv ./tests/ -m is_staging_test
338347

339348
build_doc:
340349
working_directory: ~/transformers
@@ -469,7 +478,7 @@ workflows:
469478
- run_tests_flax
470479
- run_tests_pipelines_torch
471480
- run_tests_pipelines_tf
472-
- run_tests_git_lfs
481+
- run_tests_hub
473482
- build_doc
474483
- deploy_doc: *workflow_filters
475484
# tpu_testing_jobs:

docs/source/main_classes/model.rst

+7
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,10 @@ Generation
7373

7474
.. autoclass:: transformers.generation_tf_utils.TFGenerationMixin
7575
:members:
76+
77+
78+
Pushing to the Hub
79+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
80+
81+
.. autoclass:: transformers.file_utils.PushToHubMixin
82+
:members:

docs/source/model_sharing.rst

+101-3
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@ the `model hub <https://huggingface.co/models>`__.
2222

2323
Optionally, you can join an existing organization or create a new one.
2424

25-
Prepare your model for uploading
26-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2725

2826
We have seen in the :doc:`training tutorial <training>`: how to fine-tune a model on a given task. You have probably
2927
done something similar on your task, either using the model directly in your own training loop or using the
3028
:class:`~.transformers.Trainer`/:class:`~.transformers.TFTrainer` class. Let's see how you can share the result on the
3129
`model hub <https://huggingface.co/models>`__.
3230

3331
Model versioning
34-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
32+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3533

3634
Since version v3.5.0, the model hub has built-in model versioning based on git and git-lfs. It is based on the paradigm
3735
that one model *is* one repo.
@@ -54,6 +52,106 @@ For instance:
5452
>>> revision="v2.0.1" # tag name, or branch name, or commit hash
5553
>>> )
5654
55+
56+
Push your model from Python
57+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
58+
59+
Preparation
60+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
61+
62+
The first step is to make sure your credentials to the hub are stored somewhere. This can be done in two ways. If you
63+
have access to a terminal, you cam just run the following command in the virtual environment where you installed 🤗
64+
Transformers:
65+
66+
.. code-block:: bash
67+
68+
transformers-cli login
69+
70+
It will store your access token in the Hugging Face cache folder (by default :obj:`~/.cache/`).
71+
72+
If you don't have an easy access to a terminal (for instance in a Colab session), you can find a token linked to your
73+
acount by going on `huggingface.co <https://huggingface.co/>`, click on your avatar on the top left corner, then on
74+
`Edit profile` on the left, just beneath your profile picture. In the submenu `API Tokens`, you will find your API
75+
token that you can just copy.
76+
77+
Directly push your model to the hub
78+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
79+
80+
Once you have an API token (either stored in the cache or copied and pasted in your notebook), you can directly push a
81+
finetuned model you saved in :obj:`save_drectory` by calling:
82+
83+
.. code-block:: python
84+
85+
finetuned_model.push_to_hub("my-awesome-model")
86+
87+
If you have your API token not stored in the cache, you will need to pass it with :obj:`use_auth_token=your_token`.
88+
This is also be the case for all the examples below, so we won't mention it again.
89+
90+
This will create a repository in your namespace name :obj:`my-awesome-model`, so anyone can now run:
91+
92+
.. code-block:: python
93+
94+
from transformers import AutoModel
95+
96+
model = AutoModel.from_pretrained("your_username/my-awesome-model")
97+
98+
Even better, you can combine this push to the hub with the call to :obj:`save_pretrained`:
99+
100+
.. code-block:: python
101+
102+
finetuned_model.save_pretrained(save_directory, push_to_hub=True, repo_name="my-awesome-model")
103+
104+
If you are a premium user and want your model to be private, just add :obj:`private=True` to this call.
105+
106+
If you are a member of an organization and want to push it inside the namespace of the organization instead of yours,
107+
just add :obj:`organization=my_amazing_org`.
108+
109+
Add new files to your model repo
110+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
111+
112+
Once you have pushed your model to the hub, you might want to add the tokenizer, or a version of your model for another
113+
framework (TensorFlow, PyTorch, Flax). This is super easy to do! Let's begin with the tokenizer. You can add it to the
114+
repo you created before like this
115+
116+
.. code-block:: python
117+
118+
tokenizer.push_to_hub("my-awesome-model")
119+
120+
If you know its URL (it should be :obj:`https://huggingface.co/username/repo_name`), you can also do:
121+
122+
.. code-block:: python
123+
124+
tokenizer.push_to_hub(repo_url=my_repo_url)
125+
126+
And that's all there is to it! It's also a very easy way to fix a mistake if one of the files online had a bug.
127+
128+
To add a model for another backend, it's also super easy. Let's say you have fine-tuned a TensorFlow model and want to
129+
add the pytorch model files to your model repo, so that anyone in the community can use it. The following allows you to
130+
directly create a PyTorch version of your TensorFlow model:
131+
132+
.. code-block:: python
133+
134+
from transfomers import AutoModel
135+
136+
model = AutoModel.from_pretrained(save_directory, from_tf=True)
137+
138+
You can also replace :obj:`save_directory` by the identifier of your model (:obj:`username/repo_name`) if you don't
139+
have a local save of it anymore. Then, just do the same as before:
140+
141+
.. code-block:: python
142+
143+
model.push_to_hub("my-awesome-model")
144+
145+
or
146+
147+
.. code-block:: python
148+
149+
model.push_to_hub(repo_url=my_repo_url)
150+
151+
152+
Use your terminal and git
153+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154+
57155
Basic steps
58156
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
59157

examples/legacy/text-classification/run_tf_text_classification.py

100644100755
File mode changed.

examples/pytorch/language-modeling/run_clm.py

+3
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,9 @@ def group_texts(examples):
447447
trainer.log_metrics("eval", metrics)
448448
trainer.save_metrics("eval", metrics)
449449

450+
if training_args.push_to_hub:
451+
trainer.push_to_hub()
452+
450453

451454
def _mp_fn(index):
452455
# For xla_spawn (TPUs)

examples/pytorch/language-modeling/run_mlm.py

+3
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,9 @@ def group_texts(examples):
476476
trainer.log_metrics("eval", metrics)
477477
trainer.save_metrics("eval", metrics)
478478

479+
if training_args.push_to_hub:
480+
trainer.push_to_hub()
481+
479482

480483
def _mp_fn(index):
481484
# For xla_spawn (TPUs)

examples/pytorch/language-modeling/run_plm.py

+3
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,9 @@ def group_texts(examples):
452452
trainer.log_metrics("eval", metrics)
453453
trainer.save_metrics("eval", metrics)
454454

455+
if training_args.push_to_hub:
456+
trainer.push_to_hub()
457+
455458

456459
def _mp_fn(index):
457460
# For xla_spawn (TPUs)

examples/pytorch/multiple-choice/run_swag.py

+3
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ def compute_metrics(eval_predictions):
428428
trainer.log_metrics("eval", metrics)
429429
trainer.save_metrics("eval", metrics)
430430

431+
if training_args.push_to_hub:
432+
trainer.push_to_hub()
433+
431434

432435
def _mp_fn(index):
433436
# For xla_spawn (TPUs)

examples/pytorch/question-answering/run_qa.py

+3
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,9 @@ def compute_metrics(p: EvalPrediction):
599599
trainer.log_metrics("test", metrics)
600600
trainer.save_metrics("test", metrics)
601601

602+
if training_args.push_to_hub:
603+
trainer.push_to_hub()
604+
602605

603606
def _mp_fn(index):
604607
# For xla_spawn (TPUs)

examples/pytorch/question-answering/run_qa_beam_search.py

+3
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,9 @@ def compute_metrics(p: EvalPrediction):
638638
trainer.log_metrics("test", metrics)
639639
trainer.save_metrics("test", metrics)
640640

641+
if training_args.push_to_hub:
642+
trainer.push_to_hub()
643+
641644

642645
def _mp_fn(index):
643646
# For xla_spawn (TPUs)

examples/pytorch/summarization/run_summarization.py

+3
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,9 @@ def compute_metrics(eval_preds):
579579
with open(output_test_preds_file, "w") as writer:
580580
writer.write("\n".join(test_preds))
581581

582+
if training_args.push_to_hub:
583+
trainer.push_to_hub()
584+
582585
return results
583586

584587

examples/pytorch/text-classification/run_glue.py

+3
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,9 @@ def compute_metrics(p: EvalPrediction):
517517
item = label_list[item]
518518
writer.write(f"{index}\t{item}\n")
519519

520+
if training_args.push_to_hub:
521+
trainer.push_to_hub()
522+
520523

521524
def _mp_fn(index):
522525
# For xla_spawn (TPUs)

examples/pytorch/token-classification/run_ner.py

+3
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,9 @@ def compute_metrics(p):
491491
for prediction in true_predictions:
492492
writer.write(" ".join(prediction) + "\n")
493493

494+
if training_args.push_to_hub:
495+
trainer.push_to_hub()
496+
494497

495498
def _mp_fn(index):
496499
# For xla_spawn (TPUs)

examples/pytorch/translation/run_translation.py

+3
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,9 @@ def compute_metrics(eval_preds):
571571
with open(output_test_preds_file, "w") as writer:
572572
writer.write("\n".join(test_preds))
573573

574+
if training_args.push_to_hub:
575+
trainer.push_to_hub()
576+
574577
return results
575578

576579

hubconf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232

3333

34-
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata"]
34+
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata", "huggingface_hub"]
3535

3636

3737
@add_start_docstrings(AutoConfig.__doc__)

src/transformers/configuration_utils.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
from typing import Any, Dict, Tuple, Union
2323

2424
from . import __version__
25-
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
25+
from .file_utils import CONFIG_NAME, PushToHubMixin, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
2626
from .utils import logging
2727

2828

2929
logger = logging.get_logger(__name__)
3030

3131

32-
class PretrainedConfig(object):
32+
class PretrainedConfig(PushToHubMixin):
3333
r"""
3434
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
3535
methods for loading/downloading/saving configurations.
@@ -310,14 +310,19 @@ def num_labels(self, num_labels: int):
310310
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
311311
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
312312

313-
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
313+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
314314
"""
315315
Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
316316
:func:`~transformers.PretrainedConfig.from_pretrained` class method.
317317
318318
Args:
319319
save_directory (:obj:`str` or :obj:`os.PathLike`):
320320
Directory where the configuration JSON file will be saved (will be created if it does not exist).
321+
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
322+
Whether or not to push your model to the Hugging Face model hub after saving it.
323+
kwargs:
324+
Additional key word arguments passed along to the
325+
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
321326
"""
322327
if os.path.isfile(save_directory):
323328
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -328,6 +333,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
328333
self.to_json_file(output_config_file, use_diff=True)
329334
logger.info(f"Configuration saved in {output_config_file}")
330335

336+
if push_to_hub:
337+
url = self._push_to_hub(save_files=[output_config_file], **kwargs)
338+
logger.info(f"Configuration pushed to the hub in this commit: {url}")
339+
331340
@classmethod
332341
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
333342
r"""

0 commit comments

Comments
 (0)