@@ -281,43 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281
281
}
282
282
283
283
void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284
- if (self_k_idxs) {
285
- mctx->set_input_k_idxs (self_k_idxs, ubatch);
286
- }
287
-
288
- if (self_v_idxs) {
289
- mctx->set_input_v_idxs (self_v_idxs, ubatch);
290
- }
284
+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
285
+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
291
286
292
- if (self_kq_mask) {
293
- mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
294
- }
287
+ mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
295
288
}
296
289
297
290
void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
298
- if (self_k_idxs) {
299
- mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
300
- }
301
-
302
- if (self_v_idxs) {
303
- mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
304
- }
305
-
306
- if (self_k_idxs_swa) {
307
- mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
308
- }
291
+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
292
+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
309
293
310
- if (self_v_idxs_swa) {
311
- mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
312
- }
294
+ mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
313
295
314
- if (self_kq_mask) {
315
- mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
316
- }
296
+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
297
+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
317
298
318
- if (self_kq_mask_swa) {
319
- mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
320
- }
299
+ mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
321
300
}
322
301
323
302
void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
@@ -357,17 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
357
336
}
358
337
359
338
void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
360
- if (self_k_idxs) {
361
- mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
362
- }
363
-
364
- if (self_v_idxs) {
365
- mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
366
- }
339
+ mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
340
+ mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
367
341
368
- if (self_kq_mask) {
369
- mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
370
- }
342
+ mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
371
343
372
344
const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
373
345
0 commit comments