diff --git a/axonn/axonn.py b/axonn/axonn.py index fc0f3e9..07a5980 100644 --- a/axonn/axonn.py +++ b/axonn/axonn.py @@ -111,7 +111,7 @@ def init( mixed_precision=False, fp16_allreduce=True, cpu_offload=False, - device='cuda', + device="cuda", ) -> None: """ Initialize AxoNN's 2D parallelism with G_inter-way inter-layer @@ -136,10 +136,15 @@ def init( global comm_handle, is_initialized, computation_dtype, _fp16_all_reduce global _cpu_offload comm_handle = communication_handle( - G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node + G_inter, + G_data, + G_intra_r, + G_intra_c, + G_intra_d, + gpus_per_node=gpus_per_node, + device=device, ) - print("maybe its this") - config.device=device + config.device = device config.G_inter = G_inter config.G_data = G_data config.G_intra = G_intra_r * G_intra_c * G_intra_d @@ -155,7 +160,7 @@ def init( comm_handle.intra_layer_column_parallel_rank ) is_initialized = True - if device == 'cuda' and not torch.cuda.is_available(): + if device == "cuda" and not torch.cuda.is_available(): raise ValueError("CUDA is not available. Please choose a different device.") if mixed_precision: diff --git a/axonn/communication.py b/axonn/communication.py index 3fb8e2a..57475c6 100644 --- a/axonn/communication.py +++ b/axonn/communication.py @@ -45,7 +45,7 @@ def __init__( G_intra_c=1, G_intra_d=1, gpus_per_node=None, - device="cpu", + device="cuda", ): """Constructor for the communication handle @@ -58,10 +58,7 @@ def __init__( G_intra_c (int): number of GPUs in the column intra-layer parallel dimension G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension """ - print("is this the first thing that runs") config.device = device - print(config.device) - print(device) if config.device == "cpu": self.backend = "gloo" env = DistributedEnvironment() diff --git a/axonn/config.py b/axonn/config.py index 65cd18b..b036649 100644 --- a/axonn/config.py +++ b/axonn/config.py @@ -8,4 +8,4 @@ G_data = 0 micro_batch_size = 0 batch_size = 0 -device = 'cuda' +device = "cuda"