Skip to content

Commit

Permalink
Implement feature (#2641)
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin authored Aug 14, 2022
1 parent d34f1c2 commit 83f8cc7
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 54 deletions.
96 changes: 54 additions & 42 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Dict, IO, List, Mapping, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -101,6 +101,8 @@ class Checkpoint(Serializable):
there must not be another object in ``to_save`` with key ``checkpointer``.
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
Default, `False`.
save_on_rank: Which rank to save the objects on, in the distributed configuration. If ``save_handler`` is
string or :class:`~pathlib.Path`, this is also used to instantiate a :class:`~ignite.handlers.DiskSaver`.
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
Expand Down Expand Up @@ -164,9 +166,8 @@ class Checkpoint(Serializable):
> checkpoint_12345.pt
Note:
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only
process. This class supports automatically distributed configuration and if used with
:class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process.
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0
process only.
.. warning::
Expand Down Expand Up @@ -263,6 +264,7 @@ class Checkpoint(Serializable):
- `score_name` can be used to define `score_function` automatically without providing `score_function`.
- `save_handler` automatically saves to disk if path to directory is provided.
- `save_on_rank` saves objects on this rank in a distributed configuration.
"""

Item = NamedTuple("Item", [("priority", int), ("filename", str)])
Expand All @@ -280,6 +282,7 @@ def __init__(
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
save_on_rank: Optional[int] = 0,
):

if not isinstance(to_save, collections.Mapping):
Expand Down Expand Up @@ -312,7 +315,7 @@ def __init__(
self.to_save = to_save
self.filename_prefix = filename_prefix
if isinstance(save_handler, str) or isinstance(save_handler, Path):
self.save_handler = DiskSaver(save_handler, create_dir=True)
self.save_handler = DiskSaver(save_handler, create_dir=True, save_on_rank=save_on_rank)
else:
self.save_handler = save_handler # type: ignore
self.score_function = score_function
Expand All @@ -326,6 +329,7 @@ def __init__(
self._saved = [] # type: List["Checkpoint.Item"]
self.include_self = include_self
self.greater_or_equal = greater_or_equal
self.save_on_rank = save_on_rank

def _get_filename_pattern(self, global_step: Optional[int]) -> str:
if self.filename_pattern is None:
Expand Down Expand Up @@ -761,10 +765,15 @@ class DiskSaver(BaseSaveHandler):
create_dir: if True, will create directory ``dirname`` if it doesnt exist.
require_empty: If True, will raise exception if there are any files in the
directory ``dirname``.
save_on_rank: The rank on which the checkpoint will be saved. Used in distributed
configuration.
kwargs: Accepted keyword arguments for `torch.save` or `xm.save`.
.. versionchanged:: 0.4.2
Accept ``kwargs`` for `torch.save` or `xm.save`.
.. versionchanged:: 0.5.0
Argument ``save_on_rank`` was added to specify the rank on which checkpoint should be saved.
"""

def __init__(
Expand All @@ -773,15 +782,18 @@ def __init__(
atomic: bool = True,
create_dir: bool = True,
require_empty: bool = True,
save_on_rank: Optional[int] = 0,
**kwargs: Any,
):
self.dirname = Path(dirname).expanduser()
self._atomic = atomic
self._check_and_setup(self.dirname, create_dir, require_empty)
self.save_on_rank = save_on_rank

if idist.get_rank() == save_on_rank:
self._check_and_setup(self.dirname, create_dir, require_empty)
self.kwargs = kwargs

@staticmethod
@idist.one_rank_only()
def _check_and_setup(dirname: Path, create_dir: bool, require_empty: bool) -> None:
if create_dir:
if not dirname.exists():
Expand All @@ -804,49 +816,36 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
path = self.dirname / filename

if idist.has_xla_support:
self._save_xla(checkpoint, path)
else:
self._save_native(checkpoint, path)

@idist.one_rank_only()
def _save_native(self, checkpoint: Mapping, path: Path) -> None:
self._save_func(checkpoint, path, torch.save)
import torch_xla.core.xla_model as xm

def _save_xla(self, checkpoint: Mapping, path: Path) -> None:
import torch_xla.core.xla_model as xm
# all tpu procs should enter here as internally performs sync across device
self._save_func(checkpoint, path, xm.save)
elif self.save_on_rank == idist.get_rank():
self._save_func(checkpoint, path, torch.save)

# all tpu procs should enter here as internally performs sync across device
self._save_func(checkpoint, path, xm.save, rank=idist.get_rank())

def _save_func(self, checkpoint: Mapping, path: Path, func: Callable, rank: int = 0) -> None:
def _save_func(self, checkpoint: Mapping, path: Path, func: Callable) -> None:
if not self._atomic:
func(checkpoint, path, **self.kwargs)
else:
tmp_file = None
tmp_name = ""
tmp: Optional[IO[bytes]] = None
if rank == 0:
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
tmp_file = tmp.file
tmp_name = tmp.name
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
tmp_file = tmp.file
tmp_name = tmp.name
try:
func(checkpoint, tmp_file, **self.kwargs)
except BaseException:
if tmp is not None:
tmp.close()
os.remove(tmp_name)
raise
tmp.close()
os.remove(tmp_name)
raise
else:
if tmp is not None:
tmp.close()
os.replace(tmp.name, path)
# append group/others read mode
os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH)
tmp.close()
os.replace(tmp.name, path)
# append group/others read mode
os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH)

@idist.one_rank_only()
def remove(self, filename: str) -> None:
path = self.dirname / filename
path.unlink()
if idist.get_rank() == self.save_on_rank:
path = self.dirname / filename
path.unlink()


