Skip to content

Commit f863645

Browse files
committed
enhance fsdp for 3rd devices
1 parent 31e3812 commit f863645

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,16 @@ def _check_strategy_and_fallback(self) -> None:
463463

464464
if (
465465
strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
466-
) and self._accelerator_flag not in ("cuda", "gpu"):
466+
) and self._accelerator_flag not in ("cuda", "gpu") and isinstance(self._accelerator_flag, str):
467467
raise ValueError(
468468
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:"
469469
f" {self._accelerator_flag}"
470470
)
471+
elif isinstance(self._accelerator_flag, Accelerator):
472+
Warning(
473+
f"Using a custom accelerator `{self._accelerator_flag.__class__.__name__}`."
474+
f" Please ensure it is compatible with the selected strategy `{strategy_flag}`."
475+
)
471476
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
472477
raise ValueError(
473478
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
@@ -501,7 +506,7 @@ def _check_and_init_precision(self) -> Precision:
501506
if isinstance(self.strategy, DeepSpeedStrategy):
502507
return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type]
503508
if isinstance(self.strategy, FSDPStrategy):
504-
return FSDPPrecision(precision=self._precision_input, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None)
509+
return FSDPPrecision(precision=self._precision_flag, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None)
505510
if self._precision_flag in ("16-true", "bf16-true"):
506511
return HalfPrecision(self._precision_flag) # type: ignore
507512
if self._precision_flag == "32-true":

0 commit comments

Comments
 (0)