Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add save_on_rank argument to DiskSaver and Checkpoint #2641

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 54 additions & 42 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Dict, IO, List, Mapping, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -101,6 +101,8 @@ class Checkpoint(Serializable):
there must not be another object in ``to_save`` with key ``checkpointer``.
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
Default, `False`.
save_on_rank: Which rank to save the objects on, in the distributed configuration. If ``save_handler`` is
string or :class:`~pathlib.Path`, this is also used to instantiate a :class:`~ignite.handlers.DiskSaver`.

.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
Expand Down Expand Up @@ -164,9 +166,8 @@ class Checkpoint(Serializable):
> checkpoint_12345.pt

Note:
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only
process. This class supports automatically distributed configuration and if used with
:class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process.
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0
process only.

.. warning::

Expand Down Expand Up @@ -263,6 +264,7 @@ class Checkpoint(Serializable):

- `score_name` can be used to define `score_function` automatically without providing `score_function`.
- `save_handler` automatically saves to disk if path to directory is provided.
- `save_on_rank` saves objects on this rank in a distributed configuration.
"""

Item = NamedTuple("Item", [("priority", int), ("filename", str)])
Expand All @@ -280,6 +282,7 @@ def __init__(
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
save_on_rank: Optional[int] = 0,
):

if not isinstance(to_save, collections.Mapping):
Expand Down Expand Up @@ -312,7 +315,7 @@ def __init__(
self.to_save = to_save
self.filename_prefix = filename_prefix
if isinstance(save_handler, str) or isinstance(save_handler, Path):
self.save_handler = DiskSaver(save_handler, create_dir=True)
self.save_handler = DiskSaver(save_handler, create_dir=True, save_on_rank=save_on_rank)
else:
self.save_handler = save_handler # type: ignore
self.score_function = score_function
Expand All @@ -326,6 +329,7 @@ def __init__(
self._saved = [] # type: List["Checkpoint.Item"]
self.include_self = include_self
self.greater_or_equal = greater_or_equal
self.save_on_rank = save_on_rank

def _get_filename_pattern(self, global_step: Optional[int]) -> str:
if self.filename_pattern is None:
Expand Down Expand Up @@ -761,10 +765,15 @@ class DiskSaver(BaseSaveHandler):
create_dir: if True, will create directory ``dirname`` if it doesnt exist.
require_empty: If True, will raise exception if there are any files in the
directory ``dirname``.
save_on_rank: The rank on which the checkpoint will be saved. Used in distributed
configuration.
kwargs: Accepted keyword arguments for `torch.save` or `xm.save`.

.. versionchanged:: 0.4.2
Accept ``kwargs`` for `torch.save` or `xm.save`.

.. versionchanged:: 0.5.0
Argument ``save_on_rank`` was added to specify the rank on which checkpoint should be saved.
"""

def __init__(
Expand All @@ -773,15 +782,18 @@ def __init__(
atomic: bool = True,
create_dir: bool = True,
require_empty: bool = True,
save_on_rank: Optional[int] = 0,
**kwargs: Any,
):
self.dirname = Path(dirname).expanduser()
self._atomic = atomic
self._check_and_setup(self.dirname, create_dir, require_empty)
self.save_on_rank = save_on_rank

if idist.get_rank() == save_on_rank:
self._check_and_setup(self.dirname, create_dir, require_empty)
self.kwargs = kwargs

@staticmethod
@idist.one_rank_only()
def _check_and_setup(dirname: Path, create_dir: bool, require_empty: bool) -> None:
if create_dir:
if not dirname.exists():
Expand All @@ -804,49 +816,36 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
path = self.dirname / filename

if idist.has_xla_support:
self._save_xla(checkpoint, path)
else:
self._save_native(checkpoint, path)

@idist.one_rank_only()
def _save_native(self, checkpoint: Mapping, path: Path) -> None:
self._save_func(checkpoint, path, torch.save)
import torch_xla.core.xla_model as xm

def _save_xla(self, checkpoint: Mapping, path: Path) -> None:
import torch_xla.core.xla_model as xm
# all tpu procs should enter here as internally performs sync across device
self._save_func(checkpoint, path, xm.save)
elif self.save_on_rank == idist.get_rank():
self._save_func(checkpoint, path, torch.save)

# all tpu procs should enter here as internally performs sync across device
self._save_func(checkpoint, path, xm.save, rank=idist.get_rank())

def _save_func(self, checkpoint: Mapping, path: Path, func: Callable, rank: int = 0) -> None:
def _save_func(self, checkpoint: Mapping, path: Path, func: Callable) -> None:
if not self._atomic:
func(checkpoint, path, **self.kwargs)
else:
tmp_file = None
tmp_name = ""
tmp: Optional[IO[bytes]] = None
if rank == 0:
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
tmp_file = tmp.file
tmp_name = tmp.name
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
tmp_file = tmp.file
tmp_name = tmp.name
try:
func(checkpoint, tmp_file, **self.kwargs)
except BaseException:
if tmp is not None:
tmp.close()
os.remove(tmp_name)
raise
tmp.close()
os.remove(tmp_name)
raise
else:
if tmp is not None:
tmp.close()
os.replace(tmp.name, path)
# append group/others read mode
os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH)
tmp.close()
os.replace(tmp.name, path)
# append group/others read mode
os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH)

@idist.one_rank_only()
def remove(self, filename: str) -> None:
path = self.dirname / filename
path.unlink()
if idist.get_rank() == self.save_on_rank:
path = self.dirname / filename
path.unlink()