class ModelCheckpoint(Checkpoint):
Expand Down Expand Up @@ -901,14 +900,18 @@ class ModelCheckpoint(Checkpoint):
there must not be another object in ``to_save`` with key ``checkpointer``.
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
Default, `False`.
save_on_rank: Which rank to save the objects on, in the distributed configuration. Used to
instantiate a :class:`~ignite.handlers.DiskSaver` and is also passed to the parent class.
kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`.
.. versionchanged:: 0.4.2
Accept ``kwargs`` for `torch.save` or `xm.save`
.. versionchanged:: 0.5.0
Accept ``filename_pattern`` and ``greater_or_equal`` for parity
with :class:`~ignite.handlers.checkpoint.Checkpoint`
- ``filename_pattern`` and ``greater_or_equal`` for parity
with :class:`~ignite.handlers.checkpoint.Checkpoint`
- `save_on_rank` saves objects on this rank in a distributed configuration.
Examples:
.. testcode:: python
Expand Down Expand Up @@ -945,10 +948,18 @@ def __init__(
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
save_on_rank: Optional[int] = 0,
**kwargs: Any,
):

disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs)
disk_saver = DiskSaver(
dirname,
atomic=atomic,
create_dir=create_dir,
require_empty=require_empty,
save_on_rank=save_on_rank,
**kwargs,
)

super(ModelCheckpoint, self).__init__(
to_save={},
Expand All @@ -961,6 +972,7 @@ def __init__(
filename_pattern=filename_pattern,
include_self=include_self,
greater_or_equal=greater_or_equal,
save_on_rank=save_on_rank,
)

@property
Expand Down
41 changes: 29 additions & 12 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def test_disk_saver_atomic(dirname):
to_save_serializable = {"model": model}
to_save_non_serializable = {"model": lambda x: x}

def _test_existance(atomic, _to_save, expected):
def _test_existence(atomic, _to_save, expected):

saver = DiskSaver(dirname, atomic=atomic, create_dir=False, require_empty=False)
fname = "test.pt"
Expand All @@ -652,11 +652,11 @@ def _test_existance(atomic, _to_save, expected):
if expected:
saver.remove(fname)

_test_existance(atomic=False, _to_save=to_save_serializable, expected=True)
_test_existance(atomic=False, _to_save=to_save_non_serializable, expected=True)
_test_existence(atomic=False, _to_save=to_save_serializable, expected=True)
_test_existence(atomic=False, _to_save=to_save_non_serializable, expected=True)

_test_existance(atomic=True, _to_save=to_save_serializable, expected=True)
_test_existance(atomic=True, _to_save=to_save_non_serializable, expected=False)
_test_existence(atomic=True, _to_save=to_save_serializable, expected=True)
_test_existence(atomic=True, _to_save=to_save_non_serializable, expected=False)


@pytest.mark.skipif(
Expand Down Expand Up @@ -856,7 +856,7 @@ def test_valid_state_dict_save(dirname):
pytest.fail("Unexpected ValueError")


def _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname, on_zero_rank=False):
def _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname, just_on_zero_rank=False):

torch.manual_seed(23)

Expand Down Expand Up @@ -885,7 +885,7 @@ def update_fn(engine, batch):

engine = Engine(update_fn)

if (not on_zero_rank) or (on_zero_rank and idist.get_rank() == 0):
if (not just_on_zero_rank) or (just_on_zero_rank and idist.get_rank() == 0):
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=True, n_saved=1)

engine.add_event_handler(
Expand Down Expand Up @@ -942,7 +942,7 @@ def test_save_model_optimizer_lr_scheduler_with_state_dict(dirname):
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname)


def _test_save_model_optimizer_lr_scheduler_with_validation(device, dirname, on_zero_rank=False):
def _test_save_model_optimizer_lr_scheduler_with_validation(device, dirname, just_on_zero_rank=False):
torch.manual_seed(23)

def _build_objects(acc_list):
Expand Down Expand Up @@ -1248,9 +1248,9 @@ def _test_checkpoint_load_objects_ddp(device):
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, get_rank_zero_dirname):

device = idist.device()
dirname = get_rank_zero_dirname()
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "1")
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "2", on_zero_rank=True)
rank_zero_dirname = get_rank_zero_dirname()
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "1")
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "2", just_on_zero_rank=True)
_test_checkpoint_with_ddp(device)
_test_checkpoint_load_objects_ddp(device)

Expand All @@ -1263,7 +1263,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl, get_rank_zero_di
device = idist.device()
dirname = get_rank_zero_dirname()
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "1")
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", on_zero_rank=True)
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", just_on_zero_rank=True)
_test_checkpoint_with_ddp(device=device)
_test_checkpoint_load_objects_ddp(device=device)

Expand Down Expand Up @@ -1784,3 +1784,20 @@ def test_load_single_object(obj_to_save, dirname):

checkpoint_fp = dirname / c.last_checkpoint
Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.parametrize("atomic", [False, True])
def test_disksaver_distrib(distributed_context_single_node_gloo, dirname, local_rank, atomic):

saver = DiskSaver(dirname, atomic, save_on_rank=1)
mocked_saver = MagicMock(wraps=saver)

mocked_saver(checkpoint={}, filename="test_disksaver_distrib.pt")

if local_rank == 1:
assert (dirname / "test_disksaver_distrib.pt").exists()

else:
mocked_saver._save_func.assert_not_called()

0 comments on commit 83f8cc7

Please sign in to comment.