Skip to content

Commit

Permalink
fix simple trainer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 5, 2025
1 parent 37ff6b0 commit 246f836
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
17 changes: 11 additions & 6 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)

# Wrap `_gradient_checkpointing_func` in the model with `transformer_engine` `activation_checkpointing` context.
if self.accelerator.state.is_fp8_enabled:
if self.accelerator.state.mixed_precision == "fp8":
FP8ContextWrapper.gradient_checkpointing_wrap(self.model)
else:
# Hack because `RegressionModel` in test_trainer.py doesn't have `gradient_checkpointing_disable`
Expand Down Expand Up @@ -1540,7 +1540,7 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):

# Merge autocast context and `fp8_autocast` context if FP8 is enabled.
# Currently FP8 is enabled only for training.
if self.accelerator.state.is_fp8_enabled and self.model.training:
if self.accelerator.state.mixed_precision == "fp8" and self.model.training:
ctx_manager = FP8ContextWrapper(ctx_manager, self.accelerator.fp8_recipe_handler)

return ctx_manager
Expand Down Expand Up @@ -1597,7 +1597,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te
self.accelerator.backward(loss, **kwargs)
self.model.base_model.update_and_allocate(self.state.global_step)
else:
if self.accelerator.state.is_fp8_enabled and self.args.gradient_checkpointing:
if self.accelerator.state.mixed_precision == "fp8" and self.args.gradient_checkpointing:
# The precision used in backward pass should be same as the one used in forward pass.
# However when training with gradient_checkpointing and FP8 precision, recompute forward
# in backward does not automatically run with FP8 precision. In order to handle this,
Expand Down Expand Up @@ -2460,14 +2460,19 @@ def create_accelerator_and_postprocess(self):
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
"distribution_strategy": self.args.distribution_strategy,
"dynamic": self.args.compile_dynamic,
# "distribution_strategy": self.args.distribution_strategy,
# "dynamic": self.args.compile_dynamic,
"dataloader_config": dataloader_config,
"use_regional_compilation": self.args.use_regional_compilation,
# "use_regional_compilation": self.args.use_regional_compilation,
}

# create accelerator object
self.accelerator = Accelerator(**args)

# we patch accelerator with the mpu for now
from ..distributed import parallel_state

self.accelerator.mpu = parallel_state
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics

Expand Down
20 changes: 10 additions & 10 deletions optimum/habana/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,14 @@ class GaudiTrainingArguments(TrainingArguments):
},
)

# Overriding ddp_backend to replace all possible backends by hccl
ddp_backend: Optional[str] = field(
default="hccl",
metadata={
"help": "The backend to be used for distributed training.",
"choices": ["hccl"],
},
)
# # Overriding ddp_backend to replace all possible backends by hccl
# ddp_backend: Optional[str] = field(
# default="hccl",
# metadata={
# "help": "The backend to be used for distributed training.",
# "choices": ["hccl"],
# },
# )

sdp_on_bf16: bool = field(
default=False,
Expand Down Expand Up @@ -913,7 +913,7 @@ def _setup_devices(self) -> "torch.device":
self._n_gpu = 1
if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
accelerator_state_kwargs["cpu"] = True
accelerator_state_kwargs["backend"] = self.ddp_backend
# accelerator_state_kwargs["backend"] = None
self._n_gpu = 0
elif self.use_habana:
# Some methods needs to be tweaked to optimally run on Gaudi
Expand All @@ -935,7 +935,7 @@ def _setup_devices(self) -> "torch.device":
accelerator_state_kwargs["use_deepspeed"] = True
accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
else:
accelerator_state_kwargs["backend"] = self.ddp_backend
# accelerator_state_kwargs["backend"] = self.ddp_backend
accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
accelerator_state_kwargs["context_parallel_size"] = self.context_parallel_size
accelerator_state_kwargs["minimize_memory"] = self.minimize_memory
Expand Down

0 comments on commit 246f836

Please sign in to comment.