diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 81f1ce6c3c4..8c61d726a76 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -729,12 +729,6 @@ def _prepare_backend( elif is_torch_xla_available(): backend = "xla" distributed_type = DistributedType.XLA - elif is_hpu_available(): - import habana_frameworks.torch.distributed.hccl - - if int(os.environ.get("LOCAL_RANK", -1)) != -1: - backend = "hccl" - distributed_type = DistributedType.MULTI_HPU elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: if is_mlu_available(): @@ -748,6 +742,9 @@ def _prepare_backend( elif is_npu_available(): backend = "hccl" distributed_type = DistributedType.MULTI_NPU + elif is_hpu_available(): + backend = "hccl" + distributed_type = DistributedType.MULTI_HPU elif torch.cuda.is_available(): if backend is None: backend = "nccl" diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 25e71cfb16c..3d4436db71a 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -385,6 +385,7 @@ def is_hpu_available(check_device=False): return False import habana_frameworks.torch # noqa: F401 + import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401 if check_device: try: