Skip to content

Commit 31e3812

Browse files
committed
enhance 3d-party devices in mix-precision
1 parent 8ad3e29 commit 31e3812

File tree

15 files changed

+47
-14
lines changed

15 files changed

+47
-14
lines changed

src/lightning/fabric/connector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def __init__(
141141
self._accelerator_flag = self._choose_auto_accelerator()
142142
elif self._accelerator_flag == "gpu":
143143
self._accelerator_flag = self._choose_gpu_accelerator_backend()
144+
elif isinstance(self._accelerator_flag, Accelerator):
145+
pass # do nothing
144146

145147
self._set_parallel_devices_and_init_accelerator()
146148

@@ -461,7 +463,7 @@ def _check_and_init_precision(self) -> Precision:
461463
if isinstance(self.strategy, DeepSpeedStrategy):
462464
return DeepSpeedPrecision(self._precision_input) # type: ignore
463465
if isinstance(self.strategy, FSDPStrategy):
464-
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
466+
return FSDPPrecision(precision=self._precision_input, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None) # type: ignore[arg-type]
465467
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
466468
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
467469
raise ValueError(
@@ -493,6 +495,8 @@ def _check_and_init_precision(self) -> Precision:
493495
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494496
)
495497
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
498+
if isinstance(self._accelerator_flag, Accelerator):
499+
device = self._accelerator_flag.get_device()
496500
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
497501

498502
raise RuntimeError("No precision set")

src/lightning/fabric/plugins/precision/amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050

5151
self.precision = precision
5252
if scaler is None and self.precision == "16-mixed":
53-
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
53+
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
5454
if scaler is not None and self.precision == "bf16-mixed":
5555
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
5656
self.device = device

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ class FSDPPrecision(Precision):
4848
4949
"""
5050

51-
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
51+
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None:
5252
supported_precision = get_args(_PRECISION_INPUT)
5353
if precision not in supported_precision:
5454
raise ValueError(
5555
f"`precision={precision!r})` is not supported in FSDP."
5656
f" `precision` must be one of: {supported_precision}."
5757
)
58+
self.device = device if device is not None else "cuda"
5859

5960
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
6061

@@ -110,7 +111,7 @@ def module_init_context(self) -> ContextManager:
110111
@override
111112
def forward_context(self) -> ContextManager:
112113
if "mixed" in self.precision:
113-
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
114+
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
114115
return self.tensor_init_context()
115116

116117
@override

src/lightning/fabric/strategies/ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
124124
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
125125
device_ids = self._determine_ddp_device_ids()
126126
# https://pytorch.org/docs/stable/notes/cuda.html#id5
127-
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
127+
ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext()
128128
with ctx:
129129
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
130130

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def load_checkpoint(
506506

507507
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
508508

509-
torch.cuda.empty_cache()
509+
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache()
510510
_, client_state = engine.load_checkpoint(
511511
path,
512512
tag="checkpoint",

src/lightning/fabric/strategies/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def load_checkpoint(
325325
given, the full checkpoint will be returned.
326326
327327
"""
328-
torch.cuda.empty_cache()
328+
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache()
329329
checkpoint = self.checkpoint_io.load_checkpoint(path)
330330
if not state:
331331
return checkpoint

src/lightning/pytorch/accelerators/accelerator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
4545
4646
"""
4747
raise NotImplementedError
48+
49+
@staticmethod
50+
def get_device() -> str:
51+
"""Get the device for the current process."""
52+
raise NotImplementedError

src/lightning/pytorch/accelerators/cpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
8080
description=cls.__name__,
8181
)
8282

83+
@staticmethod
84+
@override
85+
def get_device() -> str:
86+
return "cpu"
87+
8388

8489
# CPU device metrics
8590
_CPU_VM_PERCENT = "cpu_vm_percent"

src/lightning/pytorch/accelerators/cuda.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
113113
description=cls.__name__,
114114
)
115115

116+
@staticmethod
117+
@override
118+
def get_device() -> str:
119+
return "cuda"
120+
116121

117122
def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
118123
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

src/lightning/pytorch/accelerators/mps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
8787
description=cls.__name__,
8888
)
8989

90+
@staticmethod
91+
@override
92+
def get_device() -> str:
93+
return "mps"
94+
9095

9196
# device metrics
9297
_VM_PERCENT = "M1_vm_percent"

0 commit comments

Comments
 (0)