diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index f4561e4..6d2cfae 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -264,7 +264,6 @@ def forward( if not self.transpose: if scatter_input: x = Drop.apply(x, self.inner_group) - x = Drop.apply(x, self.depth_group, 0) x = AsyncLinear.apply( x, weight, @@ -278,11 +277,9 @@ def forward( ) if gather_output: x = Gather.apply(x, self.outer_group) - x = Gather.apply(x, self.depth_group, 0) else: if scatter_input: x = Drop.apply(x, self.outer_group) - x = Drop.apply(x, self.depth_group, 0) x = AsyncLinear.apply( x, @@ -297,7 +294,6 @@ def forward( ) if gather_output: x = Gather.apply(x, self.inner_group) - x = Gather.apply(x, self.depth_group, 0) if self.bias is None: return x diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index d8d5821..af09ab7 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -35,17 +35,14 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): outer_group = ax.comm_handle.outer_intra_layer_parallel_group depth_group = ax.comm_handle.depth_intra_layer_parallel_group + X_local = _drop(X, 0, depth_group) # divide rows of X along the depth tensor group + if not easy_tp: # manually divide input X_local = _drop( X, 1, inner_group ) # divide colunns of X along the inner tensor group # manually divide input - X_local = _drop( - X_local, 0, depth_group - ) # divide colunns of X along the inner tensor group - else: - X_local = X layer = Linear(in_features=H, out_features=H, bias=bias).cuda() layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda() @@ -58,11 +55,9 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias): with torch.no_grad(): # parallel FW pass Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) + Y_parallel = _gather(Y_local.clone(), 0, depth_group) if not easy_tp: # gather output manually Y_parallel = _gather(Y_local.clone(), 1, outer_group) - Y_parallel = _gather(Y_parallel.clone(), 0, depth_group) - else: - Y_parallel = Y_local Y_sequential = layer_sequential(X) assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match" @@ -117,22 +112,19 @@ def test_bw_pass( # test if load state dict works with a sharded checkpoint layer.load_state_dict(layer.state_dict()) + X_local = ( + _drop(X, 0, depth_group).detach().clone() + ) # divide colunns of X along the inner tensor group if not easy_tp: X_local = ( - _drop(X, 1, inner_group).detach().clone() + _drop(X_local, 1, inner_group).detach().clone() ) # divide colunns of X along the inner tensor group - X_local = ( - _drop(X_local, 0, depth_group).detach().clone() - ) # divide colunns of X along the inner tensor group - else: - X_local = X X_local.requires_grad = True + + Y_local_grad = _drop(Y_grad, 0, depth_group).detach().clone() if not easy_tp: - Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone() - Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone() - else: - Y_local_grad = Y_grad + Y_local_grad = _drop(Y_local_grad, 1, outer_group).detach().clone() with optimize_communication( overlap_all_reduce=comm_opt_level >= 1, @@ -144,8 +136,7 @@ def test_bw_pass( Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) Y_local.backward(Y_local_grad) - if not easy_tp: - sync_gradients(layer) + sync_gradients(layer) if comm_opt_level >= 3: clear_weights_cache() # sequential backward pass @@ -159,11 +150,9 @@ def test_bw_pass( layer_sequential.parameters(), max_norm=clip_grad_norm ) + X_grad_parallel = _gather(X_local.grad, 0, depth_group) if not easy_tp: - X_grad_parallel = _gather(X_local.grad, 0, depth_group) X_grad_parallel = _gather(X_grad_parallel, 1, inner_group) - else: - X_grad_parallel = X_local.grad assert torch.allclose( X_grad_parallel, X.grad