Skip to content

Commit 7d88711

Browse files
sayakpaulWauplin
andauthored
[Core] support saving and loading of sharded checkpoints (huggingface#7830)
* feat: support saving a model in sharded checkpoints. * feat: make loading of sharded checkpoints work. * add tests * cleanse the loading logic a bit more. * more resilience while loading from the Hub. * parallelize shard downloads by using snapshot_download()/ * default to a shard size. * more fix * Empty-Commit * debug * fix * uality * more debugging * fix more * initial comments from Benjamin * move certain methods to loading_utils * add test to check if the correct number of shards are present. * add a test to check if loading of sharded checkpoints from the Hub is okay * clarify the unit when passed as an int. * use hf_hub for sharding. * remove unnecessary code * remove unnecessary function * lucain's comments. * fixes * address high-level comments. * fix test * subfolder shenanigans./ * Update src/diffusers/utils/hub_utils.py Co-authored-by: Lucain <[email protected]> * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * remove _huggingface_hub_version as not needed. * address more feedback. * add a test for local_files_only=True/ * need hf hub to be at least 0.23.2 * style * final comment. * clean up subfolder. * deal with suffixes in code. * _add_variant default. * use weights_name_pattern * remove add_suffix_keyword * clean up downloading of sharded ckpts. * don't return something special when using index.json * fix more * don't use bare except * remove comments and catch the errors better * fix a couple of things when using is_file() * empty --------- Co-authored-by: Lucain <[email protected]>
1 parent b63c956 commit 7d88711

11 files changed

+354
-23
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"filelock",
102102
"flax>=0.4.1",
103103
"hf-doc-builder>=0.3.0",
104-
"huggingface-hub>=0.20.2",
104+
"huggingface-hub>=0.23.2",
105105
"requests-mock==1.10.0",
106106
"importlib_metadata",
107107
"invisible-watermark>=0.2.0",

src/diffusers/dependency_versions_table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"filelock": "filelock",
1010
"flax": "flax>=0.4.1",
1111
"hf-doc-builder": "hf-doc-builder>=0.3.0",
12-
"huggingface-hub": "huggingface-hub>=0.20.2",
12+
"huggingface-hub": "huggingface-hub>=0.23.2",
1313
"requests-mock": "requests-mock==1.10.0",
1414
"importlib_metadata": "importlib_metadata",
1515
"invisible-watermark": "invisible-watermark>=0.2.0",

src/diffusers/models/model_loading_utils.py

+55
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@
1818
import inspect
1919
import os
2020
from collections import OrderedDict
21+
from pathlib import Path
2122
from typing import List, Optional, Union
2223

2324
import safetensors
2425
import torch
26+
from huggingface_hub.utils import EntryNotFoundError
2527

2628
from ..utils import (
29+
SAFE_WEIGHTS_INDEX_NAME,
2730
SAFETENSORS_FILE_EXTENSION,
31+
WEIGHTS_INDEX_NAME,
32+
_add_variant,
33+
_get_model_file,
2834
is_accelerate_available,
2935
is_torch_version,
3036
logging,
@@ -175,3 +181,52 @@ def load(module: torch.nn.Module, prefix: str = ""):
175181
load(model_to_load)
176182

177183
return error_msgs
184+
185+
186+
def _fetch_index_file(
187+
is_local,
188+
pretrained_model_name_or_path,
189+
subfolder,
190+
use_safetensors,
191+
cache_dir,
192+
variant,
193+
force_download,
194+
resume_download,
195+
proxies,
196+
local_files_only,
197+
token,
198+
revision,
199+
user_agent,
200+
commit_hash,
201+
):
202+
if is_local:
203+
index_file = Path(
204+
pretrained_model_name_or_path,
205+
subfolder or "",
206+
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
207+
)
208+
else:
209+
index_file_in_repo = Path(
210+
subfolder or "",
211+
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
212+
).as_posix()
213+
try:
214+
index_file = _get_model_file(
215+
pretrained_model_name_or_path,
216+
weights_name=index_file_in_repo,
217+
cache_dir=cache_dir,
218+
force_download=force_download,
219+
resume_download=resume_download,
220+
proxies=proxies,
221+
local_files_only=local_files_only,
222+
token=token,
223+
revision=revision,
224+
subfolder=subfolder,
225+
user_agent=user_agent,
226+
commit_hash=commit_hash,
227+
)
228+
index_file = Path(index_file)
229+
except (EntryNotFoundError, EnvironmentError):
230+
index_file = None
231+
232+
return index_file

src/diffusers/models/modeling_utils.py

+121-16
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import inspect
1818
import itertools
19+
import json
1920
import os
2021
import re
2122
from collections import OrderedDict
@@ -25,17 +26,20 @@
2526

2627
import safetensors
2728
import torch
28-
from huggingface_hub import create_repo
29+
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
2930
from huggingface_hub.utils import validate_hf_hub_args
3031
from torch import Tensor, nn
3132

3233
from .. import __version__
3334
from ..utils import (
3435
CONFIG_NAME,
3536
FLAX_WEIGHTS_NAME,
37+
SAFE_WEIGHTS_INDEX_NAME,
3638
SAFETENSORS_WEIGHTS_NAME,
39+
WEIGHTS_INDEX_NAME,
3740
WEIGHTS_NAME,
3841
_add_variant,
42+
_get_checkpoint_shard_files,
3943
_get_model_file,
4044
deprecate,
4145
is_accelerate_available,
@@ -49,6 +53,7 @@
4953
)
5054
from .model_loading_utils import (
5155
_determine_device_map,
56+
_fetch_index_file,
5257
_load_state_dict_into_model,
5358
load_model_dict_into_meta,
5459
load_state_dict,
@@ -57,6 +62,8 @@
5762

5863
logger = logging.get_logger(__name__)
5964

65+
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
66+
6067

6168
if is_torch_version(">=", "1.9.0"):
6269
_LOW_CPU_MEM_USAGE_DEFAULT = True
@@ -263,6 +270,7 @@ def save_pretrained(
263270
save_function: Optional[Callable] = None,
264271
safe_serialization: bool = True,
265272
variant: Optional[str] = None,
273+
max_shard_size: Union[int, str] = "5GB",
266274
push_to_hub: bool = False,
267275
**kwargs,
268276
):
@@ -285,6 +293,10 @@ def save_pretrained(
285293
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
286294
variant (`str`, *optional*):
287295
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
296+
max_shard_size (`int` or `str`, defaults to `"5GB"`):
297+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
298+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
299+
If expressed as an integer, the unit is bytes.
288300
push_to_hub (`bool`, *optional*, defaults to `False`):
289301
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
290302
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -296,6 +308,14 @@ def save_pretrained(
296308
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
297309
return
298310

311+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
312+
weights_name = _add_variant(weights_name, variant)
313+
weight_name_split = weights_name.split(".")
314+
if len(weight_name_split) in [2, 3]:
315+
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
316+
else:
317+
raise ValueError(f"Invalid {weights_name} provided.")
318+
299319
os.makedirs(save_directory, exist_ok=True)
300320

301321
if push_to_hub:
@@ -317,18 +337,58 @@ def save_pretrained(
317337
# Save the model
318338
state_dict = model_to_save.state_dict()
319339

320-
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
321-
weights_name = _add_variant(weights_name, variant)
322-
323340
# Save the model
324-
if safe_serialization:
325-
safetensors.torch.save_file(
326-
state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
341+
state_dict_split = split_torch_state_dict_into_shards(
342+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
343+
)
344+
345+
# Clean the folder from a previous save
346+
if is_main_process:
347+
for filename in os.listdir(save_directory):
348+
if filename in state_dict_split.filename_to_tensors.keys():
349+
continue
350+
full_filename = os.path.join(save_directory, filename)
351+
if not os.path.isfile(full_filename):
352+
continue
353+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
354+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
355+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
356+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
357+
if (
358+
filename.startswith(weights_without_ext)
359+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
360+
):
361+
os.remove(full_filename)
362+
363+
for filename, tensors in state_dict_split.filename_to_tensors.items():
364+
shard = {tensor: state_dict[tensor] for tensor in tensors}
365+
filepath = os.path.join(save_directory, filename)
366+
if safe_serialization:
367+
# At some point we will need to deal better with save_function (used for TPU and other distributed
368+
# joyfulness), but for now this enough.
369+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
370+
else:
371+
torch.save(shard, filepath)
372+
373+
if state_dict_split.is_sharded:
374+
index = {
375+
"metadata": state_dict_split.metadata,
376+
"weight_map": state_dict_split.tensor_to_filename,
377+
}
378+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
379+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
380+
# Save the index as well
381+
with open(save_index_file, "w", encoding="utf-8") as f:
382+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
383+
f.write(content)
384+
logger.info(
385+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
386+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
387+
f"index located at {save_index_file}."
327388
)
328389
else:
329-
torch.save(state_dict, Path(save_directory, weights_name).as_posix())
330-
331-
logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
390+
path_to_weights = os.path.join(save_directory, weights_name)
391+
logger.info(f"Model weights saved in {path_to_weights}")
332392

333393
if push_to_hub:
334394
# Create a new empty model card and eventually tag it
@@ -566,6 +626,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
566626
**kwargs,
567627
)
568628

629+
# Determine if we're loading from a directory of sharded checkpoints.
630+
is_sharded = False
631+
index_file = None
632+
is_local = os.path.isdir(pretrained_model_name_or_path)
633+
index_file = _fetch_index_file(
634+
is_local=is_local,
635+
pretrained_model_name_or_path=pretrained_model_name_or_path,
636+
subfolder=subfolder or "",
637+
use_safetensors=use_safetensors,
638+
cache_dir=cache_dir,
639+
variant=variant,
640+
force_download=force_download,
641+
resume_download=resume_download,
642+
proxies=proxies,
643+
local_files_only=local_files_only,
644+
token=token,
645+
revision=revision,
646+
user_agent=user_agent,
647+
commit_hash=commit_hash,
648+
)
649+
if index_file is not None and index_file.is_file():
650+
is_sharded = True
651+
652+
if is_sharded and from_flax:
653+
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
654+
569655
# load model
570656
model_file = None
571657
if from_flax:
@@ -590,7 +676,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
590676

591677
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
592678
else:
593-
if use_safetensors:
679+
if is_sharded:
680+
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
681+
pretrained_model_name_or_path,
682+
index_file,
683+
cache_dir=cache_dir,
684+
proxies=proxies,
685+
resume_download=resume_download,
686+
local_files_only=local_files_only,
687+
token=token,
688+
user_agent=user_agent,
689+
revision=revision,
690+
subfolder=subfolder or "",
691+
)
692+
693+
elif use_safetensors and not is_sharded:
594694
try:
595695
model_file = _get_model_file(
596696
pretrained_model_name_or_path,
@@ -606,11 +706,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
606706
user_agent=user_agent,
607707
commit_hash=commit_hash,
608708
)
709+
609710
except IOError as e:
711+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
610712
if not allow_pickle:
611-
raise e
612-
pass
613-
if model_file is None:
713+
raise
714+
logger.warning(
715+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
716+
)
717+
718+
if model_file is None and not is_sharded:
614719
model_file = _get_model_file(
615720
pretrained_model_name_or_path,
616721
weights_name=_add_variant(WEIGHTS_NAME, variant),
@@ -632,7 +737,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
632737
model = cls.from_config(config, **unused_kwargs)
633738

634739
# if device_map is None, load the state dict and move the params from meta device to the cpu
635-
if device_map is None:
740+
if device_map is None and not is_sharded:
636741
param_device = "cpu"
637742
state_dict = load_state_dict(model_file, variant=variant)
638743
model._convert_deprecated_attention_blocks(state_dict)
@@ -670,7 +775,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
670775
try:
671776
accelerate.load_checkpoint_and_dispatch(
672777
model,
673-
model_file,
778+
model_file if not is_sharded else sharded_ckpt_cached_folder,
674779
device_map,
675780
max_memory=max_memory,
676781
offload_folder=offload_folder,

src/diffusers/utils/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
MIN_PEFT_VERSION,
2929
ONNX_EXTERNAL_WEIGHTS_NAME,
3030
ONNX_WEIGHTS_NAME,
31+
SAFE_WEIGHTS_INDEX_NAME,
3132
SAFETENSORS_FILE_EXTENSION,
3233
SAFETENSORS_WEIGHTS_NAME,
3334
USE_PEFT_BACKEND,
35+
WEIGHTS_INDEX_NAME,
3436
WEIGHTS_NAME,
3537
)
3638
from .deprecation_utils import deprecate
@@ -40,6 +42,7 @@
4042
from .hub_utils import (
4143
PushToHubMixin,
4244
_add_variant,
45+
_get_checkpoint_shard_files,
4346
_get_model_file,
4447
extract_commit_hash,
4548
http_user_agent,

src/diffusers/utils/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828

2929
CONFIG_NAME = "config.json"
3030
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
31+
WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.bin.index.json"
3132
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
3233
ONNX_WEIGHTS_NAME = "model.onnx"
3334
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
35+
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
3436
SAFETENSORS_FILE_EXTENSION = "safetensors"
3537
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
3638
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")

0 commit comments

Comments
 (0)