Skip to content

Commit

Permalink
WIP: fix last commit
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Feb 22, 2025
1 parent 1c31b68 commit 7331a4c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
6 changes: 3 additions & 3 deletions tests/pytorch/fused_attn/test_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ def gen_data():
rtol=tols[dtype],
)
if qkv_format == "thd":
print('i ', i, seq, cu_seqlens_q)
print(full_output[seq, sim.t_total_lens[i] - 1, :4])
print(line_output[cu_seqlens_q[i + 1] - 1, :4])
#print('i ', i, seq, cu_seqlens_q)
#print(full_output[seq, sim.t_total_lens[i] - 1, :4])
#print(line_output[cu_seqlens_q[i + 1] - 1, :4])
torch.testing.assert_close(
#full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :],
#line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,22 +428,23 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_ragged_q || is_ragged_kv) {
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsets =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsQ = nullptr;
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
void *devOffsetsO = nullptr;
if (is_ragged_q) {
devOffsetsQ =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset;
devOffsetsQ = devOffsets;
devOffsetsO = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
}
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
if (is_ragged_kv) {
devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
devOffsetsK = static_cast<int8_t *>(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset;
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsetsO) + num_bytes_per_ragged_offset;
devOffsetsS = static_cast<int8_t *>(devOffsets) + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset;
}
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
Expand Down Expand Up @@ -655,7 +656,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
o->set_ragged_offset(offset_o);
dO->set_ragged_offset(offset_o);
}
if (is_ragged_q) {
if (is_ragged_kv) {
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b + 1, 1, 1, 1})
Expand Down Expand Up @@ -886,22 +887,23 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_ragged_q || is_ragged_kv) {
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsets =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsQ = nullptr;
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
void *devOffsetsO = nullptr;
if (is_ragged_q) {
devOffsetsQ =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset;
devOffsetsQ = devOffsets;
devOffsetsO = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
}
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
if (is_ragged_kv) {
devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
devOffsetsK = static_cast<int8_t *>(devOffsets) + (int)is_ragged_q * 2 * num_bytes_per_ragged_offset;
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsetsO) + num_bytes_per_ragged_offset;
devOffsetsS = static_cast<int8_t *>(devOffsets) + ((int)is_ragged_q + (int)is_ragged_kv) * 2 * num_bytes_per_ragged_offset;
}
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def post_step(
output = output[:self.batch_size, :self.max_seqlen_q].transpose(0, 1).contiguous()
if self.input_qkv_format == "thd":
print('oooo ', output.shape)
print(output[:2, :4])
#print(output[:2, :4])
#output_buffer = self.q_orig[layer_number]
#step_lens = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1]
#tex.reshape_o(output, output_buffer, step_lens,
Expand Down

0 comments on commit 7331a4c

Please sign in to comment.