Skip to content

Commit

Permalink
Removing the drop and gathers in depth tensor parallelism for the eas…
Browse files Browse the repository at this point in the history
…y API (#66)
  • Loading branch information
siddharth9820 authored Feb 28, 2024
1 parent 508798b commit f975e58
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 27 deletions.
4 changes: 0 additions & 4 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def forward(
if not self.transpose:
if scatter_input:
x = Drop.apply(x, self.inner_group)
x = Drop.apply(x, self.depth_group, 0)
x = AsyncLinear.apply(
x,
weight,
Expand All @@ -278,11 +277,9 @@ def forward(
)
if gather_output:
x = Gather.apply(x, self.outer_group)
x = Gather.apply(x, self.depth_group, 0)
else:
if scatter_input:
x = Drop.apply(x, self.outer_group)
x = Drop.apply(x, self.depth_group, 0)

x = AsyncLinear.apply(
x,
Expand All @@ -297,7 +294,6 @@ def forward(
)
if gather_output:
x = Gather.apply(x, self.inner_group)
x = Gather.apply(x, self.depth_group, 0)

if self.bias is None:
return x
Expand Down
35 changes: 12 additions & 23 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,14 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias):
outer_group = ax.comm_handle.outer_intra_layer_parallel_group
depth_group = ax.comm_handle.depth_intra_layer_parallel_group

X_local = _drop(X, 0, depth_group) # divide rows of X along the depth tensor group

if not easy_tp:
# manually divide input
X_local = _drop(
X, 1, inner_group
) # divide colunns of X along the inner tensor group
# manually divide input
X_local = _drop(
X_local, 0, depth_group
) # divide colunns of X along the inner tensor group
else:
X_local = X

layer = Linear(in_features=H, out_features=H, bias=bias).cuda()
layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda()
Expand All @@ -58,11 +55,9 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias):
with torch.no_grad():
# parallel FW pass
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_parallel = _gather(Y_local.clone(), 0, depth_group)
if not easy_tp: # gather output manually
Y_parallel = _gather(Y_local.clone(), 1, outer_group)
Y_parallel = _gather(Y_parallel.clone(), 0, depth_group)
else:
Y_parallel = Y_local
Y_sequential = layer_sequential(X)

assert torch.allclose(Y_sequential, Y_parallel), "FW Pass - output does not match"
Expand Down Expand Up @@ -117,22 +112,19 @@ def test_bw_pass(
# test if load state dict works with a sharded checkpoint
layer.load_state_dict(layer.state_dict())

X_local = (
_drop(X, 0, depth_group).detach().clone()
) # divide colunns of X along the inner tensor group
if not easy_tp:
X_local = (
_drop(X, 1, inner_group).detach().clone()
_drop(X_local, 1, inner_group).detach().clone()
) # divide colunns of X along the inner tensor group
X_local = (
_drop(X_local, 0, depth_group).detach().clone()
) # divide colunns of X along the inner tensor group
else:
X_local = X

X_local.requires_grad = True

Y_local_grad = _drop(Y_grad, 0, depth_group).detach().clone()
if not easy_tp:
Y_local_grad = _drop(Y_grad, 1, outer_group).detach().clone()
Y_local_grad = _drop(Y_local_grad, 0, depth_group).detach().clone()
else:
Y_local_grad = Y_grad
Y_local_grad = _drop(Y_local_grad, 1, outer_group).detach().clone()

with optimize_communication(
overlap_all_reduce=comm_opt_level >= 1,
Expand All @@ -144,8 +136,7 @@ def test_bw_pass(
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_local.backward(Y_local_grad)

if not easy_tp:
sync_gradients(layer)
sync_gradients(layer)
if comm_opt_level >= 3:
clear_weights_cache()
# sequential backward pass
Expand All @@ -159,11 +150,9 @@ def test_bw_pass(
layer_sequential.parameters(), max_norm=clip_grad_norm
)

X_grad_parallel = _gather(X_local.grad, 0, depth_group)
if not easy_tp:
X_grad_parallel = _gather(X_local.grad, 0, depth_group)
X_grad_parallel = _gather(X_grad_parallel, 1, inner_group)
else:
X_grad_parallel = X_local.grad

assert torch.allclose(
X_grad_parallel, X.grad
Expand Down

0 comments on commit f975e58

Please sign in to comment.