Skip to content

Commit 2cf31a2

Browse files
Cuda: Decoder Masked Multihead Attention Q values get corrupted when using cross attention (microsoft#16721)
### Description Some code was accidentally moved into the `if(!params.is_cross_attention)' block, it must stay outside to work in both cases. ### Motivation and Context This causes invalid results. We detected this as a performance bug, as it caused the EOS early exit to never happen, and the runs would always take max_length to complete which was slow.
1 parent 2b7a94e commit 2cf31a2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu

+5-3
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
179179

180180
const float inv_sqrt_dh = params.scale;
181181

182+
if (!is_masked) {
183+
// Store the Q values to shared memory.
184+
*reinterpret_cast<Qk_vec_k*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
185+
}
186+
182187
if (!params.is_cross_attention) {
183188
Qk_vec_k k;
184189

@@ -241,9 +246,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
241246
}
242247

243248
if (!is_masked) {
244-
// Store the Q values to shared memory.
245-
*reinterpret_cast<Qk_vec_k*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
246-
247249
// Write the K values to the global memory cache.
248250
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
249251
// system. We designed it this way as it allows much better memory loads (and there are many

0 commit comments

Comments
 (0)