Skip to content

Commit

Permalink
Fixed communication handle flags
Browse files Browse the repository at this point in the history
Fixed formatting
  • Loading branch information
Sathwik Yanamaddi committed Mar 3, 2024
1 parent 51bff43 commit 72df221
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
15 changes: 10 additions & 5 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def init(
mixed_precision=False,
fp16_allreduce=True,
cpu_offload=False,
device='cuda',
device="cuda",
) -> None:
"""
Initialize AxoNN's 2D parallelism with G_inter-way inter-layer
Expand All @@ -136,10 +136,15 @@ def init(
global comm_handle, is_initialized, computation_dtype, _fp16_all_reduce
global _cpu_offload
comm_handle = communication_handle(
G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node
G_inter,
G_data,
G_intra_r,
G_intra_c,
G_intra_d,
gpus_per_node=gpus_per_node,
device=device,
)
print("maybe its this")
config.device=device
config.device = device
config.G_inter = G_inter
config.G_data = G_data
config.G_intra = G_intra_r * G_intra_c * G_intra_d
Expand All @@ -155,7 +160,7 @@ def init(
comm_handle.intra_layer_column_parallel_rank
)
is_initialized = True
if device == 'cuda' and not torch.cuda.is_available():
if device == "cuda" and not torch.cuda.is_available():
raise ValueError("CUDA is not available. Please choose a different device.")

if mixed_precision:
Expand Down
5 changes: 1 addition & 4 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
G_intra_c=1,
G_intra_d=1,
gpus_per_node=None,
device="cpu",
device="cuda",
):
"""Constructor for the communication handle
Expand All @@ -58,10 +58,7 @@ def __init__(
G_intra_c (int): number of GPUs in the column intra-layer parallel dimension
G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension
"""
print("is this the first thing that runs")
config.device = device
print(config.device)
print(device)
if config.device == "cpu":
self.backend = "gloo"
env = DistributedEnvironment()
Expand Down
2 changes: 1 addition & 1 deletion axonn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
G_data = 0
micro_batch_size = 0
batch_size = 0
device = 'cuda'
device = "cuda"

0 comments on commit 72df221

Please sign in to comment.