diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index e5fb539a5b..8926866cc3 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -59,7 +59,7 @@ from accelerate.utils.other import is_compiled_module from torch.optim.lr_scheduler import LRScheduler -from .. import parallel_state +from ..distributed import parallel_state if is_deepspeed_available(): diff --git a/optimum/habana/accelerate/data_loader.py b/optimum/habana/accelerate/data_loader.py index afe8fc1cd8..bab31edb0e 100644 --- a/optimum/habana/accelerate/data_loader.py +++ b/optimum/habana/accelerate/data_loader.py @@ -22,7 +22,7 @@ ) from torch.utils.data import BatchSampler, DataLoader, IterableDataset -from .. import parallel_state +from ..distributed import parallel_state from .state import GaudiAcceleratorState from .utils.operations import ( broadcast, diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index b9c4e794f7..6e507acc2c 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -21,7 +21,7 @@ from optimum.utils import logging -from .. import parallel_state +from ..distributed import parallel_state from .utils import GaudiDistributedType diff --git a/optimum/habana/distributed/contextparallel.py b/optimum/habana/distributed/contextparallel.py index 0b48465542..2020b6a84e 100644 --- a/optimum/habana/distributed/contextparallel.py +++ b/optimum/habana/distributed/contextparallel.py @@ -1,6 +1,6 @@ import torch -from ..parallel_state import ( +from .parallel_state import ( get_sequence_parallel_group, get_sequence_parallel_rank, get_sequence_parallel_world_size, diff --git a/optimum/habana/parallel_state.py b/optimum/habana/distributed/parallel_state.py similarity index 100% rename from optimum/habana/parallel_state.py rename to optimum/habana/distributed/parallel_state.py diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 55d4475a87..6cd3de0a72 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -21,7 +21,8 @@ ) from transformers.utils import is_torchdynamo_compiling -from .... import distributed, parallel_state +from .... import distributed +from ....distributed import parallel_state from ....distributed.strategy import DistributedStrategy, NoOpStrategy from ....distributed.tensorparallel import ( reduce_from_tensor_model_parallel_region,