Skip to content

Commit 09e5724

Browse files
authored
[CUDA] Fix beam search of num_beams > 32 (microsoft#23599)
### Description * Pass topk_scores to beam scorer in slow topk path. * Add an env variable `ORT_BEAM_SEARCH_USE_FAST_TOPK` to enable/disable fast topk. * Add a test case for slow topk path. ### Motivation and Context This bug was introduced in microsoft#16272 Beam search uses fast cuda kernel when number of beams <= 32. When beam size is larger than that threshold, we use another code path (slower cuda kernel) to get topk. In such `slow topk path`, topk_scores shall be passed to beam scorer but it is not. This bug will cause incorrect result when num_beams > 32. It was not found previously since such large beam size is rarely used.
1 parent 82840f6 commit 09e5724

File tree

4 files changed

+35
-15
lines changed

4 files changed

+35
-15
lines changed

onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
5+
#include "core/platform/env_var_utils.h"
56

67
namespace onnxruntime {
78
namespace contrib {
@@ -136,7 +137,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
136137
temperature = 1.0f;
137138
}
138139
}
140+
141+
// The following parameter is read from environment variable for testing purpose.
142+
use_fast_topk = ParseEnvironmentVariableWithDefault<bool>(kBeamSearchUseFastTopK, true);
139143
}
144+
140145
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
141146
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
142147
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)

onnxruntime/contrib_ops/cpu/transformers/generation_shared.h

+6
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,14 @@ struct IGenerationParameters {
199199
int extra_decoding_ids_input_id = -1;
200200
int cross_qk_output_id = -1;
201201
int no_speech_probs_output_id = -1;
202+
203+
// Parameter for testing slow topk path. It can be updated by the below environment variable.
204+
bool use_fast_topk = true;
202205
};
203206

207+
// Environment variable to enable/disable fast topk kernel on GPU. Default is 1 (enabled).
208+
constexpr const char* kBeamSearchUseFastTopK = "ORT_BEAM_SEARCH_USE_FAST_TOPK";
209+
204210
} // namespace transformers
205211
} // namespace contrib
206212
} // namespace onnxruntime

onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc

+11-14
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,8 @@ Status ProcessLogits(const OrtValue& logits, //
524524
beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size());
525525
}
526526

527-
if (num_beams <= 32) {
527+
gsl::span<float> scores_to_process = beam_state->next_scores;
528+
if (parameters->use_fast_topk && num_beams <= 32) {
528529
constexpr size_t max_parts_of_vocab = 128;
529530
size_t candidate_count = SafeInt<size_t>(batch_beam_size) * 2 * num_beams;
530531
float* topk_tmp_buffer = beam_state->topk_buffer.data();
@@ -546,13 +547,6 @@ Status ProcessLogits(const OrtValue& logits, //
546547
beam_state->next_tokens.data(),
547548
beam_state->next_indices.data(),
548549
cuda_stream);
549-
550-
// Select [batch_size, 2 * num_beams] from [batch_size * num_beams, 2 * num_beams]
551-
#ifdef DEBUG_GENERATION
552-
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, 2 * num_beams);
553-
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
554-
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
555-
#endif
556550
} else {
557551
// Apply top-k selection like the following:
558552
// next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
@@ -588,18 +582,20 @@ Status ProcessLogits(const OrtValue& logits, //
588582
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
589583
batch_size, top_k, vocab_size, cuda_stream);
590584

591-
#ifdef DEBUG_GENERATION
592-
dumper->Print("next_scores before scorer", topk_scores->Data<float>(), batch_size, top_k);
593-
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k);
594-
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
595-
#endif
585+
scores_to_process = gsl::span<float>(topk_scores->MutableData<float>(), batch_size * top_k);
596586
}
597587

598588
// gsl::span doesn't convert from non const to const, so all we're doing here is making each const.
599-
gsl::span<const float> next_scores(beam_state->next_scores.data(), beam_state->next_scores.size());
589+
gsl::span<const float> next_scores(scores_to_process.data(), scores_to_process.size());
600590
gsl::span<const int32_t> next_tokens(beam_state->next_tokens.data(), beam_state->next_tokens.size());
601591
gsl::span<const int32_t> next_indices(beam_state->next_indices.data(), beam_state->next_indices.size());
602592

593+
#ifdef DEBUG_GENERATION
594+
dumper->Print("next_scores before scorer", next_scores.data(), batch_size, 2 * num_beams);
595+
dumper->Print("next_tokens before scorer", next_tokens.data(), batch_size, 2 * num_beams);
596+
dumper->Print("next_indices before scorer", next_indices.data(), batch_size, 2 * num_beams);
597+
#endif
598+
603599
beam_scorer->Process(
604600
*sequences,
605601
next_scores,
@@ -735,6 +731,7 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
735731
next_tokens,
736732
next_indices,
737733
stream_);
734+
738735
CUDA_CALL_THROW(cudaEventRecord(event_process_complete_.Get(), stream_));
739736

740737
cuda::LaunchBeamSearchScorer_AppendNextTokenToSequences(*state_cpu_,

onnxruntime/test/contrib_ops/beam_search_test.cc

+13-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "test/common/cuda_op_test_utils.h"
1010
#include "test/providers/model_tester.h"
1111
#include "test/util/include/current_test_name.h"
12+
#include "test/util/include/scoped_env_vars.h"
13+
#include "contrib_ops/cpu/transformers/generation_shared.h"
1214

1315
#ifdef USE_CUDA
1416
#include "core/providers/cuda/cuda_provider_options.h"
@@ -19,7 +21,7 @@ extern std::unique_ptr<Ort::Env> ort_env;
1921
namespace onnxruntime {
2022
namespace test {
2123

22-
TEST(BeamSearchTest, GptBeamSearchFp32) {
24+
void RunGptBeamSearchFp32() {
2325
std::vector<int64_t> input_ids_shape{3, 12};
2426
std::vector<int32_t> input_ids{
2527
0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620,
@@ -107,6 +109,16 @@ TEST(BeamSearchTest, GptBeamSearchFp32) {
107109
ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end()));
108110
}
109111

112+
TEST(BeamSearchTest, GptBeamSearchFp32) {
113+
RunGptBeamSearchFp32();
114+
}
115+
116+
TEST(BeamSearchTest, GptBeamSearchFp32_DisableFastTopK) {
117+
ScopedEnvironmentVariables scoped_env_vars{
118+
EnvVarMap{{onnxruntime::contrib::transformers::kBeamSearchUseFastTopK, "0"}}};
119+
RunGptBeamSearchFp32();
120+
}
121+
110122
TEST(BeamSearchTest, GptBeamSearchFp16) {
111123
std::vector<int64_t> input_ids_shape{3, 12};
112124
std::vector<int32_t> input_ids{

0 commit comments

Comments
 (0)