Skip to content

Commit

Permalink
force backend hccl and multi_hpu type when sure of distributed launch
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 5, 2025
1 parent f66c5df commit 2a4130d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,12 +729,6 @@ def _prepare_backend(
elif is_torch_xla_available():
backend = "xla"
distributed_type = DistributedType.XLA
elif is_hpu_available():
import habana_frameworks.torch.distributed.hccl

if int(os.environ.get("LOCAL_RANK", -1)) != -1:
backend = "hccl"
distributed_type = DistributedType.MULTI_HPU

elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
if is_mlu_available():
Expand All @@ -748,6 +742,9 @@ def _prepare_backend(
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU
elif is_hpu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_HPU
elif torch.cuda.is_available():
if backend is None:
backend = "nccl"
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def is_hpu_available(check_device=False):
return False

import habana_frameworks.torch # noqa: F401
import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401

if check_device:
try:
Expand Down

0 comments on commit 2a4130d

Please sign in to comment.