Skip to content

Commit 8cde331

Browse files
committed
bugfix
1 parent 933036d commit 8cde331

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

Diff for: src/lightning/fabric/strategies/ddp.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def __init__(
7878
self._backward_sync_control = _DDPBackwardSyncControl()
7979
self._ddp_kwargs = kwargs
8080

81-
self.device_type = self.root_device.type
81+
if isinstance(self.accelerator, Accelerator):
82+
self.device_type = self.accelerator.get_device_type()
83+
else:
84+
self.device_type = "cuda"
8285
self.torch_lib = getattr(torch, self.device_type)
8386

8487
@property

Diff for: src/lightning/fabric/strategies/deepspeed.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,10 @@ def __init__(
299299

300300
self._deepspeed_engine: Optional[DeepSpeedEngine] = None
301301

302-
self.device_type = self.root_device.type
302+
if isinstance(self.accelerator, Accelerator):
303+
self.device_type = self.accelerator.get_device_type()
304+
else:
305+
self.device_type = "cuda"
303306
self.torch_lib = getattr(torch, self.device_type)
304307

305308
@property

Diff for: src/lightning/pytorch/strategies/ddp.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def __init__(
103103
self._timeout: Optional[timedelta] = timeout
104104
self._start_method = start_method
105105

106-
self.device_type = self.root_device.type
107-
self.torch_lib = getattr(torch, self.device_type)
106+
try:
107+
self.device_type = self.accelerator.get_device_type()
108+
except Exception:
109+
self.device_type = "cuda"
108110

109111
@property
110112
def is_distributed(self) -> bool: # pragma: no-cover

Diff for: src/lightning/pytorch/strategies/deepspeed.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,10 @@ def __init__(
319319
self.hysteresis = hysteresis
320320
self.min_loss_scale = min_loss_scale
321321

322-
self.device_type = self.root_device.type
322+
try:
323+
self.device_type = self.accelerator.get_device_type()
324+
except Exception:
325+
self.device_type = "cuda"
323326
self.torch_lib = getattr(torch, self.device_type)
324327

325328
@override

0 commit comments

Comments
 (0)