Skip to content

Commit

Permalink
Rebased cpu changes onto latest axonn
Browse files Browse the repository at this point in the history
  • Loading branch information
Sathwik Yanamaddi committed Feb 2, 2024
1 parent a9d38c2 commit e3bfa64
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
4 changes: 2 additions & 2 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _post_fw_recv_requests():
if (requests["fw"] is None) and config.inter_layer_parallel_rank > 0:
tensor = torch.empty(
size=_fill_shape(model.get_input_shape()),
device="cuda",
device="cpu",
dtype=computation_dtype,
)
tensor.requires_grad = True
Expand All @@ -556,7 +556,7 @@ def _post_bw_recv_requests():
):
tensor = torch.empty(
size=_fill_shape(model.get_output_shape()),
device="cuda",
device="cpu",
dtype=computation_dtype,
)
requests["bw"] = [
Expand Down
39 changes: 30 additions & 9 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
import torch
import numpy as np

class DistributedEnvironment:
def __init__(self):
self.world_size = int(os.environ["SLURM_NTASKS"])
self.local_rank = int(os.environ["SLURM_PROCID"])

def get_world_size(self):
return self.world_size

def get_rank(self):
return self.local_rank

class communication_handle:
"""
Expand Down Expand Up @@ -44,6 +54,19 @@ def __init__(
G_intra_c (int): number of GPUs in the column intra-layer parallel dimension
G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension
"""
env = DistributedEnvironment()
self.world_rank = env.get_rank()
self.world_size = env.get_world_size()

if gpus_per_node is None:
self.backend = "gloo"
self.is_gpu_available = False
else:
self.backend = "nccl"
self.is_gpu_available = True
self.local_rank = self.world_rank % gpus_per_node
torch.cuda.set_device(self.local_rank)

if not torch.distributed.is_initialized():
assert MPI4PY, "either install mpi4py and launch via mpirun/srun"
"or initialize torch.distributed outside axonn"
Expand Down Expand Up @@ -71,8 +94,6 @@ def __init__(
self.gpus_per_node = (
gpus_per_node if gpus_per_node is not None else torch.cuda.device_count()
)
self.local_rank = self.world_rank % self.gpus_per_node
torch.cuda.set_device(self.local_rank)
self.intra_layer_parallel_rank = self.world_rank % G_intra
self.intra_layer_column_parallel_rank = (
self.intra_layer_parallel_rank % G_intra_c
Expand Down Expand Up @@ -115,7 +136,7 @@ def __init__(
master_port = os.getenv("MASTER_PORT", "6000")
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend="nccl",
backend=self.backend,
world_size=self.world_size,
rank=self.world_rank,
init_method=init_method,
Expand All @@ -130,7 +151,7 @@ def __init__(
for k in range(self.G_data)
]
ith_jth_data_parallel_group = torch.distributed.new_group(
ranks=ranks_in_ith_jth_data_parallel_group, backend="nccl"
ranks=ranks_in_ith_jth_data_parallel_group, backend=self.backend
)
if self.world_rank in ranks_in_ith_jth_data_parallel_group:
self.coll_nccl_comm = ith_jth_data_parallel_group
Expand All @@ -142,7 +163,7 @@ def __init__(
i_ * G_inter * G_intra + j_ * G_intra + k for k in range(G_intra)
]
ith_jth_intra_layer_group = torch.distributed.new_group(
ranks=ranks_in_ith_jth_intra_layer_group, backend="nccl"
ranks=ranks_in_ith_jth_intra_layer_group, backend=self.backend
)
if self.world_rank in ranks_in_ith_jth_intra_layer_group:
self.intra_layer_group = ith_jth_intra_layer_group
Expand All @@ -165,7 +186,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, j, :]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.inner_intra_layer_parallel_group = group
Expand All @@ -177,7 +198,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, :, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.outer_intra_layer_parallel_group = group
Expand All @@ -189,7 +210,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[:, i, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.depth_intra_layer_parallel_group = group
Expand All @@ -200,7 +221,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, :, :].flatten()
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.outer_inner_intra_layer_parallel_group = group
Expand Down
6 changes: 3 additions & 3 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def initialize_params(
in_features_group,
depth_group,
init_method,
init_device="cuda",
init_device="cpu",
):
params = torch.empty((out_features, in_features), device=init_device)
init_method(params)
Expand Down Expand Up @@ -253,8 +253,8 @@ def get_output_feature_size(self):
def forward(
self,
x,
scatter_input=True,
gather_output=True,
scatter_input=False,
gather_output=False,

This comment has been minimized.

Copy link
@siddharth9820

siddharth9820 Feb 2, 2024

Collaborator

These need to be true by default. Could you make this change?

cache_weights_in_all_gather=False,
):
# gather weights from depth parallel group
Expand Down

0 comments on commit e3bfa64

Please sign in to comment.