Skip to content

Commit

Permalink
Modify the attention backward pass to accumulate dq & dkv using fp32 …
Browse files Browse the repository at this point in the history
…precision buffer when employing context parallelism
  • Loading branch information
Yuxin-CV committed Apr 28, 2024
1 parent 409c45d commit 9f675f0
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,9 +891,13 @@ def backward(ctx, dout):
out = out.view(*q.shape)
dout = dout.view(*q.shape)
# Flash Attn outputs
dq = torch.empty_like(q)

p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
if ctx.accumulate_in_fp32:
dq = torch.empty_like(q, dtype=torch.float32)
p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=torch.float32, device=kv.device), \
torch.empty((2, *kv.shape), dtype=torch.float32, device=kv.device)]
else:
dq = torch.empty_like(q)
p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = []
Expand Down Expand Up @@ -926,7 +930,7 @@ def backward(ctx, dout):
ctx.cp_group,
batch_p2p_comm)

kv = p2p_comm_buffers[i%2][0]
kv = p2p_comm_buffers[i%2][0].to(q.dtype)
# In reversed order of fwd
if ctx.causal:
if i == (cp_size-1):
Expand Down Expand Up @@ -1249,8 +1253,8 @@ def backward(ctx, dout):
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
None, None, None, None, None, None, attn_dbias, None, None, None
return None, dq.to(q.dtype), dkv[0].to(kv.dtype), dkv[1].to(kv.dtype), None, None, None, \
None, None, None, None, None, None, None, None, None, attn_dbias, None, None, None


def attn_forward_func_with_cp(
Expand Down

0 comments on commit 9f675f0

Please sign in to comment.