Skip to content

Commit 9c74b75

Browse files
authored
kernel: optimize merge_attn_states CUDA kernel dispatch (#278)
* misc: add cuda merge_attn_states kernel index * misc: add cuda merge_attn_states kernel index * misc: add cuda merge_attn_states kernel index * [Kernel] optimize merge_attn_states CUDA kernel dispatch
1 parent 482b3fd commit 9c74b75

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

kernels/openai-triton/merge-attn-states/cuda_merge_attn_states.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,12 @@ void merge_attn_states_launcher(
217217
if (output_lse.has_value()) {
218218
output_lse_ptr = output_lse.value().data_ptr<float>();
219219
}
220+
const bool skip_loop_over_head = (
221+
num_tokens <= 1024 || num_heads >= 64
222+
|| disable_loop_over_head
223+
);
220224

221-
if (num_tokens <= 1024 || num_heads >= 64
222-
|| disable_loop_over_head) {
225+
if (skip_loop_over_head) {
223226
dim3 grid(num_tokens, num_heads);
224227
dim3 block(head_size / pack_size);
225228
LAUNCH_MERGE_ATTN_STATES(SCALAR_T, false);

0 commit comments

Comments
 (0)