Skip to content

Fix wrong behavior of DDPStrategy option with simple GAN training using DDP #20936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions examples/pytorch/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@
from lightning.pytorch import cli_lightning_logo
from lightning.pytorch.core import LightningModule
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision


def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list:
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers


class Generator(nn.Module):
"""
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Expand All @@ -47,19 +56,11 @@ class Generator(nn.Module):
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
*_block(latent_dim, 128, normalize=False),
*_block(128, 256),
*_block(256, 512),
*_block(512, 1024),
nn.Linear(1024, int(math.prod(img_shape))),
nn.Tanh(),
)
Expand Down Expand Up @@ -209,10 +210,18 @@ def main(args: Namespace) -> None:
# ------------------------
# 2 INIT TRAINER
# ------------------------
# If use distributed training PyTorch recommends to use DistributedDataParallel.
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
dm = MNISTDataModule()
trainer = Trainer(accelerator="gpu", devices=1)

if args.use_ddp:
# `MultiModelDDPStrategy` is critical for multi-gpu GAN training
# There are two ways to run training codes with existed `DDPStrategy`:
# 1) Activate `find_unused_parameters` option
# 2) change from self.manual_backward(loss) to loss.backward()
# Neither of them is desirable.
trainer = Trainer(accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
else:
# If you want to run on a single GPU, you can use the default strategy.
trainer = Trainer(accelerator="gpu", devices=1)

# ------------------------
# 3 START TRAINING
Expand All @@ -229,6 +238,8 @@ def main(args: Namespace) -> None:
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--use_ddp", action="store_true", help="distributed strategy to use")

args = parser.parse_args()

main(args)
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.utilities.registry import _register_classes
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.strategies.ddp import DDPStrategy, MultiModelDDPStrategy
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
Expand All @@ -30,6 +30,7 @@

__all__ = [
"DDPStrategy",
"MultiModelDDPStrategy",
"DeepSpeedStrategy",
"FSDPStrategy",
"ModelParallelStrategy",
Expand Down
43 changes: 42 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def set_world_ranks(self) -> None:
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank

def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
log.debug(f"{self.__class__.__name__}: registering DDP hooks")
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type == "cuda":
Expand Down Expand Up @@ -419,6 +419,47 @@ def teardown(self) -> None:
super().teardown()


class MultiModelDDPStrategy(DDPStrategy):
"""Specific strategy for training on multiple models with multiple optimizers (e.g. GAN training).

This strategy wraps each individual child module in :class:`~torch.nn.parallel.distributed.DistributedDataParallel`
module. Ensures manual backward only updates parameters of the targeted child module, preventing cross-references
between modules' parameters.