class ModelCheckpoint(Checkpoint):
Expand Down Expand Up @@ -901,14 +900,18 @@ class ModelCheckpoint(Checkpoint):
there must not be another object in ``to_save`` with key ``checkpointer``.
greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, the first model.
Default, `False`.
save_on_rank: Which rank to save the objects on, in the distributed configuration. Used to
instantiate a :class:`~ignite.handlers.DiskSaver` and is also passed to the parent class.
kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`.

.. versionchanged:: 0.4.2
Accept ``kwargs`` for `torch.save` or `xm.save`

.. versionchanged:: 0.5.0
Accept ``filename_pattern`` and ``greater_or_equal`` for parity
with :class:`~ignite.handlers.checkpoint.Checkpoint`

- ``filename_pattern`` and ``greater_or_equal`` for parity
with :class:`~ignite.handlers.checkpoint.Checkpoint`
- `save_on_rank` saves objects on this rank in a distributed configuration.

Examples:
.. testcode:: python
Expand Down Expand Up @@ -945,10 +948,18 @@ def __init__(
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
save_on_rank: Optional[int] = 0,
**kwargs: Any,
):

disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs)
disk_saver = DiskSaver(
dirname,
atomic=atomic,
create_dir=create_dir,
require_empty=require_empty,
save_on_rank=save_on_rank,
**kwargs,
)

super(ModelCheckpoint, self).__init__(
to_save={},
Expand All @@ -961,6 +972,7 @@ def __init__(
filename_pattern=filename_pattern,
include_self=include_self,
greater_or_equal=greater_or_equal,
save_on_rank=save_on_rank,
)

@property
Expand Down
41 changes: 29 additions & 12 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def test_disk_saver_atomic(dirname):
to_save_serializable = {"model": model}
to_save_non_serializable = {"model": lambda x: x}

def _test_existance(atomic, _to_save, expected):
def _test_existence(atomic, _to_save, expected):

saver = DiskSaver(dirname, atomic=atomic, create_dir=False, require_empty=False)
fname = "test.pt"
Expand All @@ -652,11 +652,11 @@ def _test_existance(atomic, _to_save, expected):
if expected:
saver.remove(fname)

_test_existance(atomic=False, _to_save=to_save_serializable, expected=True)
_test_existance(atomic=False, _to_save=to_save_non_serializable, expected=True)
_test_existence(atomic=False, _to_save=to_save_serializable, expected=True)
_test_existence(atomic=False, _to_save=to_save_non_serializable, expected=True)

_test_existance(atomic=True, _to_save=to_save_serializable, expected=True)
_test_existance(atomic=True, _to_save=to_save_non_serializable, expected=False)
_test_existence(atomic=True, _to_save=to_save_serializable, expected=True)
_test_existence(atomic=True, _to_save=to_save_non_serializable, expected=False)


@pytest.mark.skipif(
Expand Down Expand Up @@ -856,7 +856,7 @@ def test_valid_state_dict_save(dirname):
pytest.fail("Unexpected ValueError")


def _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname, on_zero_rank=False):
def _test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname, just_on_zero_rank=False):

torch.manual_seed(23)

Expand Down Expand Up @@ -885,7 +885,7 @@ def update_fn(engine, batch):

engine = Engine(update_fn)

if (not on_zero_rank) or (on_zero_rank and idist.get_rank() == 0):
if (not just_on_zero_rank) or (just_on_zero_rank and idist.get_rank() == 0):
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=True, n_saved=1)

engine.add_event_handler(
Expand Down Expand Up @@ -942,7 +942,7 @@ def test_save_model_optimizer_lr_scheduler_with_state_dict(dirname):
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname)


def _test_save_model_optimizer_lr_scheduler_with_validation(device, dirname, on_zero_rank=False):
def _test_save_model_optimizer_lr_scheduler_with_validation(device, dirname, just_on_zero_rank=False):
torch.manual_seed(23)

def _build_objects(acc_list):
Expand Down Expand Up @@ -1248,9 +1248,9 @@ def _test_checkpoint_load_objects_ddp(device):
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, get_rank_zero_dirname):

device = idist.device()
dirname = get_rank_zero_dirname()
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "1")
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "2", on_zero_rank=True)
rank_zero_dirname = get_rank_zero_dirname()
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "1")
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, rank_zero_dirname / "2", just_on_zero_rank=True)
_test_checkpoint_with_ddp(device)
_test_checkpoint_load_objects_ddp(device)

Expand All @@ -1263,7 +1263,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl, get_rank_zero_di
device = idist.device()
dirname = get_rank_zero_dirname()
_test_save_model_optimizer_lr_scheduler_with_state_dict(device, dirname / "1")
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", on_zero_rank=True)
_test_save_model_optimizer_lr_scheduler_with_state_dict("cpu", dirname / "2", just_on_zero_rank=True)
_test_checkpoint_with_ddp(device=device)
_test_checkpoint_load_objects_ddp(device=device)

Expand Down Expand Up @@ -1784,3 +1784,20 @@ def test_load_single_object(obj_to_save, dirname):

checkpoint_fp = dirname / c.last_checkpoint
Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.parametrize("atomic", [False, True])
def test_disksaver_distrib(distributed_context_single_node_gloo, dirname, local_rank, atomic):

saver = DiskSaver(dirname, atomic, save_on_rank=1)
mocked_saver = MagicMock(wraps=saver)

mocked_saver(checkpoint={}, filename="test_disksaver_distrib.pt")

if local_rank == 1:
assert (dirname / "test_disksaver_distrib.pt").exists()

else:
mocked_saver._save_func.assert_not_called()