Skip to content

Commit

Permalink
inter-layer parallelism for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Sathwik Yanamaddi committed Nov 2, 2023
1 parent 584ba0a commit 0830f8c
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ def __init__(
self.world_size = MPI.COMM_WORLD.Get_size()

if gpus_per_node is None:
self.backend = "gloo"
self.backend = 'gloo'
self.is_gpu_available = False
self.local_rank = os.environ.get("SLURM_LOCALID", 0)
else:
self.backend = "nccl"
self.backend = 'nccl'
self.is_gpu_available = True
self.local_rank = self.world_rank % gpus_per_node
torch.cuda.set_device(self.local_rank)
Expand Down Expand Up @@ -66,7 +65,7 @@ def __init__(
self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour)
assert self.p2p_mpi_comm.Get_size() == G_inter

# create communicator for collective (NCCL) communication
# create communicator for collective communication
if not torch.distributed.is_initialized():
init_method = "tcp://"
master_ip = os.getenv("MASTER_ADDR", "localhost")
Expand Down Expand Up @@ -121,7 +120,7 @@ def __init__(
for i in range(G_intra_c):
group_members = intra_layer_ranks[i::G_intra_c]
group = torch.distributed.new_group(
ranks=group_members, backend="gloo"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.outer_intra_layer_parallel_group = group
Expand Down

0 comments on commit 0830f8c

Please sign in to comment.