From 83f8cc7c8889e39006b9d95499cb4b2e7976f22d Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Mon, 15 Aug 2022 01:40:27 +0430 Subject: [PATCH] Implement feature (#2641) --- ignite/handlers/checkpoint.py | 96 +++++++++++++----------- tests/ignite/handlers/test_checkpoint.py | 41 +++++++--- 2 files changed, 83 insertions(+), 54 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 124c9744687..17edf058b1a 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod from collections import OrderedDict from pathlib import Path -from typing import Any, Callable, Dict, IO, List, Mapping, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union import torch import torch.nn as nn @@ -101,6 +101,8 @@ class Checkpoint(Serializable): there must not be another object in ``to_save`` with key ``checkpointer``. greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model. Default, `False`. + save_on_rank: Which rank to save the objects on, in the distributed configuration. If ``save_handler`` is + string or :class:`~pathlib.Path`, this is also used to instantiate a :class:`~ignite.handlers.DiskSaver`. .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ torch.nn.parallel.DistributedDataParallel.html @@ -164,9 +166,8 @@ class Checkpoint(Serializable): > checkpoint_12345.pt Note: - This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only - process. This class supports automatically distributed configuration and if used with - :class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process. + This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 + process only. .. warning:: @@ -263,6 +264,7 @@ class Checkpoint(Serializable): - `score_name` can be used to define `score_function` automatically without providing `score_function`. - `save_handler` automatically saves to disk if path to directory is provided. + - `save_on_rank` saves objects on this rank in a distributed configuration. """ Item = NamedTuple("Item", [("priority", int), ("filename", str)]) @@ -280,6 +282,7 @@ def __init__( filename_pattern: Optional[str] = None, include_self: bool = False, greater_or_equal: bool = False, + save_on_rank: Optional[int] = 0, ): if not isinstance(to_save, collections.Mapping): @@ -312,7 +315,7 @@ def __init__( self.to_save = to_save self.filename_prefix = filename_prefix if isinstance(save_handler, str) or isinstance(save_handler, Path): - self.save_handler = DiskSaver(save_handler, create_dir=True) + self.save_handler = DiskSaver(save_handler, create_dir=True, save_on_rank=save_on_rank) else: self.save_handler = save_handler # type: ignore self.score_function = score_function @@ -326,6 +329,7 @@ def __init__( self._saved = [] # type: List["Checkpoint.Item"] self.include_self = include_self self.greater_or_equal = greater_or_equal + self.save_on_rank = save_on_rank def _get_filename_pattern(self, global_step: Optional[int]) -> str: if self.filename_pattern is None: @@ -761,10 +765,15 @@ class DiskSaver(BaseSaveHandler): create_dir: if True, will create directory ``dirname`` if it doesnt exist. require_empty: If True, will raise exception if there are any files in the directory ``dirname``. + save_on_rank: The rank on which the checkpoint will be saved. Used in distributed + configuration. kwargs: Accepted keyword arguments for `torch.save` or `xm.save`. .. versionchanged:: 0.4.2 Accept ``kwargs`` for `torch.save` or `xm.save`. + + .. versionchanged:: 0.5.0 + Argument ``save_on_rank`` was added to specify the rank on which checkpoint should be saved. """ def __init__( @@ -773,15 +782,18 @@ def __init__( atomic: bool = True, create_dir: bool = True, require_empty: bool = True, + save_on_rank: Optional[int] = 0, **kwargs: Any, ): self.dirname = Path(dirname).expanduser() self._atomic = atomic - self._check_and_setup(self.dirname, create_dir, require_empty) + self.save_on_rank = save_on_rank + + if idist.get_rank() == save_on_rank: + self._check_and_setup(self.dirname, create_dir, require_empty) self.kwargs = kwargs @staticmethod - @idist.one_rank_only() def _check_and_setup(dirname: Path, create_dir: bool, require_empty: bool) -> None: if create_dir: if not dirname.exists(): @@ -804,49 +816,36 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin path = self.dirname / filename if idist.has_xla_support: - self._save_xla(checkpoint, path) - else: - self._save_native(checkpoint, path) - - @idist.one_rank_only() - def _save_native(self, checkpoint: Mapping, path: Path) -> None: - self._save_func(checkpoint, path, torch.save) + import torch_xla.core.xla_model as xm - def _save_xla(self, checkpoint: Mapping, path: Path) -> None: - import torch_xla.core.xla_model as xm + # all tpu procs should enter here as internally performs sync across device + self._save_func(checkpoint, path, xm.save) + elif self.save_on_rank == idist.get_rank(): + self._save_func(checkpoint, path, torch.save) - # all tpu procs should enter here as internally performs sync across device - self._save_func(checkpoint, path, xm.save, rank=idist.get_rank()) - - def _save_func(self, checkpoint: Mapping, path: Path, func: Callable, rank: int = 0) -> None: + def _save_func(self, checkpoint: Mapping, path: Path, func: Callable) -> None: if not self._atomic: func(checkpoint, path, **self.kwargs) else: - tmp_file = None - tmp_name = "" - tmp: Optional[IO[bytes]] = None - if rank == 0: - tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) - tmp_file = tmp.file - tmp_name = tmp.name + tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) + tmp_file = tmp.file + tmp_name = tmp.name try: func(checkpoint, tmp_file, **self.kwargs) except BaseException: - if tmp is not None: - tmp.close() - os.remove(tmp_name) - raise + tmp.close() + os.remove(tmp_name) + raise else: - if tmp is not None: - tmp.close() - os.replace(tmp.name, path) - # append group/others read mode - os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH) + tmp.close() + os.replace(tmp.name, path) + # append group/others read mode + os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH) - @idist.one_rank_only() def remove(self, filename: str) -> None: - path = self.dirname / filename - path.unlink() + if idist.get_rank() == self.save_on_rank: + path = self.dirname / filename + path.unlink() class ModelCheckpoint(Checkpoint): @@ -901,14 +900,18 @@ class ModelCheckpoint(Checkpoint): there must not be another object in ``to_save`` with key ``checkpointer``. greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model. Default, `False`. + save_on_rank: Which rank to save the objects on, in the distributed configuration. Used to + instantiate a :class:`~ignite.handlers.DiskSaver` and is also passed to the parent class. kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`. .. versionchanged:: 0.4.2 Accept ``kwargs`` for `torch.save` or `xm.save` .. versionchanged:: 0.5.0 - Accept ``filename_pattern`` and ``greater_or_equal`` for parity - with :class:`~ignite.handlers.checkpoint.Checkpoint` + + - ``filename_pattern`` and ``greater_or_equal`` for parity + with :class:`~ignite.handlers.checkpoint.Checkpoint` + - `save_on_rank` saves objects on this rank in a distributed configuration. Examples: .. testcode:: python @@ -945,10 +948,18 @@ def __init__( filename_pattern: Optional[str] = None, include_self: bool = False, greater_or_equal: bool = False, + save_on_rank: Optional[int] = 0, **kwargs: Any, ): - disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs) + disk_saver = DiskSaver( + dirname, + atomic=atomic, + create_dir=create_dir, + require_empty=require_empty, + save_on_rank=save_on_rank, + **kwargs, + ) super(ModelCheckpoint, self).__init__( to_save={}, @@ -961,6 +972,7 @@ def __init__( filename_pattern=filename_pattern, include_self=include_self, greater_or_equal=greater_or_equal, + save_on_rank=save_on_rank, ) @property diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index cbd34fce17d..a6a2d7e78df 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -629,7 +629,7 @@ def test_disk_saver_atomic(dirname): to_save_serializable = {"model": model} to_save_non_serializable = {"model": lambda x: x} - def _test_existance(atomic, _to_save, expected): + def _test_existence(atomic, _to_save, expected): saver = DiskSaver(dirname, atomic=atomic, create_dir=False, require_empty=False) fname = "test.pt" @@ -652,11 +652,11 @@ def _test_existance(atomic, _to_save, expected): if expected: saver.remove(fname) - _test_existance(atomic=False, _to_save=to_save_serializable, expected=True) - _test_existance(atomic=False, _to_save=to_save_non_serializable, expected=True) + _test_existence(atomic=False, _to_save=to_save_serializable, expected=True) + _test_existence(atomic=False, _to_save=to_save_non_serializable, expected=True) - _test_existance(atomic=True, _to_save=to_save_serializable, expected=True) - _test_existance(atomic=True, _to_save=to_save_non_serializable, expected=False) + _test_existence(atomic=True, _to_save=to_save_serializable, expected=True) + _test_existence(atomic=True, _to_save=to_save_non_serializable, expected=False) @pytest.mark.skipif( @@ -856,7 +856,7 @@ def test_valid_state_dict_save(dirname): pytest.fail("Unexpected ValueError") -def _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname, on_zero_rank=False): +def _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname, just_on_zero_rank=False): torch.manual_seed(23) @@ -885,7 +885,7 @@ def update_fn(engine, batch): engine = Engine(update_fn) - if (not on_zero_rank) or (on_zero_rank and idist.get_rank() == 0): + if (not just_on_zero_rank) or (just_on_zero_rank and idist.get_rank() == 0): handler = ModelCheckpoint(dirname, _PREFIX, create_dir=True, n_saved=1) engine.add_event_handler( @@ -942,7 +942,7 @@ def test_save_model_optimizer_lr_scheduler_with_state_dict(dirname): _test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname) -def _test_save_model_optimizer_lr_scheduler_with_validation(device, dirname, on_zero_rank=False): +def _test_save_model_optimizer_lr_scheduler_with_validation(device, dirname, just_on_zero_rank=False): torch.manual_seed(23) def _build_objects(acc_list): @@ -1248,9 +1248,9 @@ def _test_checkpoint_load_objects_ddp(device): def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, get_rank_zero_dirname): device = idist.device() - dirname = get_rank_zero_dirname() - _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "1") - _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "2", on_zero_rank=True) + rank_zero_dirname = get_rank_zero_dirname() + _test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "1") + _test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "2", just_on_zero_rank=True) _test_checkpoint_with_ddp(device) _test_checkpoint_load_objects_ddp(device) @@ -1263,7 +1263,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl, get_rank_zero_di device = idist.device() dirname = get_rank_zero_dirname() _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "1") - _test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", on_zero_rank=True) + _test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", just_on_zero_rank=True) _test_checkpoint_with_ddp(device=device) _test_checkpoint_load_objects_ddp(device=device) @@ -1784,3 +1784,20 @@ def test_load_single_object(obj_to_save, dirname): checkpoint_fp = dirname / c.last_checkpoint Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp)) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.parametrize("atomic", [False, True]) +def test_disksaver_distrib(distributed_context_single_node_gloo, dirname, local_rank, atomic): + + saver = DiskSaver(dirname, atomic, save_on_rank=1) + mocked_saver = MagicMock(wraps=saver) + + mocked_saver(checkpoint={}, filename="test_disksaver_distrib.pt") + + if local_rank == 1: + assert (dirname / "test_disksaver_distrib.pt").exists() + + else: + mocked_saver._save_func.assert_not_called()