Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch] Implement fp32 accumulation for attention with context parallel in both forward and backward pass. #821

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,10 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst,

@jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim,
softmax_lse, softmax_lse_per_step):
softmax_lse, softmax_lse_per_step,
accumulate_in_fp32):
"""Merge partial outputs of each step in Attention with context parallelism"""
if accumulate_in_fp32: out_per_step = out_per_step.to(softmax_lse.dtype)
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step*softmax_lse_corrected_exp
Expand All @@ -518,7 +520,8 @@ class AttnFuncWithCP(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention):
attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention,
accumulate_in_fp32):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -779,8 +782,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,

with torch.cuda.stream(flash_attn_streams[(i-1)%2]):
if i == 1:
out = torch.empty_like(q).zero_()
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
out = torch.empty_like(q, dtype=torch.float32).zero_()
softmax_lse = torch.clone(softmax_lse_per_step[0])
if causal:
# [b, np, sq] -> [b, np, 2, sq//2]
softmax_lse_ = softmax_lse.view(
Expand Down Expand Up @@ -812,13 +815,15 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
out_per_step[i],
seq_dim,
softmax_lse,
softmax_lse_per_step[i])
softmax_lse_per_step[i],
accumulate_in_fp32)
else:
flash_attn_fwd_out_correction(out_,
out_per_step[i],
seq_dim,
softmax_lse_[..., 1, :],
softmax_lse_per_step[i])
softmax_lse_per_step[i],
accumulate_in_fp32)

kv = p2p_comm_buffers[-1]
if use_fused_attention:
Expand All @@ -828,7 +833,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
out = out.view(-1, *out.shape[-3:])
else:
out = out.view(-1, *out.shape[-2:])
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.save_for_backward(q, kv, out.to(q.dtype), softmax_lse, cu_seqlens_q, cu_seqlens_k)
ctx.rng_states = rng_states
ctx.cp_group = cp_group
ctx.cp_global_ranks = cp_global_ranks
Expand All @@ -843,7 +848,9 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
ctx.attn_biases = attn_biases
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
return out
ctx.accumulate_in_fp32 = accumulate_in_fp32

return out.to(q.dtype)

@staticmethod
def backward(ctx, dout):
Expand Down Expand Up @@ -884,9 +891,13 @@ def backward(ctx, dout):
out = out.view(*q.shape)
dout = dout.view(*q.shape)
# Flash Attn outputs
dq = torch.empty_like(q)

p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
if ctx.accumulate_in_fp32:
dq = torch.empty_like(q, dtype=torch.float32)
p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=torch.float32, device=kv.device), \
torch.empty((2, *kv.shape), dtype=torch.float32, device=kv.device)]
else:
dq = torch.empty_like(q)
p2p_comm_buffers = [torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), \
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device)]
p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = []
Expand Down Expand Up @@ -919,7 +930,7 @@ def backward(ctx, dout):
ctx.cp_group,
batch_p2p_comm)

kv = p2p_comm_buffers[i%2][0]
kv = p2p_comm_buffers[i%2][0].to(q.dtype)
# In reversed order of fwd
if ctx.causal:
if i == (cp_size-1):
Expand Down Expand Up @@ -1242,15 +1253,15 @@ def backward(ctx, dout):
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)

return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \
None, None, None, None, None, None, attn_dbias, None, None
return None, dq.to(q.dtype), dkv[0].to(kv.dtype), dkv[1].to(kv.dtype), None, None, None, \
None, None, None, None, None, None, None, None, None, attn_dbias, None, None, None


def attn_forward_func_with_cp(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale=None, qkv_format="bshd",
attn_mask_type="causal", attn_bias_type="no_bias", attn_bias=None, deterministic=False,
use_fused_attention=False
use_fused_attention=False, accumulate_in_fp32=True
) -> torch.Tensor:
"""Attention implementation with context parallelism"""
assert(qkv_format in ["bshd", "sbhd"]
Expand All @@ -1264,7 +1275,8 @@ def attn_forward_func_with_cp(
out = AttnFuncWithCP.apply(
is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention
attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention,
accumulate_in_fp32
)
return out

Expand Down Expand Up @@ -1923,6 +1935,7 @@ def __init__(
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
accumulate_in_fp32: bool = True,
) -> None:
super().__init__()

Expand All @@ -1936,6 +1949,7 @@ def __init__(
self.attention_type = attention_type
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.accumulate_in_fp32 = accumulate_in_fp32

def forward(
self,
Expand Down Expand Up @@ -2074,7 +2088,8 @@ def forward(
softmax_scale=1.0/self.norm_factor,
qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
attn_mask_type=attn_mask_type,
deterministic=self.deterministic
deterministic=self.deterministic,
accumulate_in_fp32=self.accumulate_in_fp32,
)
else:

Expand Down Expand Up @@ -2909,6 +2924,7 @@ def __init__(
deterministic: bool = False,
tp_size: int = 1,
tp_group: Optional[dist_group_type] = None,
accumulate_in_fp32: bool = True,
) -> None:
super().__init__()

Expand Down Expand Up @@ -2937,6 +2953,7 @@ def __init__(

self.tp_size = tp_size
self.tp_group = tp_group
self.accumulate_in_fp32 = accumulate_in_fp32

def get_fp8_weights_scratchpad(
self,
Expand Down Expand Up @@ -3056,6 +3073,7 @@ def forward(
attn_bias_type=core_attention_bias_type,
attn_bias=core_attention_bias,
use_fused_attention=True,
accumulate_in_fp32=self.accumulate_in_fp32
)
else:
with self.prepare_forward(query_layer,
Expand Down