diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py index 60ee5e87ca..752d3f70a9 100644 --- a/tests/pytorch/fused_attn/test_paged_attn.py +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -525,9 +525,9 @@ def gen_data(): rtol=tols[dtype], ) if qkv_format == "thd": - print('i ', i, seq, cu_seqlens_q) - print(full_output[seq, sim.t_total_lens[i] - 1, :4]) - print(line_output[cu_seqlens_q[i + 1] - 1, :4]) + #print('i ', i, seq, cu_seqlens_q) + #print(full_output[seq, sim.t_total_lens[i] - 1, :4]) + #print(line_output[cu_seqlens_q[i + 1] - 1, :4]) torch.testing.assert_close( #full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :], #line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :], diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 313ca416da..b5710eab86 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -428,22 +428,23 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_ragged_q || is_ragged_kv) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; + void *devOffsets = + static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; void *devOffsetsQ = nullptr; - void *devOffsetsK = nullptr; - void *devOffsetsV = nullptr; void *devOffsetsO = nullptr; if (is_ragged_q) { - devOffsetsQ = - static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + devOffsetsQ = devOffsets; + devOffsetsO = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; } + void *devOffsetsK = nullptr; + void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + devOffsetsK = static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { - devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + devOffsetsS = static_cast(devOffsets) + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( @@ -655,7 +656,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( o->set_ragged_offset(offset_o); dO->set_ragged_offset(offset_o); } - if (is_ragged_q) { + if (is_ragged_kv) { offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_k") .set_dim({b + 1, 1, 1, 1}) @@ -886,22 +887,23 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged_q || is_ragged_kv) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; + void *devOffsets = + static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; void *devOffsetsQ = nullptr; - void *devOffsetsK = nullptr; - void *devOffsetsV = nullptr; void *devOffsetsO = nullptr; if (is_ragged_q) { - devOffsetsQ = - static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + devOffsetsQ = devOffsets; + devOffsetsO = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; } + void *devOffsetsK = nullptr; + void *devOffsetsV = nullptr; if (is_ragged_kv) { - devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + devOffsetsK = static_cast(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset; devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; if (is_ragged_q && cudnn_runtime_version >= 90600) { - devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + devOffsetsS = static_cast(devOffsets) + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( diff --git a/transformer_engine/pytorch/inference.py b/transformer_engine/pytorch/inference.py index 25b13a9e81..0d8b660369 100644 --- a/transformer_engine/pytorch/inference.py +++ b/transformer_engine/pytorch/inference.py @@ -386,7 +386,7 @@ def post_step( output = output[:self.batch_size, :self.max_seqlen_q].transpose(0, 1).contiguous() if self.input_qkv_format == "thd": print('oooo ', output.shape) - print(output[:2, :4]) + #print(output[:2, :4]) #output_buffer = self.q_orig[layer_number] #step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] #tex.reshape_o(output, output_buffer, step_lens,