Skip to content

Commit 83f8cc7

Browse files
Implement feature (#2641)
1 parent d34f1c2 commit 83f8cc7

File tree

2 files changed

+83
-54
lines changed

2 files changed

+83
-54
lines changed

ignite/handlers/checkpoint.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import ABCMeta, abstractmethod
88
from collections import OrderedDict
99
from pathlib import Path
10-
from typing import Any, Callable, Dict, IO, List, Mapping, NamedTuple, Optional, Tuple, Union
10+
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -101,6 +101,8 @@ class Checkpoint(Serializable):
101101
there must not be another object in ``to_save`` with key ``checkpointer``.
102102
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
103103
Default, `False`.
104+
save_on_rank: Which rank to save the objects on, in the distributed configuration. If ``save_handler`` is
105+
string or :class:`~pathlib.Path`, this is also used to instantiate a :class:`~ignite.handlers.DiskSaver`.
104106
105107
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
106108
torch.nn.parallel.DistributedDataParallel.html
@@ -164,9 +166,8 @@ class Checkpoint(Serializable):
164166
> checkpoint_12345.pt
165167
166168
Note:
167-
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only
168-
process. This class supports automatically distributed configuration and if used with
169-
:class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process.
169+
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0
170+
process only.
170171
171172
.. warning::
172173
@@ -263,6 +264,7 @@ class Checkpoint(Serializable):
263264
264265
- `score_name` can be used to define `score_function` automatically without providing `score_function`.
265266
- `save_handler` automatically saves to disk if path to directory is provided.
267+
- `save_on_rank` saves objects on this rank in a distributed configuration.
266268
"""
267269

268270
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
@@ -280,6 +282,7 @@ def __init__(
280282
filename_pattern: Optional[str] = None,
281283
include_self: bool = False,
282284
greater_or_equal: bool = False,
285+
save_on_rank: Optional[int] = 0,
283286
):
284287

285288
if not isinstance(to_save, collections.Mapping):
@@ -312,7 +315,7 @@ def __init__(
312315
self.to_save = to_save
313316
self.filename_prefix = filename_prefix
314317
if isinstance(save_handler, str) or isinstance(save_handler, Path):
315-
self.save_handler = DiskSaver(save_handler, create_dir=True)
318+
self.save_handler = DiskSaver(save_handler, create_dir=True, save_on_rank=save_on_rank)
316319
else:
317320
self.save_handler = save_handler # type: ignore
318321
self.score_function = score_function
@@ -326,6 +329,7 @@ def __init__(
326329
self._saved = [] # type: List["Checkpoint.Item"]
327330
self.include_self = include_self
328331
self.greater_or_equal = greater_or_equal
332+
self.save_on_rank = save_on_rank
329333

330334
def _get_filename_pattern(self, global_step: Optional[int]) -> str:
331335
if self.filename_pattern is None:
@@ -761,10 +765,15 @@ class DiskSaver(BaseSaveHandler):
761765
create_dir: if True, will create directory ``dirname`` if it doesnt exist.
762766
require_empty: If True, will raise exception if there are any files in the
763767
directory ``dirname``.
768+
save_on_rank: The rank on which the checkpoint will be saved. Used in distributed
769+
configuration.
764770
kwargs: Accepted keyword arguments for `torch.save` or `xm.save`.
765771
766772
.. versionchanged:: 0.4.2
767773
Accept ``kwargs`` for `torch.save` or `xm.save`.
774+
775+
.. versionchanged:: 0.5.0
776+
Argument ``save_on_rank`` was added to specify the rank on which checkpoint should be saved.
768777
"""
769778

770779
def __init__(
@@ -773,15 +782,18 @@ def __init__(
773782
atomic: bool = True,
774783
create_dir: bool = True,
775784
require_empty: bool = True,
785+
save_on_rank: Optional[int] = 0,
776786
**kwargs: Any,
777787
):
778788
self.dirname = Path(dirname).expanduser()
779789
self._atomic = atomic
780-
self._check_and_setup(self.dirname, create_dir, require_empty)
790+
self.save_on_rank = save_on_rank
791+
792+
if idist.get_rank() == save_on_rank:
793+
self._check_and_setup(self.dirname, create_dir, require_empty)
781794
self.kwargs = kwargs
782795

783796
@staticmethod
784-
@idist.one_rank_only()
785797
def _check_and_setup(dirname: Path, create_dir: bool, require_empty: bool) -> None:
786798
if create_dir:
787799
if not dirname.exists():
@@ -804,49 +816,36 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
804816
path = self.dirname / filename
805817

806818
if idist.has_xla_support:
807-
self._save_xla(checkpoint, path)
808-
else:
809-
self._save_native(checkpoint, path)
810-
811-
@idist.one_rank_only()
812-
def _save_native(self, checkpoint: Mapping, path: Path) -> None:
813-
self._save_func(checkpoint, path, torch.save)
819+
import torch_xla.core.xla_model as xm
814820

815-
def _save_xla(self, checkpoint: Mapping, path: Path) -> None:
816-
import torch_xla.core.xla_model as xm
821+
# all tpu procs should enter here as internally performs sync across device
822+
self._save_func(checkpoint, path, xm.save)
823+
elif self.save_on_rank == idist.get_rank():
824+
self._save_func(checkpoint, path, torch.save)
817825

818-
# all tpu procs should enter here as internally performs sync across device
819-
self._save_func(checkpoint, path, xm.save, rank=idist.get_rank())
820-
821-
def _save_func(self, checkpoint: Mapping, path: Path, func: Callable, rank: int = 0) -> None:
826+
def _save_func(self, checkpoint: Mapping, path: Path, func: Callable) -> None:
822827
if not self._atomic:
823828
func(checkpoint, path, **self.kwargs)
824829
else:
825-
tmp_file = None
826-
tmp_name = ""
827-
tmp: Optional[IO[bytes]] = None
828-
if rank == 0:
829-
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
830-
tmp_file = tmp.file
831-
tmp_name = tmp.name
830+
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
831+
tmp_file = tmp.file
832+
tmp_name = tmp.name
832833
try:
833834
func(checkpoint, tmp_file, **self.kwargs)
834835
except BaseException:
835-
if tmp is not None:
836-
tmp.close()
837-
os.remove(tmp_name)
838-
raise
836+
tmp.close()
837+
os.remove(tmp_name)
838+
raise
839839
else:
840-
if tmp is not None:
841-
tmp.close()
842-
os.replace(tmp.name, path)
843-
# append group/others read mode
844-
os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH)
840+
tmp.close()
841+
os.replace(tmp.name, path)
842+
# append group/others read mode
843+
os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH)
845844

