Skip to content

Commit

Permalink
pass fast distributed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 10, 2025
1 parent b8fc49f commit a0f11d0
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2470,22 +2470,30 @@ def create_accelerator_and_postprocess(self):
# create accelerator object
self.accelerator = Accelerator(**args)

# we patch accelerator with the mpu here
# should this be in deepspeed plugin instead?
# we patch accelerator with the mpu here as it used to be interwined with the custom accelerator implementation
# should this be somewhere else ?
self.accelerator.mpu = parallel_state

context_parallel_size = self.args.context_parallel_size
if not is_deepspeed_available():
context_parallel_size = 1
if self.accelerator.mpu.is_unitialized():
self.accelerator.mpu.initialize_model_parallel(sequence_parallel_size=context_parallel_size, use_fp8=False)
else:
if self.accelerator.mpu.get_sequence_parallel_world_size() != context_parallel_size:
raise ValueError(
"The initialized sequence parallel world size does not match the context parallel size."

if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not self.args.use_cpu:
context_parallel_size = self.args.context_parallel_size
if not is_deepspeed_available():
context_parallel_size = 1

if self.accelerator.mpu.is_unitialized():
self.accelerator.mpu.initialize_model_parallel(
sequence_parallel_size=context_parallel_size, use_fp8=False
)
if self.accelerator.mpu.amax_reduction_is_initialized():
logger.info("FP8 amax reduction group is already initialized.")
else:
if self.accelerator.mpu.get_sequence_parallel_world_size() != context_parallel_size:
raise ValueError(
"The initialized sequence parallel world size does not match the context parallel size."
)
if self.accelerator.mpu.amax_reduction_is_initialized():
logger.info("FP8 amax reduction group is already initialized.")

# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
Expand Down

0 comments on commit a0f11d0

Please sign in to comment.