From 2a4130d4b2c77b81843315d38a2c119968e1a42f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 5 Feb 2025 10:27:24 +0000 Subject: [PATCH] force backend hccl and multi_hpu type when sure of distributed launch --- src/accelerate/state.py | 9 +++------ src/accelerate/utils/imports.py | 1 + 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 81f1ce6c3c4..8c61d726a76 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -729,12 +729,6 @@ def _prepare_backend( elif is_torch_xla_available(): backend = "xla" distributed_type = DistributedType.XLA - elif is_hpu_available(): - import habana_frameworks.torch.distributed.hccl - - if int(os.environ.get("LOCAL_RANK", -1)) != -1: - backend = "hccl" - distributed_type = DistributedType.MULTI_HPU elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: if is_mlu_available(): @@ -748,6 +742,9 @@ def _prepare_backend( elif is_npu_available(): backend = "hccl" distributed_type = DistributedType.MULTI_NPU + elif is_hpu_available(): + backend = "hccl" + distributed_type = DistributedType.MULTI_HPU elif torch.cuda.is_available(): if backend is None: backend = "nccl" diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 25e71cfb16c..3d4436db71a 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -385,6 +385,7 @@ def is_hpu_available(check_device=False): return False import habana_frameworks.torch # noqa: F401 + import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401 if check_device: try: