Skip to content

Commit

Permalink
WIP: fix 1c31b68
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Feb 22, 2025
1 parent 7331a4c commit 0341de7
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions transformer_engine/common/fused_attn/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ __device__ void cu_seqlens_padded_to_offsets_impl(
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id];
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
Expand All @@ -449,6 +450,7 @@ __device__ void cu_seqlens_padded_to_offsets_impl(
if (offsets_k != nullptr && offsets_v != nullptr) {
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id];
break;
Expand Down Expand Up @@ -495,6 +497,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
std::array<int64_t, 4> offsets_qkvo{};
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q;
offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv;
offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv;
Expand Down

0 comments on commit 0341de7

Please sign in to comment.