From 246f8367e0a7faf7f4c384584c5657c35837b854 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 5 Feb 2025 09:13:16 +0000 Subject: [PATCH] fix simple trainer tests --- optimum/habana/transformers/trainer.py | 17 +++++++++++------ optimum/habana/transformers/training_args.py | 20 ++++++++++---------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 943ceab098..eecb6750b2 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -747,7 +747,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) # Wrap `_gradient_checkpointing_func` in the model with `transformer_engine` `activation_checkpointing` context. - if self.accelerator.state.is_fp8_enabled: + if self.accelerator.state.mixed_precision == "fp8": FP8ContextWrapper.gradient_checkpointing_wrap(self.model) else: # Hack because `RegressionModel` in test_trainer.py doesn't have `gradient_checkpointing_disable` @@ -1540,7 +1540,7 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): # Merge autocast context and `fp8_autocast` context if FP8 is enabled. # Currently FP8 is enabled only for training. - if self.accelerator.state.is_fp8_enabled and self.model.training: + if self.accelerator.state.mixed_precision == "fp8" and self.model.training: ctx_manager = FP8ContextWrapper(ctx_manager, self.accelerator.fp8_recipe_handler) return ctx_manager @@ -1597,7 +1597,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te self.accelerator.backward(loss, **kwargs) self.model.base_model.update_and_allocate(self.state.global_step) else: - if self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing: + if self.accelerator.state.mixed_precision == "fp8" and self.args.gradient_checkpointing: # The precision used in backward pass should be same as the one used in forward pass. # However when training with gradient_checkpointing and FP8 precision, recompute forward # in backward does not automatically run with FP8 precision. In order to handle this, @@ -2460,14 +2460,19 @@ def create_accelerator_and_postprocess(self): args = { "deepspeed_plugin": self.args.deepspeed_plugin, "gradient_accumulation_plugin": gradient_accumulation_plugin, - "distribution_strategy": self.args.distribution_strategy, - "dynamic": self.args.compile_dynamic, + # "distribution_strategy": self.args.distribution_strategy, + # "dynamic": self.args.compile_dynamic, "dataloader_config": dataloader_config, - "use_regional_compilation": self.args.use_regional_compilation, + # "use_regional_compilation": self.args.use_regional_compilation, } # create accelerator object self.accelerator = Accelerator(**args) + + # we patch accelerator with the mpu for now + from ..distributed import parallel_state + + self.accelerator.mpu = parallel_state # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 04984e5763..720038786d 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -312,14 +312,14 @@ class GaudiTrainingArguments(TrainingArguments): }, ) - # Overriding ddp_backend to replace all possible backends by hccl - ddp_backend: Optional[str] = field( - default="hccl", - metadata={ - "help": "The backend to be used for distributed training.", - "choices": ["hccl"], - }, - ) + # # Overriding ddp_backend to replace all possible backends by hccl + # ddp_backend: Optional[str] = field( + # default="hccl", + # metadata={ + # "help": "The backend to be used for distributed training.", + # "choices": ["hccl"], + # }, + # ) sdp_on_bf16: bool = field( default=False, @@ -913,7 +913,7 @@ def _setup_devices(self) -> "torch.device": self._n_gpu = 1 if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): accelerator_state_kwargs["cpu"] = True - accelerator_state_kwargs["backend"] = self.ddp_backend + # accelerator_state_kwargs["backend"] = None self._n_gpu = 0 elif self.use_habana: # Some methods needs to be tweaked to optimally run on Gaudi @@ -935,7 +935,7 @@ def _setup_devices(self) -> "torch.device": accelerator_state_kwargs["use_deepspeed"] = True accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) else: - accelerator_state_kwargs["backend"] = self.ddp_backend + # accelerator_state_kwargs["backend"] = self.ddp_backend accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) accelerator_state_kwargs["context_parallel_size"] = self.context_parallel_size accelerator_state_kwargs["minimize_memory"] = self.minimize_memory