846-
@idist.one_rank_only()
847845
def remove(self, filename: str) -> None:
848-
path = self.dirname / filename
849-
path.unlink()
846+
if idist.get_rank() == self.save_on_rank:
847+
path = self.dirname / filename
848+
path.unlink()
850849

851850

852851
class ModelCheckpoint(Checkpoint):
@@ -901,14 +900,18 @@ class ModelCheckpoint(Checkpoint):
901900
there must not be another object in ``to_save`` with key ``checkpointer``.
902901
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
903902
Default, `False`.
903+
save_on_rank: Which rank to save the objects on, in the distributed configuration. Used to
904+
instantiate a :class:`~ignite.handlers.DiskSaver` and is also passed to the parent class.
904905
kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`.
905906
906907
.. versionchanged:: 0.4.2
907908
Accept ``kwargs`` for `torch.save` or `xm.save`
908909
909910
.. versionchanged:: 0.5.0
910-
Accept ``filename_pattern`` and ``greater_or_equal`` for parity
911-
with :class:`~ignite.handlers.checkpoint.Checkpoint`
911+
912+
- ``filename_pattern`` and ``greater_or_equal`` for parity
913+
with :class:`~ignite.handlers.checkpoint.Checkpoint`
914+
- `save_on_rank` saves objects on this rank in a distributed configuration.
912915
913916
Examples:
914917
.. testcode:: python
@@ -945,10 +948,18 @@ def __init__(
945948
filename_pattern: Optional[str] = None,
946949
include_self: bool = False,
947950
greater_or_equal: bool = False,
951+
save_on_rank: Optional[int] = 0,
948952
**kwargs: Any,
949953
):
950954

951-
disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs)
955+
disk_saver = DiskSaver(
956+
dirname,
957+
atomic=atomic,
958+
create_dir=create_dir,
959+
require_empty=require_empty,
960+
save_on_rank=save_on_rank,
961+
**kwargs,
962+
)
952963

953964
super(ModelCheckpoint, self).__init__(
954965
to_save={},
@@ -961,6 +972,7 @@ def __init__(
961972
filename_pattern=filename_pattern,
962973
include_self=include_self,
963974
greater_or_equal=greater_or_equal,
975+
save_on_rank=save_on_rank,
964976
)
965977

966978
@property

tests/ignite/handlers/test_checkpoint.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def test_disk_saver_atomic(dirname):
629629
to_save_serializable = {"model": model}
630630
to_save_non_serializable = {"model": lambda x: x}
631631

632-
def _test_existance(atomic, _to_save, expected):
632+
def _test_existence(atomic, _to_save, expected):
633633

634634
saver = DiskSaver(dirname, atomic=atomic, create_dir=False, require_empty=False)
635635
fname = "test.pt"
@@ -652,11 +652,11 @@ def _test_existance(atomic, _to_save, expected):
652652
if expected:
653653
saver.remove(fname)
654654

655-
_test_existance(atomic=False, _to_save=to_save_serializable, expected=True)
656-
_test_existance(atomic=False, _to_save=to_save_non_serializable, expected=True)
655+
_test_existence(atomic=False, _to_save=to_save_serializable, expected=True)
656+
_test_existence(atomic=False, _to_save=to_save_non_serializable, expected=True)
657657

658-
_test_existance(atomic=True, _to_save=to_save_serializable, expected=True)
659-
_test_existance(atomic=True, _to_save=to_save_non_serializable, expected=False)
658+
_test_existence(atomic=True, _to_save=to_save_serializable, expected=True)
659+
_test_existence(atomic=True, _to_save=to_save_non_serializable, expected=False)
660660

661661

662662
@pytest.mark.skipif(
@@ -856,7 +856,7 @@ def test_valid_state_dict_save(dirname):
856856
pytest.fail("Unexpected ValueError")
857857

858858

859-
def _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname, on_zero_rank=False):
859+
def _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname, just_on_zero_rank=False):
860860

861861
torch.manual_seed(23)
862862

@@ -885,7 +885,7 @@ def update_fn(engine, batch):
885885

886886
engine = Engine(update_fn)
887887

888-
if (not on_zero_rank) or (on_zero_rank and idist.get_rank() == 0):
888+
if (not just_on_zero_rank) or (just_on_zero_rank and idist.get_rank() == 0):
889889
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=True, n_saved=1)
890890

891891
engine.add_event_handler(
@@ -942,7 +942,7 @@ def test_save_model_optimizer_lr_scheduler_with_state_dict(dirname):
942942
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname)
943943

944944

945-
def _test_save_model_optimizer_lr_scheduler_with_validation(device, dirname, on_zero_rank=False):
945+
def _test_save_model_optimizer_lr_scheduler_with_validation(device, dirname, just_on_zero_rank=False):
946946
torch.manual_seed(23)
947947

948948
def _build_objects(acc_list):
@@ -1248,9 +1248,9 @@ def _test_checkpoint_load_objects_ddp(device):
12481248
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, get_rank_zero_dirname):
12491249

12501250
device = idist.device()
1251-
dirname = get_rank_zero_dirname()
1252-
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "1")
1253-
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "2", on_zero_rank=True)
1251+
rank_zero_dirname = get_rank_zero_dirname()
1252+
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "1")
1253+
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "2", just_on_zero_rank=True)
12541254
_test_checkpoint_with_ddp(device)
12551255
_test_checkpoint_load_objects_ddp(device)
12561256

@@ -1263,7 +1263,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl, get_rank_zero_di
12631263
device = idist.device()
12641264
dirname = get_rank_zero_dirname()
12651265
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "1")
1266-
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", on_zero_rank=True)
1266+
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", just_on_zero_rank=True)
12671267
_test_checkpoint_with_ddp(device=device)
12681268
_test_checkpoint_load_objects_ddp(device=device)
12691269

@@ -1784,3 +1784,20 @@ def test_load_single_object(obj_to_save, dirname):
17841784

17851785
checkpoint_fp = dirname / c.last_checkpoint
17861786
Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))
1787+
1788+
1789+
@pytest.mark.distributed
1790+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
1791+
@pytest.mark.parametrize("atomic", [False, True])
1792+
def test_disksaver_distrib(distributed_context_single_node_gloo, dirname, local_rank, atomic):
1793+
1794+
saver = DiskSaver(dirname, atomic, save_on_rank=1)
1795+
mocked_saver = MagicMock(wraps=saver)
1796+
1797+
mocked_saver(checkpoint={}, filename="test_disksaver_distrib.pt")
1798+
1799+
if local_rank == 1:
1800+
assert (dirname / "test_disksaver_distrib.pt").exists()
1801+
1802+
else:
1803+
mocked_saver._save_func.assert_not_called()

0 commit comments

Comments
 (0)