Skip to content

Commit 436dfc3

Browse files
[Native WebGPU] Fix the error when past and present key/value share buffer (#23315)
### Description Fix error causing incorrect output when past key/value share buffer with present key/value ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent e7d8596 commit 436dfc3

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

+16-13
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
6969
return context.RunProgram(program);
7070
};
7171

72-
void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) {
72+
void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) {
7373
if (seqlen_k != nullptr) {
7474
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
75-
ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n";
75+
ss << "var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);\n";
7676
} else {
7777
ss << "let past_sequence_length = uniforms.past_sequence_length;\n";
7878
}
@@ -106,7 +106,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
106106
<< "let sequence_length = uniforms.M;\n"
107107
<< "var total_sequence_length = uniforms.N;\n";
108108
std::ostringstream oss;
109-
InitVarStub(oss, seqlen_k_, is_first_prompt_);
109+
InitVarStub(oss, seqlen_k_);
110110
shader.MainFunctionBody() << oss.str();
111111
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
112112
if (has_present_key_) {
@@ -121,7 +121,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
121121
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
122122
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";
123123

124-
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
124+
if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
125125
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
126126
<< " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n"
127127
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
@@ -213,7 +213,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
213213
{static_cast<uint32_t>(past_sequence_length)},
214214
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
215215
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
216-
{static_cast<uint32_t>(parameters.n_reps)}})
216+
{static_cast<uint32_t>(parameters.n_reps)},
217+
{static_cast<uint32_t>(parameters.is_first_prompt_ ? 1 : 0)}})
217218
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
218219

219220
return context.RunProgram(program);
@@ -231,7 +232,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
231232
<< "let sequence_length = uniforms.sequence_length;\n"
232233
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
233234
std::ostringstream oss;
234-
InitVarStub(oss, seqlen_k_, is_first_prompt_);
235+
InitVarStub(oss, seqlen_k_);
235236
shader.MainFunctionBody() << oss.str()
236237
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
237238
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
@@ -285,20 +286,21 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
285286
}
286287
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;
287288

288-
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k};
289+
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k};
289290
if (seqlen_k != nullptr) {
290291
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
291292
}
292293
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
293-
.CacheHint(work_group_size, is_first_prompt)
294+
.CacheHint(work_group_size)
294295
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
295296
.SetWorkgroupSize(work_group_size)
296297
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
297298
{static_cast<uint32_t>(num_heads)},
298299
{static_cast<uint32_t>(past_sequence_length)},
299300
{static_cast<uint32_t>(sequence_length)},
300301
{static_cast<uint32_t>(total_sequence_length_comp)},
301-
{static_cast<uint32_t>(elementsPerThread)}});
302+
{static_cast<uint32_t>(elementsPerThread)},
303+
{static_cast<uint32_t>(is_first_prompt ? 1 : 0)}});
302304

303305
return context.RunProgram(program);
304306
}
@@ -327,7 +329,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
327329
<< "let sequence_length = uniforms.M;\n"
328330
<< "var total_sequence_length = uniforms.K;\n";
329331
std::ostringstream oss;
330-
InitVarStub(oss, seqlen_k_, is_first_prompt_);
332+
InitVarStub(oss, seqlen_k_);
331333
shader.MainFunctionBody() << oss.str();
332334
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
333335
if (has_present_value_) {
@@ -342,12 +344,12 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
342344
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
343345
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n";
344346

345-
if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
347+
if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
346348
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
347349
<< " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n"
348350
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
349351
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
350-
<< " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n"
352+
<< " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n"
351353
<< " }\n";
352354
} else {
353355
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n"
@@ -425,7 +427,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
425427
{static_cast<uint32_t>(past_sequence_length)},
426428
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
427429
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
428-
{static_cast<uint32_t>(parameters.n_reps)}})
430+
{static_cast<uint32_t>(parameters.n_reps)},
431+
{static_cast<uint32_t>(parameters.is_first_prompt_)}})
429432
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
430433

431434
return context.RunProgram(program);

onnxruntime/contrib_ops/webgpu/bert/attention.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
4949
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
5050
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
5151
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
52-
{"n_reps", ProgramUniformVariableDataType::Uint32});
52+
{"n_reps", ProgramUniformVariableDataType::Uint32},
53+
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
5354

5455
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
5556

@@ -67,8 +68,8 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
6768

6869
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
6970
public:
70-
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr)
71-
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) {
71+
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr)
72+
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) {
7273
}
7374

7475
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -78,13 +79,13 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
7879
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
7980
{"sequence_length", ProgramUniformVariableDataType::Uint32},
8081
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
81-
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});
82+
{"elements_per_thread", ProgramUniformVariableDataType::Uint32},
83+
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
8284

8385
private:
8486
int work_group_size_;
8587
int components_;
8688
const Tensor* seqlen_k_;
87-
bool is_first_prompt_;
8889
};
8990

9091
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
@@ -104,7 +105,8 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
104105
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
105106
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
106107
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
107-
{"n_reps", ProgramUniformVariableDataType::Uint32});
108+
{"n_reps", ProgramUniformVariableDataType::Uint32},
109+
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
108110

109111
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
110112

0 commit comments

Comments
 (0)