"""

@override
def _setup_model(self, model: Module) -> DistributedDataParallel:
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
for name, module in model.named_children():
if isinstance(module, Module):
ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs)
setattr(model, name, ddp_module)
return model

@override
def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering DDP hooks")
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type != "cuda":
return
assert isinstance(self.model, Module)

for name, module in self.model.named_children():
assert isinstance(module, DistributedDataParallel)
_register_ddp_comm_hook(
model=module,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)


class _DDPForwardRedirection(_ForwardRedirection):
@override
def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:
Expand Down
174 changes: 174 additions & 0 deletions tests/tests_pytorch/strategies/test_multi_model_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from torch.multiprocessing import ProcessRaisedException

from lightning.pytorch import Trainer
from lightning.pytorch.strategies import MultiModelDDPStrategy
from lightning.pytorch.trainer import seed_everything
from tests_pytorch.helpers.advanced_models import BasicGAN
from tests_pytorch.helpers.runif import RunIf


@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True)
def test_multi_gpu_with_multi_model_ddp_fit_only(tmp_path):
dm = BasicGAN.train_dataloader()
model = BasicGAN()
trainer = Trainer(
default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy()
)
trainer.fit(model, datamodule=dm)


@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True)
def test_multi_gpu_with_multi_model_ddp_predict_only(tmp_path):
dm = BasicGAN.train_dataloader()
model = BasicGAN()
trainer = Trainer(
default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy()
)
trainer.predict(model, datamodule=dm)


@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True)
def test_multi_gpu_multi_model_ddp_fit_predict(tmp_path):
seed_everything(4321)
dm = BasicGAN.train_dataloader()
model = BasicGAN()
trainer = Trainer(
default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy()
)
trainer.fit(model, datamodule=dm)
trainer.predict(model, datamodule=dm)


class UnusedParametersBasicGAN(BasicGAN):
def __init__(self):
super().__init__()
mnist_shape = (1, 28, 28)
self.intermediate_layer = torch.nn.Linear(mnist_shape[-1], mnist_shape[-1])

def training_step(self, batch, batch_idx):
with torch.no_grad():
img = self.intermediate_layer(batch[0])
batch[0] = img # modify the batch to use the intermediate layer result
return super().training_step(batch, batch_idx)


@RunIf(standalone=True)
def test_find_unused_parameters_ddp_spawn_raises():
"""Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users."""
trainer = Trainer(
accelerator="cpu",
devices=1,
strategy=MultiModelDDPStrategy(),
max_steps=2,
logger=False,
)
with pytest.raises(
ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in"
):
trainer.fit(UnusedParametersBasicGAN())


@RunIf(standalone=True)
def test_find_unused_parameters_ddp_exception():
"""Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users."""
trainer = Trainer(
accelerator="cpu",
devices=1,
strategy=MultiModelDDPStrategy(),
max_steps=2,
logger=False,
)
with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"):
trainer.fit(UnusedParametersBasicGAN())


class CheckOptimizerDeviceModel(BasicGAN):
def configure_optimizers(self):
assert all(param.device.type == "cuda" for param in self.parameters())
super().configure_optimizers()


@RunIf(min_cuda_gpus=1)
def test_model_parameters_on_device_for_optimizer():
"""Test that the strategy has moved the parameters to the device by the time the optimizer gets created."""
model = CheckOptimizerDeviceModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
fast_dev_run=1,
accelerator="gpu",
devices=1,
strategy=MultiModelDDPStrategy(),
)
trainer.fit(model)


class BasicGANCPU(BasicGAN):
def on_train_start(self) -> None:
# make sure that the model is on CPU when training
assert self.device == torch.device("cpu")


@RunIf(skip_windows=True)
def test_multi_model_ddp_with_cpu():
"""Tests if device is set correctly when training for MultiModelDDPStrategy."""
trainer = Trainer(
accelerator="cpu",
devices=-1,
strategy=MultiModelDDPStrategy(),
fast_dev_run=True,
)
# assert strategy attributes for device setting
assert isinstance(trainer.strategy, MultiModelDDPStrategy)
assert trainer.strategy.root_device == torch.device("cpu")
model = BasicGANCPU()
trainer.fit(model)


class BasicGANGPU(BasicGAN):
def on_train_start(self) -> None:
# make sure that the model is on GPU when training
assert self.device == torch.device(f"cuda:{self.trainer.strategy.local_rank}")
self.start_cuda_memory = torch.cuda.memory_allocated()


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_multi_model_ddp_with_gpus():
"""Tests if device is set correctly when training and after teardown for MultiModelDDPStrategy."""
trainer = Trainer(
accelerator="gpu",
devices=-1,
strategy=MultiModelDDPStrategy(),
fast_dev_run=True,
enable_progress_bar=False,
enable_model_summary=False,
)
# assert strategy attributes for device setting
assert isinstance(trainer.strategy, MultiModelDDPStrategy)
local_rank = trainer.strategy.local_rank
assert trainer.strategy.root_device == torch.device(f"cuda:{local_rank}")

model = BasicGANGPU()

trainer.fit(model)

# assert after training, model is moved to CPU and memory is deallocated
assert model.device == torch.device("cpu")
cuda_memory = torch.cuda.memory_allocated()
assert cuda_memory < model.start_cuda_memory
Loading