diff --git a/axonn/communication.py b/axonn/communication.py index e37902b..cf0ecc4 100644 --- a/axonn/communication.py +++ b/axonn/communication.py @@ -184,7 +184,7 @@ def __init__( ranks_in_ith_jth_intra_layer_group[i, j, :] ) group = torch.distributed.new_group( - ranks=group_members, backend="gloo" + ranks=group_members, backend="self.backend" ) if self.world_rank in group_members: self.inner_intra_layer_parallel_group = group @@ -196,7 +196,7 @@ def __init__( ranks_in_ith_jth_intra_layer_group[i, :, j] ) group = torch.distributed.new_group( - ranks=group_members, backend="gloo" + ranks=group_members, backend="self.backend" ) if self.world_rank in group_members: self.outer_intra_layer_parallel_group = group @@ -208,7 +208,7 @@ def __init__( ranks_in_ith_jth_intra_layer_group[:, i, j] ) group = torch.distributed.new_group( - ranks=group_members, backend="gloo" + ranks=group_members, backend="self.backend" ) if self.world_rank in group_members: self.depth_intra_layer_parallel_group = group