@@ -463,11 +463,16 @@ def _check_strategy_and_fallback(self) -> None:
463
463
464
464
if (
465
465
strategy_flag in FSDPStrategy .get_registered_strategies () or type (self ._strategy_flag ) is FSDPStrategy
466
- ) and self ._accelerator_flag not in ("cuda" , "gpu" ):
466
+ ) and self ._accelerator_flag not in ("cuda" , "gpu" ) and isinstance ( self . _accelerator_flag , str ) :
467
467
raise ValueError (
468
468
f"The strategy `{ FSDPStrategy .strategy_name } ` requires a GPU accelerator, but got:"
469
469
f" { self ._accelerator_flag } "
470
470
)
471
+ elif isinstance (self ._accelerator_flag , Accelerator ):
472
+ Warning (
473
+ f"Using a custom accelerator `{ self ._accelerator_flag .__class__ .__name__ } `."
474
+ f" Please ensure it is compatible with the selected strategy `{ strategy_flag } `."
475
+ )
471
476
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch .multiprocessing .get_all_start_methods ():
472
477
raise ValueError (
473
478
f"You selected `Trainer(strategy='{ strategy_flag } ')` but process forking is not supported on this"
@@ -501,7 +506,7 @@ def _check_and_init_precision(self) -> Precision:
501
506
if isinstance (self .strategy , DeepSpeedStrategy ):
502
507
return DeepSpeedPrecision (self ._precision_flag ) # type: ignore[arg-type]
503
508
if isinstance (self .strategy , FSDPStrategy ):
504
- return FSDPPrecision (precision = self ._precision_input , device = self ._accelerator_flag .get_device () if isinstance (self ._accelerator_flag , Accelerator ) else None )
509
+ return FSDPPrecision (precision = self ._precision_flag , device = self ._accelerator_flag .get_device () if isinstance (self ._accelerator_flag , Accelerator ) else None )
505
510
if self ._precision_flag in ("16-true" , "bf16-true" ):
506
511
return HalfPrecision (self ._precision_flag ) # type: ignore
507
512
if self ._precision_flag == "32-true" :
0 commit comments