From 0341de79ddbf8f21fe061e34e107e7a3fdfd717f Mon Sep 17 00:00:00 2001 From: Charlene Yang Date: Fri, 21 Feb 2025 17:06:52 -0800 Subject: [PATCH] WIP: fix 1c31b68d Signed-off-by: Charlene Yang --- transformer_engine/common/fused_attn/utils.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index daf4ce71c1..1c6072da07 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -434,6 +434,7 @@ __device__ void cu_seqlens_padded_to_offsets_impl( offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_3HD: @@ -449,6 +450,7 @@ __device__ void cu_seqlens_padded_to_offsets_impl( if (offsets_k != nullptr && offsets_v != nullptr) { switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; break; @@ -495,6 +497,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at std::array offsets_qkvo{}; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q; offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv; offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv;