77from abc import ABCMeta , abstractmethod
88from collections import OrderedDict
99from pathlib import Path
10- from typing import Any , Callable , Dict , IO , List , Mapping , NamedTuple , Optional , Tuple , Union
10+ from typing import Any , Callable , Dict , List , Mapping , NamedTuple , Optional , Tuple , Union
1111
1212import torch
1313import torch .nn as nn
@@ -101,6 +101,8 @@ class Checkpoint(Serializable):
101101 there must not be another object in ``to_save`` with key ``checkpointer``.
102102 greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
103103 Default, `False`.
104+ save_on_rank: Which rank to save the objects on, in the distributed configuration. If ``save_handler`` is
105+ string or :class:`~pathlib.Path`, this is also used to instantiate a :class:`~ignite.handlers.DiskSaver`.
104106
105107 .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
106108 torch.nn.parallel.DistributedDataParallel.html
@@ -164,9 +166,8 @@ class Checkpoint(Serializable):
164166 > checkpoint_12345.pt
165167
166168 Note:
167- This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only
168- process. This class supports automatically distributed configuration and if used with
169- :class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process.
169+ This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0
170+ process only.
170171
171172 .. warning::
172173
@@ -263,6 +264,7 @@ class Checkpoint(Serializable):
263264
264265 - `score_name` can be used to define `score_function` automatically without providing `score_function`.
265266 - `save_handler` automatically saves to disk if path to directory is provided.
267+ - `save_on_rank` saves objects on this rank in a distributed configuration.
266268 """
267269
268270 Item = NamedTuple ("Item" , [("priority" , int ), ("filename" , str )])
@@ -280,6 +282,7 @@ def __init__(
280282 filename_pattern : Optional [str ] = None ,
281283 include_self : bool = False ,
282284 greater_or_equal : bool = False ,
285+ save_on_rank : Optional [int ] = 0 ,
283286 ):
284287
285288 if not isinstance (to_save , collections .Mapping ):
@@ -312,7 +315,7 @@ def __init__(
312315 self .to_save = to_save
313316 self .filename_prefix = filename_prefix
314317 if isinstance (save_handler , str ) or isinstance (save_handler , Path ):
315- self .save_handler = DiskSaver (save_handler , create_dir = True )
318+ self .save_handler = DiskSaver (save_handler , create_dir = True , save_on_rank = save_on_rank )
316319 else :
317320 self .save_handler = save_handler # type: ignore
318321 self .score_function = score_function
@@ -326,6 +329,7 @@ def __init__(
326329 self ._saved = [] # type: List["Checkpoint.Item"]
327330 self .include_self = include_self
328331 self .greater_or_equal = greater_or_equal
332+ self .save_on_rank = save_on_rank
329333
330334 def _get_filename_pattern (self , global_step : Optional [int ]) -> str :
331335 if self .filename_pattern is None :
@@ -761,10 +765,15 @@ class DiskSaver(BaseSaveHandler):
761765 create_dir: if True, will create directory ``dirname`` if it doesnt exist.
762766 require_empty: If True, will raise exception if there are any files in the
763767 directory ``dirname``.
768+ save_on_rank: The rank on which the checkpoint will be saved. Used in distributed
769+ configuration.
764770 kwargs: Accepted keyword arguments for `torch.save` or `xm.save`.
765771
766772 .. versionchanged:: 0.4.2
767773 Accept ``kwargs`` for `torch.save` or `xm.save`.
774+
775+ .. versionchanged:: 0.5.0
776+ Argument ``save_on_rank`` was added to specify the rank on which checkpoint should be saved.
768777 """
769778
770779 def __init__ (
@@ -773,15 +782,18 @@ def __init__(
773782 atomic : bool = True ,
774783 create_dir : bool = True ,
775784 require_empty : bool = True ,
785+ save_on_rank : Optional [int ] = 0 ,
776786 ** kwargs : Any ,
777787 ):
778788 self .dirname = Path (dirname ).expanduser ()
779789 self ._atomic = atomic
780- self ._check_and_setup (self .dirname , create_dir , require_empty )
790+ self .save_on_rank = save_on_rank
791+
792+ if idist .get_rank () == save_on_rank :
793+ self ._check_and_setup (self .dirname , create_dir , require_empty )
781794 self .kwargs = kwargs
782795
783796 @staticmethod
784- @idist .one_rank_only ()
785797 def _check_and_setup (dirname : Path , create_dir : bool , require_empty : bool ) -> None :
786798 if create_dir :
787799 if not dirname .exists ():
@@ -804,49 +816,36 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
804816 path = self .dirname / filename
805817
806818 if idist .has_xla_support :
807- self ._save_xla (checkpoint , path )
808- else :
809- self ._save_native (checkpoint , path )
810-
811- @idist .one_rank_only ()
812- def _save_native (self , checkpoint : Mapping , path : Path ) -> None :
813- self ._save_func (checkpoint , path , torch .save )
819+ import torch_xla .core .xla_model as xm
814820
815- def _save_xla (self , checkpoint : Mapping , path : Path ) -> None :
816- import torch_xla .core .xla_model as xm
821+ # all tpu procs should enter here as internally performs sync across device
822+ self ._save_func (checkpoint , path , xm .save )
823+ elif self .save_on_rank == idist .get_rank ():
824+ self ._save_func (checkpoint , path , torch .save )
817825
818- # all tpu procs should enter here as internally performs sync across device
819- self ._save_func (checkpoint , path , xm .save , rank = idist .get_rank ())
820-
821- def _save_func (self , checkpoint : Mapping , path : Path , func : Callable , rank : int = 0 ) -> None :
826+ def _save_func (self , checkpoint : Mapping , path : Path , func : Callable ) -> None :
822827 if not self ._atomic :
823828 func (checkpoint , path , ** self .kwargs )
824829 else :
825- tmp_file = None
826- tmp_name = ""
827- tmp : Optional [IO [bytes ]] = None
828- if rank == 0 :
829- tmp = tempfile .NamedTemporaryFile (delete = False , dir = self .dirname )
830- tmp_file = tmp .file
831- tmp_name = tmp .name
830+ tmp = tempfile .NamedTemporaryFile (delete = False , dir = self .dirname )
831+ tmp_file = tmp .file
832+ tmp_name = tmp .name
832833 try :
833834 func (checkpoint , tmp_file , ** self .kwargs )
834835 except BaseException :
835- if tmp is not None :
836- tmp .close ()
837- os .remove (tmp_name )
838- raise
836+ tmp .close ()
837+ os .remove (tmp_name )
838+ raise
839839 else :
840- if tmp is not None :
841- tmp .close ()
842- os .replace (tmp .name , path )
843- # append group/others read mode
844- os .chmod (path , os .stat (path ).st_mode | stat .S_IRGRP | stat .S_IROTH )
840+ tmp .close ()
841+ os .replace (tmp .name , path )
842+ # append group/others read mode
843+ os .chmod (path , os .stat (path ).st_mode | stat .S_IRGRP | stat .S_IROTH )
845844
846- @idist .one_rank_only ()
847845 def remove (self , filename : str ) -> None :
848- path = self .dirname / filename
849- path .unlink ()
846+ if idist .get_rank () == self .save_on_rank :
847+ path = self .dirname / filename
848+ path .unlink ()
850849
851850
852851class ModelCheckpoint (Checkpoint ):
@@ -901,14 +900,18 @@ class ModelCheckpoint(Checkpoint):
901900 there must not be another object in ``to_save`` with key ``checkpointer``.
902901 greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
903902 Default, `False`.
903+ save_on_rank: Which rank to save the objects on, in the distributed configuration. Used to
904+ instantiate a :class:`~ignite.handlers.DiskSaver` and is also passed to the parent class.
904905 kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`.
905906
906907 .. versionchanged:: 0.4.2
907908 Accept ``kwargs`` for `torch.save` or `xm.save`
908909
909910 .. versionchanged:: 0.5.0
910- Accept ``filename_pattern`` and ``greater_or_equal`` for parity
911- with :class:`~ignite.handlers.checkpoint.Checkpoint`
911+
912+ - ``filename_pattern`` and ``greater_or_equal`` for parity
913+ with :class:`~ignite.handlers.checkpoint.Checkpoint`
914+ - `save_on_rank` saves objects on this rank in a distributed configuration.
912915
913916 Examples:
914917 .. testcode:: python
@@ -945,10 +948,18 @@ def __init__(
945948 filename_pattern : Optional [str ] = None ,
946949 include_self : bool = False ,
947950 greater_or_equal : bool = False ,
951+ save_on_rank : Optional [int ] = 0 ,
948952 ** kwargs : Any ,
949953 ):
950954
951- disk_saver = DiskSaver (dirname , atomic = atomic , create_dir = create_dir , require_empty = require_empty , ** kwargs )
955+ disk_saver = DiskSaver (
956+ dirname ,
957+ atomic = atomic ,
958+ create_dir = create_dir ,
959+ require_empty = require_empty ,
960+ save_on_rank = save_on_rank ,
961+ ** kwargs ,
962+ )
952963
953964 super (ModelCheckpoint , self ).__init__ (
954965 to_save = {},
@@ -961,6 +972,7 @@ def __init__(
961972 filename_pattern = filename_pattern ,
962973 include_self = include_self ,
963974 greater_or_equal = greater_or_equal ,
975+ save_on_rank = save_on_rank ,
964976 )
965977
966978 @property
0 commit comments