Skip to content

Commit 01a931d

Browse files
committed
update ddp
1 parent d352b4c commit 01a931d

File tree

2 files changed

+27
-14
lines changed
  • src/lightning

2 files changed

+27
-14
lines changed

Diff for: src/lightning/fabric/strategies/ddp.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
124124
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
125125
device_ids = self._determine_ddp_device_ids()
126126
# https://pytorch.org/docs/stable/notes/cuda.html#id5
127-
ctx = (
128-
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
129-
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
130-
)
131-
if device_ids is not None
132-
else nullcontext()
133-
)
127+
ctx = self._create_stream_context(device_ids=device_ids)
134128
with ctx:
135129
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
136130

@@ -234,6 +228,18 @@ def _set_world_ranks(self) -> None:
234228
def _determine_ddp_device_ids(self) -> Optional[list[int]]:
235229
return None if self.root_device.type == "cpu" else [self.root_device.index]
236230

231+
def _create_stream_context(self, device_ids=None):
232+
"""Create a stream context for the current device, if supported."""
233+
234+
torch_lib = getattr(torch, self.root_device.type)
235+
# Check if the device type supports streams and has the necessary attributes.
236+
if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None:
237+
stream = torch_lib.Stream()
238+
ctx = torch_lib.stream(stream)
239+
else:
240+
ctx = nullcontext()
241+
return ctx
242+
237243

238244
class _DDPBackwardSyncControl(_BackwardSyncControl):
239245
@override

Diff for: src/lightning/pytorch/strategies/ddp.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
190190
device_ids = self.determine_ddp_device_ids()
191191
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
192192
# https://pytorch.org/docs/stable/notes/cuda.html#id5
193-
ctx = (
194-
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
195-
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
196-
)
197-
if device_ids is not None
198-
else nullcontext()
199-
)
193+
ctx = self._create_stream_context(device_ids=device_ids)
200194
with ctx:
201195
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
202196

@@ -424,6 +418,19 @@ def teardown(self) -> None:
424418

425419
super().teardown()
426420

421+
def _create_stream_context(self, device_ids=None):
422+
"""Create a stream context for the current device, if supported."""
423+
424+
torch_lib = getattr(torch, self.root_device.type)
425+
# Check if the device type supports streams and has the necessary attributes.
426+
if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None:
427+
# ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
428+
stream = torch_lib.Stream()
429+
ctx = torch_lib.stream(stream)
430+
else:
431+
ctx = nullcontext()
432+
return ctx
433+
427434

428435
class _DDPForwardRedirection(_ForwardRedirection):
429436
@override

0 commit comments

Comments
 (0)