16
16
17
17
import inspect
18
18
import itertools
19
+ import json
19
20
import os
20
21
import re
21
22
from collections import OrderedDict
25
26
26
27
import safetensors
27
28
import torch
28
- from huggingface_hub import create_repo
29
+ from huggingface_hub import create_repo , split_torch_state_dict_into_shards
29
30
from huggingface_hub .utils import validate_hf_hub_args
30
31
from torch import Tensor , nn
31
32
32
33
from .. import __version__
33
34
from ..utils import (
34
35
CONFIG_NAME ,
35
36
FLAX_WEIGHTS_NAME ,
37
+ SAFE_WEIGHTS_INDEX_NAME ,
36
38
SAFETENSORS_WEIGHTS_NAME ,
39
+ WEIGHTS_INDEX_NAME ,
37
40
WEIGHTS_NAME ,
38
41
_add_variant ,
42
+ _get_checkpoint_shard_files ,
39
43
_get_model_file ,
40
44
deprecate ,
41
45
is_accelerate_available ,
49
53
)
50
54
from .model_loading_utils import (
51
55
_determine_device_map ,
56
+ _fetch_index_file ,
52
57
_load_state_dict_into_model ,
53
58
load_model_dict_into_meta ,
54
59
load_state_dict ,
57
62
58
63
logger = logging .get_logger (__name__ )
59
64
65
+ _REGEX_SHARD = re .compile (r"(.*?)-\d{5}-of-\d{5}" )
66
+
60
67
61
68
if is_torch_version (">=" , "1.9.0" ):
62
69
_LOW_CPU_MEM_USAGE_DEFAULT = True
@@ -263,6 +270,7 @@ def save_pretrained(
263
270
save_function : Optional [Callable ] = None ,
264
271
safe_serialization : bool = True ,
265
272
variant : Optional [str ] = None ,
273
+ max_shard_size : Union [int , str ] = "5GB" ,
266
274
push_to_hub : bool = False ,
267
275
** kwargs ,
268
276
):
@@ -285,6 +293,10 @@ def save_pretrained(
285
293
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
286
294
variant (`str`, *optional*):
287
295
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
296
+ max_shard_size (`int` or `str`, defaults to `"5GB"`):
297
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
298
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
299
+ If expressed as an integer, the unit is bytes.
288
300
push_to_hub (`bool`, *optional*, defaults to `False`):
289
301
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
290
302
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -296,6 +308,14 @@ def save_pretrained(
296
308
logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
297
309
return
298
310
311
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
312
+ weights_name = _add_variant (weights_name , variant )
313
+ weight_name_split = weights_name .split ("." )
314
+ if len (weight_name_split ) in [2 , 3 ]:
315
+ weights_name_pattern = weight_name_split [0 ] + "{suffix}." + "." .join (weight_name_split [1 :])
316
+ else :
317
+ raise ValueError (f"Invalid { weights_name } provided." )
318
+
299
319
os .makedirs (save_directory , exist_ok = True )
300
320
301
321
if push_to_hub :
@@ -317,18 +337,58 @@ def save_pretrained(
317
337
# Save the model
318
338
state_dict = model_to_save .state_dict ()
319
339
320
- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
321
- weights_name = _add_variant (weights_name , variant )
322
-
323
340
# Save the model
324
- if safe_serialization :
325
- safetensors .torch .save_file (
326
- state_dict , Path (save_directory , weights_name ).as_posix (), metadata = {"format" : "pt" }
341
+ state_dict_split = split_torch_state_dict_into_shards (
342
+ state_dict , max_shard_size = max_shard_size , filename_pattern = weights_name_pattern
343
+ )
344
+
345
+ # Clean the folder from a previous save
346
+ if is_main_process :
347
+ for filename in os .listdir (save_directory ):
348
+ if filename in state_dict_split .filename_to_tensors .keys ():
349
+ continue
350
+ full_filename = os .path .join (save_directory , filename )
351
+ if not os .path .isfile (full_filename ):
352
+ continue
353
+ weights_without_ext = weights_name_pattern .replace (".bin" , "" ).replace (".safetensors" , "" )
354
+ weights_without_ext = weights_without_ext .replace ("{suffix}" , "" )
355
+ filename_without_ext = filename .replace (".bin" , "" ).replace (".safetensors" , "" )
356
+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
357
+ if (
358
+ filename .startswith (weights_without_ext )
359
+ and _REGEX_SHARD .fullmatch (filename_without_ext ) is not None
360
+ ):
361
+ os .remove (full_filename )
362
+
363
+ for filename , tensors in state_dict_split .filename_to_tensors .items ():
364
+ shard = {tensor : state_dict [tensor ] for tensor in tensors }
365
+ filepath = os .path .join (save_directory , filename )
366
+ if safe_serialization :
367
+ # At some point we will need to deal better with save_function (used for TPU and other distributed
368
+ # joyfulness), but for now this enough.
369
+ safetensors .torch .save_file (shard , filepath , metadata = {"format" : "pt" })
370
+ else :
371
+ torch .save (shard , filepath )
372
+
373
+ if state_dict_split .is_sharded :
374
+ index = {
375
+ "metadata" : state_dict_split .metadata ,
376
+ "weight_map" : state_dict_split .tensor_to_filename ,
377
+ }
378
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
379
+ save_index_file = os .path .join (save_directory , _add_variant (save_index_file , variant ))
380
+ # Save the index as well
381
+ with open (save_index_file , "w" , encoding = "utf-8" ) as f :
382
+ content = json .dumps (index , indent = 2 , sort_keys = True ) + "\n "
383
+ f .write (content )
384
+ logger .info (
385
+ f"The model is bigger than the maximum size per checkpoint ({ max_shard_size } ) and is going to be "
386
+ f"split in { len (state_dict_split .filename_to_tensors )} checkpoint shards. You can find where each parameters has been saved in the "
387
+ f"index located at { save_index_file } ."
327
388
)
328
389
else :
329
- torch .save (state_dict , Path (save_directory , weights_name ).as_posix ())
330
-
331
- logger .info (f"Model weights saved in { Path (save_directory , weights_name ).as_posix ()} " )
390
+ path_to_weights = os .path .join (save_directory , weights_name )
391
+ logger .info (f"Model weights saved in { path_to_weights } " )
332
392
333
393
if push_to_hub :
334
394
# Create a new empty model card and eventually tag it
@@ -566,6 +626,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
566
626
** kwargs ,
567
627
)
568
628
629
+ # Determine if we're loading from a directory of sharded checkpoints.
630
+ is_sharded = False
631
+ index_file = None
632
+ is_local = os .path .isdir (pretrained_model_name_or_path )
633
+ index_file = _fetch_index_file (
634
+ is_local = is_local ,
635
+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
636
+ subfolder = subfolder or "" ,
637
+ use_safetensors = use_safetensors ,
638
+ cache_dir = cache_dir ,
639
+ variant = variant ,
640
+ force_download = force_download ,
641
+ resume_download = resume_download ,
642
+ proxies = proxies ,
643
+ local_files_only = local_files_only ,
644
+ token = token ,
645
+ revision = revision ,
646
+ user_agent = user_agent ,
647
+ commit_hash = commit_hash ,
648
+ )
649
+ if index_file is not None and index_file .is_file ():
650
+ is_sharded = True
651
+
652
+ if is_sharded and from_flax :
653
+ raise ValueError ("Loading of sharded checkpoints is not supported when `from_flax=True`." )
654
+
569
655
# load model
570
656
model_file = None
571
657
if from_flax :
@@ -590,7 +676,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
590
676
591
677
model = load_flax_checkpoint_in_pytorch_model (model , model_file )
592
678
else :
593
- if use_safetensors :
679
+ if is_sharded :
680
+ sharded_ckpt_cached_folder , sharded_metadata = _get_checkpoint_shard_files (
681
+ pretrained_model_name_or_path ,
682
+ index_file ,
683
+ cache_dir = cache_dir ,
684
+ proxies = proxies ,
685
+ resume_download = resume_download ,
686
+ local_files_only = local_files_only ,
687
+ token = token ,
688
+ user_agent = user_agent ,
689
+ revision = revision ,
690
+ subfolder = subfolder or "" ,
691
+ )
692
+
693
+ elif use_safetensors and not is_sharded :
594
694
try :
595
695
model_file = _get_model_file (
596
696
pretrained_model_name_or_path ,
@@ -606,11 +706,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
606
706
user_agent = user_agent ,
607
707
commit_hash = commit_hash ,
608
708
)
709
+
609
710
except IOError as e :
711
+ logger .error (f"An error occurred while trying to fetch { pretrained_model_name_or_path } : { e } " )
610
712
if not allow_pickle :
611
- raise e
612
- pass
613
- if model_file is None :
713
+ raise
714
+ logger .warning (
715
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
716
+ )
717
+
718
+ if model_file is None and not is_sharded :
614
719
model_file = _get_model_file (
615
720
pretrained_model_name_or_path ,
616
721
weights_name = _add_variant (WEIGHTS_NAME , variant ),
@@ -632,7 +737,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
632
737
model = cls .from_config (config , ** unused_kwargs )
633
738
634
739
# if device_map is None, load the state dict and move the params from meta device to the cpu
635
- if device_map is None :
740
+ if device_map is None and not is_sharded :
636
741
param_device = "cpu"
637
742
state_dict = load_state_dict (model_file , variant = variant )
638
743
model ._convert_deprecated_attention_blocks (state_dict )
@@ -670,7 +775,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
670
775
try :
671
776
accelerate .load_checkpoint_and_dispatch (
672
777
model ,
673
- model_file ,
778
+ model_file if not is_sharded else sharded_ckpt_cached_folder ,
674
779
device_map ,
675
780
max_memory = max_memory ,
676
781
offload_folder = offload_folder ,
0 commit comments