@@ -69,10 +69,10 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
69
69
return context.RunProgram (program);
70
70
};
71
71
72
- void InitVarStub (std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt ) {
72
+ void InitVarStub (std::ostringstream& ss, const Tensor* seqlen_k) {
73
73
if (seqlen_k != nullptr ) {
74
74
ss << " total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n " ;
75
- ss << " var past_sequence_length: u32 = " << (is_first_prompt ? " 0 " : " total_sequence_length - sequence_length" ) << " ;\n " ;
75
+ ss << " var past_sequence_length: u32 = select( total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0) ;\n " ;
76
76
} else {
77
77
ss << " let past_sequence_length = uniforms.past_sequence_length;\n " ;
78
78
}
@@ -106,7 +106,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
106
106
<< " let sequence_length = uniforms.M;\n "
107
107
<< " var total_sequence_length = uniforms.N;\n " ;
108
108
std::ostringstream oss;
109
- InitVarStub (oss, seqlen_k_, is_first_prompt_ );
109
+ InitVarStub (oss, seqlen_k_);
110
110
shader.MainFunctionBody () << oss.str ();
111
111
shader.MainFunctionBody () << " let kOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.kv_sequence_length * uniforms.K;\n " ;
112
112
if (has_present_key_) {
@@ -121,7 +121,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
121
121
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n "
122
122
" var idx = TILE_SIZE * local_id.y + local_id.x;\n " ;
123
123
124
- if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
124
+ if ((feed_past_key_ && has_present_key_) || ( past_present_share_buffer_ && !is_first_prompt_) ) {
125
125
shader.MainFunctionBody () << " if (n + local_id.y < past_sequence_length) {\n "
126
126
<< " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.past_sequence_length * uniforms.K;\n "
127
127
<< " tileK[idx] = " << (past_present_share_buffer_ ? " present_key" : " past_key" ) << " [pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n "
@@ -213,7 +213,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
213
213
{static_cast <uint32_t >(past_sequence_length)},
214
214
{static_cast <uint32_t >(parameters.kv_sequence_length_ )},
215
215
{static_cast <uint32_t >(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_ )},
216
- {static_cast <uint32_t >(parameters.n_reps )}})
216
+ {static_cast <uint32_t >(parameters.n_reps )},
217
+ {static_cast <uint32_t >(parameters.is_first_prompt_ ? 1 : 0 )}})
217
218
.SetOverridableConstants ({{static_cast <uint32_t >(tile_size)}});
218
219
219
220
return context.RunProgram (program);
@@ -231,7 +232,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
231
232
<< " let sequence_length = uniforms.sequence_length;\n "
232
233
<< " var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << " ;\n " ;
233
234
std::ostringstream oss;
234
- InitVarStub (oss, seqlen_k_, is_first_prompt_ );
235
+ InitVarStub (oss, seqlen_k_);
235
236
shader.MainFunctionBody () << oss.str ()
236
237
<< " let local_offset = local_idx * uniforms.elements_per_thread;\n "
237
238
<< " let offset = (global_idx / " << work_group_size_ << " ) * uniforms.total_sequence_length_comp + local_offset;\n "
@@ -285,20 +286,21 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
285
286
}
286
287
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1 ) / work_group_size;
287
288
288
- InPlaceSoftmaxProgram program{" InPlaceSoftmax" , work_group_size, components, is_first_prompt, seqlen_k};
289
+ InPlaceSoftmaxProgram program{" InPlaceSoftmax" , work_group_size, components, seqlen_k};
289
290
if (seqlen_k != nullptr ) {
290
291
program.AddInput ({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
291
292
}
292
293
program.AddOutputs ({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
293
- .CacheHint (work_group_size, is_first_prompt )
294
+ .CacheHint (work_group_size)
294
295
.SetDispatchGroupSize (1 , sequence_length, batch_size * num_heads)
295
296
.SetWorkgroupSize (work_group_size)
296
297
.AddUniformVariables ({{static_cast <uint32_t >(batch_size)},
297
298
{static_cast <uint32_t >(num_heads)},
298
299
{static_cast <uint32_t >(past_sequence_length)},
299
300
{static_cast <uint32_t >(sequence_length)},
300
301
{static_cast <uint32_t >(total_sequence_length_comp)},
301
- {static_cast <uint32_t >(elementsPerThread)}});
302
+ {static_cast <uint32_t >(elementsPerThread)},
303
+ {static_cast <uint32_t >(is_first_prompt ? 1 : 0 )}});
302
304
303
305
return context.RunProgram (program);
304
306
}
@@ -327,7 +329,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
327
329
<< " let sequence_length = uniforms.M;\n "
328
330
<< " var total_sequence_length = uniforms.K;\n " ;
329
331
std::ostringstream oss;
330
- InitVarStub (oss, seqlen_k_, is_first_prompt_ );
332
+ InitVarStub (oss, seqlen_k_);
331
333
shader.MainFunctionBody () << oss.str ();
332
334
shader.MainFunctionBody () << " let vOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.kv_sequence_length + n;\n " ;
333
335
if (has_present_value_) {
@@ -342,12 +344,12 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
342
344
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n "
343
345
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n " ;
344
346
345
- if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
347
+ if ((feed_past_value_ && has_present_value_) || ( past_present_share_buffer_ && !is_first_prompt_) ) {
346
348
shader.MainFunctionBody () << " if (w + local_id.y < past_sequence_length) {\n "
347
349
<< " let pastValueOffset = (workgroup_id.z / " << n_reps_ << " ) * uniforms.N * uniforms.past_sequence_length + n;\n "
348
350
<< " tileK[idx] = " << (past_present_share_buffer_ ? " present_value" : " past_value" ) << " [pastValueOffset + (w + local_id.y) * uniforms.N];\n "
349
351
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n "
350
- << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms. past_sequence_length) * uniforms.N];\n "
352
+ << " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n "
351
353
<< " }\n " ;
352
354
} else {
353
355
shader.MainFunctionBody () << " if (w + local_id.y < uniforms.kv_sequence_length) {\n "
@@ -425,7 +427,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
425
427
{static_cast <uint32_t >(past_sequence_length)},
426
428
{static_cast <uint32_t >(parameters.kv_sequence_length_ )},
427
429
{static_cast <uint32_t >(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_ )},
428
- {static_cast <uint32_t >(parameters.n_reps )}})
430
+ {static_cast <uint32_t >(parameters.n_reps )},
431
+ {static_cast <uint32_t >(parameters.is_first_prompt_ )}})
429
432
.SetOverridableConstants ({{static_cast <uint32_t >(tile_size)}});
430
433
431
434
return context.RunProgram (program);
0 commit comments