diff --git a/axonn/communication.py b/axonn/communication.py index 6a11ebd..8036af7 100644 --- a/axonn/communication.py +++ b/axonn/communication.py @@ -33,11 +33,10 @@ def __init__( self.world_size = MPI.COMM_WORLD.Get_size() if gpus_per_node is None: - self.backend = "gloo" + self.backend = 'gloo' self.is_gpu_available = False - self.local_rank = os.environ.get("SLURM_LOCALID", 0) else: - self.backend = "nccl" + self.backend = 'nccl' self.is_gpu_available = True self.local_rank = self.world_rank % gpus_per_node torch.cuda.set_device(self.local_rank) @@ -66,7 +65,7 @@ def __init__( self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour) assert self.p2p_mpi_comm.Get_size() == G_inter - # create communicator for collective (NCCL) communication + # create communicator for collective communication if not torch.distributed.is_initialized(): init_method = "tcp://" master_ip = os.getenv("MASTER_ADDR", "localhost") @@ -121,7 +120,7 @@ def __init__( for i in range(G_intra_c): group_members = intra_layer_ranks[i::G_intra_c] group = torch.distributed.new_group( - ranks=group_members, backend="gloo" + ranks=group_members, backend=self.backend ) if self.world_rank in group_members: self.outer_intra_layer_parallel_group = group