Skip to content

Commit 2ae041f

Browse files
atomicAdd returns previous value, not current value. (microsoft#16690)
### Description Mistake in beam scorer processing, atomicAdd result should be compared with '1' vs '0' as it returns the original value, not the latest value. This error just results in slow perf, nothing fails. ### Motivation and Context Fixes microsoft#16642
1 parent 44fd98e commit 2ae041f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
392392
if (beam_hyp.beams_used_ == state.num_beams_) {
393393
if (state.early_stopping_ || !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length)) {
394394
beam_hyp.done_ = true;
395-
if (atomicAdd(&state.not_done_count_, -1) == 0)
395+
if (atomicAdd(&state.not_done_count_, -1) == 1)
396396
state_cpu.not_done_count_ = 0; // Update the CPU side
397397
}
398398
}

0 commit comments

Comments
 (0)