7
7
from abc import ABCMeta , abstractmethod
8
8
from collections import OrderedDict
9
9
from pathlib import Path
10
- from typing import Any , Callable , Dict , IO , List , Mapping , NamedTuple , Optional , Tuple , Union
10
+ from typing import Any , Callable , Dict , List , Mapping , NamedTuple , Optional , Tuple , Union
11
11
12
12
import torch
13
13
import torch .nn as nn
@@ -101,6 +101,8 @@ class Checkpoint(Serializable):
101
101
there must not be another object in ``to_save`` with key ``checkpointer``.
102
102
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
103
103
Default, `False`.
104
+ save_on_rank: Which rank to save the objects on, in the distributed configuration. If ``save_handler`` is
105
+ string or :class:`~pathlib.Path`, this is also used to instantiate a :class:`~ignite.handlers.DiskSaver`.
104
106
105
107
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
106
108
torch.nn.parallel.DistributedDataParallel.html
@@ -164,9 +166,8 @@ class Checkpoint(Serializable):
164
166
> checkpoint_12345.pt
165
167
166
168
Note:
167
- This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only
168
- process. This class supports automatically distributed configuration and if used with
169
- :class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process.
169
+ This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0
170
+ process only.
170
171
171
172
.. warning::
172
173
@@ -263,6 +264,7 @@ class Checkpoint(Serializable):
263
264
264
265
- `score_name` can be used to define `score_function` automatically without providing `score_function`.
265
266
- `save_handler` automatically saves to disk if path to directory is provided.
267
+ - `save_on_rank` saves objects on this rank in a distributed configuration.
266
268
"""
267
269
268
270
Item = NamedTuple ("Item" , [("priority" , int ), ("filename" , str )])
@@ -280,6 +282,7 @@ def __init__(
280
282
filename_pattern : Optional [str ] = None ,
281
283
include_self : bool = False ,
282
284
greater_or_equal : bool = False ,
285
+ save_on_rank : Optional [int ] = 0 ,
283
286
):
284
287
285
288
if not isinstance (to_save , collections .Mapping ):
@@ -312,7 +315,7 @@ def __init__(
312
315
self .to_save = to_save
313
316
self .filename_prefix = filename_prefix
314
317
if isinstance (save_handler , str ) or isinstance (save_handler , Path ):
315
- self .save_handler = DiskSaver (save_handler , create_dir = True )
318
+ self .save_handler = DiskSaver (save_handler , create_dir = True , save_on_rank = save_on_rank )
316
319
else :
317
320
self .save_handler = save_handler # type: ignore
318
321
self .score_function = score_function
@@ -326,6 +329,7 @@ def __init__(
326
329
self ._saved = [] # type: List["Checkpoint.Item"]
327
330
self .include_self = include_self
328
331
self .greater_or_equal = greater_or_equal
332
+ self .save_on_rank = save_on_rank
329
333
330
334
def _get_filename_pattern (self , global_step : Optional [int ]) -> str :
331
335
if self .filename_pattern is None :
@@ -761,10 +765,15 @@ class DiskSaver(BaseSaveHandler):
761
765
create_dir: if True, will create directory ``dirname`` if it doesnt exist.
762
766
require_empty: If True, will raise exception if there are any files in the
763
767
directory ``dirname``.
768
+ save_on_rank: The rank on which the checkpoint will be saved. Used in distributed
769
+ configuration.
764
770
kwargs: Accepted keyword arguments for `torch.save` or `xm.save`.
765
771
766
772
.. versionchanged:: 0.4.2
767
773
Accept ``kwargs`` for `torch.save` or `xm.save`.
774
+
775
+ .. versionchanged:: 0.5.0
776
+ Argument ``save_on_rank`` was added to specify the rank on which checkpoint should be saved.
768
777
"""
769
778
770
779
def __init__ (
@@ -773,15 +782,18 @@ def __init__(
773
782
atomic : bool = True ,
774
783
create_dir : bool = True ,
775
784
require_empty : bool = True ,
785
+ save_on_rank : Optional [int ] = 0 ,
776
786
** kwargs : Any ,
777
787
):
778
788
self .dirname = Path (dirname ).expanduser ()
779
789
self ._atomic = atomic
780
- self ._check_and_setup (self .dirname , create_dir , require_empty )
790
+ self .save_on_rank = save_on_rank
791
+
792
+ if idist .get_rank () == save_on_rank :
793
+ self ._check_and_setup (self .dirname , create_dir , require_empty )
781
794
self .kwargs = kwargs
782
795
783
796
@staticmethod
784
- @idist .one_rank_only ()
785
797
def _check_and_setup (dirname : Path , create_dir : bool , require_empty : bool ) -> None :
786
798
if create_dir :
787
799
if not dirname .exists ():
@@ -804,49 +816,36 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
804
816
path = self .dirname / filename
805
817
806
818
if idist .has_xla_support :
807
- self ._save_xla (checkpoint , path )
808
- else :
809
- self ._save_native (checkpoint , path )
810
-
811
- @idist .one_rank_only ()
812
- def _save_native (self , checkpoint : Mapping , path : Path ) -> None :
813
- self ._save_func (checkpoint , path , torch .save )
819
+ import torch_xla .core .xla_model as xm
814
820
815
- def _save_xla (self , checkpoint : Mapping , path : Path ) -> None :
816
- import torch_xla .core .xla_model as xm
821
+ # all tpu procs should enter here as internally performs sync across device
822
+ self ._save_func (checkpoint , path , xm .save )
823
+ elif self .save_on_rank == idist .get_rank ():
824
+ self ._save_func (checkpoint , path , torch .save )
817
825
818
- # all tpu procs should enter here as internally performs sync across device
819
- self ._save_func (checkpoint , path , xm .save , rank = idist .get_rank ())
820
-
821
- def _save_func (self , checkpoint : Mapping , path : Path , func : Callable , rank : int = 0 ) -> None :
826
+ def _save_func (self , checkpoint : Mapping , path : Path , func : Callable ) -> None :
822
827
if not self ._atomic :
823
828
func (checkpoint , path , ** self .kwargs )
824
829
else :
825
- tmp_file = None
826
- tmp_name = ""
827
- tmp : Optional [IO [bytes ]] = None
828
- if rank == 0 :
829
- tmp = tempfile .NamedTemporaryFile (delete = False , dir = self .dirname )
830
- tmp_file = tmp .file
831
- tmp_name = tmp .name
830
+ tmp = tempfile .NamedTemporaryFile (delete = False , dir = self .dirname )
831
+ tmp_file = tmp .file
832
+ tmp_name = tmp .name
832
833
try :
833
834
func (checkpoint , tmp_file , ** self .kwargs )
834
835
except BaseException :
835
- if tmp is not None :
836
- tmp .close ()
837
- os .remove (tmp_name )
838
- raise
836
+ tmp .close ()
837
+ os .remove (tmp_name )
838
+ raise
839
839
else :
840
- if tmp is not None :
841
- tmp .close ()
842
- os .replace (tmp .name , path )
843
- # append group/others read mode
844
- os .chmod (path , os .stat (path ).st_mode | stat .S_IRGRP | stat .S_IROTH )
840
+ tmp .close ()
841
+ os .replace (tmp .name , path )
842
+ # append group/others read mode
843
+ os .chmod (path , os .stat (path ).st_mode | stat .S_IRGRP | stat .S_IROTH )
845
844
846
- @idist .one_rank_only ()
847
845
def remove (self , filename : str ) -> None :
848
- path = self .dirname / filename
849
- path .unlink ()
846
+ if idist .get_rank () == self .save_on_rank :
847
+ path = self .dirname / filename
848
+ path .unlink ()
850
849
851
850
852
851
class ModelCheckpoint (Checkpoint ):
@@ -901,14 +900,18 @@ class ModelCheckpoint(Checkpoint):
901
900
there must not be another object in ``to_save`` with key ``checkpointer``.
902
901
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
903
902
Default, `False`.
903
+ save_on_rank: Which rank to save the objects on, in the distributed configuration. Used to
904
+ instantiate a :class:`~ignite.handlers.DiskSaver` and is also passed to the parent class.
904
905
kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`.
905
906
906
907
.. versionchanged:: 0.4.2
907
908
Accept ``kwargs`` for `torch.save` or `xm.save`
908
909
909
910
.. versionchanged:: 0.5.0
910
- Accept ``filename_pattern`` and ``greater_or_equal`` for parity
911
- with :class:`~ignite.handlers.checkpoint.Checkpoint`
911
+
912
+ - ``filename_pattern`` and ``greater_or_equal`` for parity
913
+ with :class:`~ignite.handlers.checkpoint.Checkpoint`
914
+ - `save_on_rank` saves objects on this rank in a distributed configuration.
912
915
913
916
Examples:
914
917
.. testcode:: python
@@ -945,10 +948,18 @@ def __init__(
945
948
filename_pattern : Optional [str ] = None ,
946
949
include_self : bool = False ,
947
950
greater_or_equal : bool = False ,
951
+ save_on_rank : Optional [int ] = 0 ,
948
952
** kwargs : Any ,
949
953
):
950
954
951
- disk_saver = DiskSaver (dirname , atomic = atomic , create_dir = create_dir , require_empty = require_empty , ** kwargs )
955
+ disk_saver = DiskSaver (
956
+ dirname ,
957
+ atomic = atomic ,
958
+ create_dir = create_dir ,
959
+ require_empty = require_empty ,
960
+ save_on_rank = save_on_rank ,
961
+ ** kwargs ,
962
+ )
952
963
953
964
super (ModelCheckpoint , self ).__init__ (
954
965
to_save = {},
@@ -961,6 +972,7 @@ def __init__(
961
972
filename_pattern = filename_pattern ,
962
973
include_self = include_self ,
963
974
greater_or_equal = greater_or_equal ,
975
+ save_on_rank = save_on_rank ,
964
976
)
965
977
966
978
@property
0 commit comments