diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e74e4b6f50..0d2e9d4ed7 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2431,19 +2431,19 @@ def forward( elif qkv_format == "sbhd": out = out.view(-1, *out.shape[-3:]) ctx.batch_size = out.shape[1] - + tex.fused_out_correction( - out, - out_per_step, - softmax_lse, - softmax_lse_per_step, - cu_seqlens_q_padded, - qkv_format, - cp_size, - rank, - causal, - softmax_lse_in_packed_format, - ) + out, + out_per_step, + softmax_lse, + softmax_lse_per_step, + cu_seqlens_q_padded, + qkv_format, + cp_size, + rank, + causal, + softmax_lse_in_packed_format, + ) if cp_size_a2a > 1: chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2eaacac742..0d3cfc148b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -452,7 +452,7 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank); - + void fused_out_correction(at::Tensor out, const std::vector &out_per_step, const at::Tensor &lse, const std::vector &lse_per_step, const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index cb005bcde1..cbdc8387cc 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1256,7 +1256,7 @@ void fused_out_correction_helper(at::Tensor out, const std::vector & if (softmax_lse_in_packed_format) { lse_seqlen = total_tokens; } else { - lse_seqlen = lse.size(2); + lse_seqlen = lse.size(2); } } constexpr int tile = 16; @@ -1277,15 +1277,12 @@ void fused_out_correction_helper(at::Tensor out, const std::vector & tensors.addresses_lse[j] = lse_per_step[i + j].data_ptr(); } if (qkv_format == "sbhd") { - NVTE_CHECK(softmax_lse_in_packed_format == false, "Packed lse doesn't support SBHD format."); - fused_out_correction_kernel - <<>>( - out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), - batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); - + fused_out_correction_kernel + <<>>( + out.data_ptr(), tensors, lse.data_ptr(), cu_seqlens.data_ptr(), + batch, num_heads, dim_per_head, lse_seqlen, cp_size, rank, i); } else if (qkv_format == "bshd") { if (softmax_lse_in_packed_format) {