Skip to content

Commit 4534123

Browse files
committed
cont : remove redundant ifs
ggml-ci
1 parent 6f33a9d commit 4534123

File tree

1 file changed

+12
-40
lines changed

1 file changed

+12
-40
lines changed

src/llama-graph.cpp

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -281,43 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281
}
282282

283283
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284-
if (self_k_idxs) {
285-
mctx->set_input_k_idxs(self_k_idxs, ubatch);
286-
}
287-
288-
if (self_v_idxs) {
289-
mctx->set_input_v_idxs(self_v_idxs, ubatch);
290-
}
284+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
285+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
291286

292-
if (self_kq_mask) {
293-
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
294-
}
287+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
295288
}
296289

297290
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
298-
if (self_k_idxs) {
299-
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
300-
}
301-
302-
if (self_v_idxs) {
303-
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
304-
}
305-
306-
if (self_k_idxs_swa) {
307-
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
308-
}
291+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
292+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
309293

310-
if (self_v_idxs_swa) {
311-
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
312-
}
294+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
313295

314-
if (self_kq_mask) {
315-
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
316-
}
296+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
297+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
317298

318-
if (self_kq_mask_swa) {
319-
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
320-
}
299+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
321300
}
322301

323302
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -357,17 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
357336
}
358337

359338
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
360-
if (self_k_idxs) {
361-
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
362-
}
363-
364-
if (self_v_idxs) {
365-
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
366-
}
339+
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
340+
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
367341

368-
if (self_kq_mask) {
369-
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
370-
}
342+
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
371343

372344
const int64_t n_rs = mctx->get_recr()->get_n_rs();
373345

0 commit comments

Comments
 (0)