Skip to content

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 5, 2025
1 parent 5fd4de2 commit 7f72745
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 42 deletions.
8 changes: 3 additions & 5 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,9 @@ def __init__(

if kwargs_handlers is not None:
for handler in kwargs_handlers:
assert isinstance(
handler, KwargsHandler
), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
assert isinstance(handler, KwargsHandler), (
f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
)
if isinstance(handler, DistributedDataParallelKwargs):
if self.ddp_handler is not None:
raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.")
Expand Down Expand Up @@ -1416,8 +1416,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
"""
if device_placement is None:
device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
if not evaluation_mode and self.distributed_type == DistributedType.MULTI_HPU:
device_placement = None
self._models.append(model)

# TODO: Look at enabling native TP training directly with a proper config
Expand Down
51 changes: 14 additions & 37 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def __init__(self, cpu: bool = False, **kwargs):
use_deepspeed = True
# Deal with all other backends but XPU and CPU, that gets handled special later
elif (
self.distributed_type
not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU, DistributedType.MULTI_HPU)
self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU)
and not torch.distributed.is_initialized()
):
torch.distributed.init_process_group(backend=self.backend, **kwargs)
Expand Down Expand Up @@ -257,16 +256,6 @@ def __init__(self, cpu: bool = False, **kwargs):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=self.backend, **kwargs)

elif self.distributed_type == DistributedType.MULTI_HPU:
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu

hpu_world_size, hpu_rank, hpu_local_rank = initialize_distributed_hpu()

if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend=self.backend, rank=hpu_rank, world_size=hpu_world_size
)

# No backend == no distributed training
if self.backend is None:
self.distributed_type = DistributedType.NO
Expand All @@ -283,10 +272,6 @@ def __init__(self, cpu: bool = False, **kwargs):
self.local_process_index = xm.get_local_ordinal()
else:
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
elif self.distributed_type == DistributedType.MULTI_HPU:
self.process_index = hpu_rank
self.num_processes = hpu_world_size
self.local_process_index = hpu_local_rank
else:
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
Expand All @@ -311,24 +296,6 @@ def __init__(self, cpu: bool = False, **kwargs):
"will do this automatically."
)

if self.device.type == "hpu":
# we should do this in optimum-habana somehow and not here
from optimum.habana.distributed import parallel_state # noqa: F401

if self.distributed_type != DistributedType.DEEPSPEED:
context_parallel_size = 1
if parallel_state.is_unitialized():
parallel_state.initialize_model_parallel(
sequence_parallel_size=context_parallel_size, use_fp8=False
)
else:
if parallel_state.get_sequence_parallel_world_size() != context_parallel_size:
raise ValueError(
"The initialized sequence parallel world size does not match the context parallel size."
)
if parallel_state.amax_reduction_is_initialized():
logger.info("FP8 amax reduction group is already initialized.")

# Important: This should be the *only* code outside of `self.initialized!`
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)

Expand Down Expand Up @@ -762,6 +729,19 @@ def _prepare_backend(
elif is_torch_xla_available():
backend = "xla"
distributed_type = DistributedType.XLA
elif is_hpu_available():
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu # noqa: F401

print("after importing habana_frameworks in prepare_backend")

print(os.environ.get("LOCAL_RANK", -1))
print(os.environ.get("WORLD_SIZE", -1))
print(os.environ.get("RANK", -1))
print(os.environ.get("MASTER_ADDR", -1))
print(os.environ.get("MASTER_PORT", -1))

backend = "hccl"
distributed_type = DistributedType.MULTI_HPU
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
if is_mlu_available():
backend = "cncl"
Expand All @@ -778,9 +758,6 @@ def _prepare_backend(
if backend is None:
backend = "nccl"
distributed_type = DistributedType.MULTI_GPU
elif is_hpu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_HPU

if distributed_type is None and (
int(os.environ.get("LOCAL_RANK", -1)) != -1
Expand Down

0 comments on commit 7f72745

Please sign in to comment.