Skip to content

Commit

Permalink
Fixed communication handle flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Sathwik Yanamaddi committed Mar 3, 2024
1 parent 51bff43 commit 7732efc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
3 changes: 1 addition & 2 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ def init(
global comm_handle, is_initialized, computation_dtype, _fp16_all_reduce
global _cpu_offload
comm_handle = communication_handle(
G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node
G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node, device=device
)
print("maybe its this")
config.device=device
config.G_inter = G_inter
config.G_data = G_data
Expand Down
5 changes: 1 addition & 4 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
G_intra_c=1,
G_intra_d=1,
gpus_per_node=None,
device="cpu",
device="cuda",
):
"""Constructor for the communication handle
Expand All @@ -58,10 +58,7 @@ def __init__(
G_intra_c (int): number of GPUs in the column intra-layer parallel dimension
G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension
"""
print("is this the first thing that runs")
config.device = device
print(config.device)
print(device)
if config.device == "cpu":
self.backend = "gloo"
env = DistributedEnvironment()
Expand Down

0 comments on commit 7732efc

Please sign in to comment.