Skip to content

Commit 91f2960

Browse files
committed
feat: Support hybrid recurrent in llama-graph
NOTE: I intentionally did not add support for s_mask since it will be going away soon Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 5276203 commit 91f2960

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

src/llama-graph.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "llama-kv-cache-unified.h"
88
#include "llama-kv-cache-unified-iswa.h"
99
#include "llama-kv-cache-recurrent.h"
10+
#include "llama-kv-cache-hybrid-recurrent.h"
1011

1112
#include <cassert>
1213
#include <cmath>
@@ -396,6 +397,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
396397
}
397398
}
398399

400+
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
401+
const llama_hparams & hparams,
402+
const llama_cparams & cparams,
403+
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
404+
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
405+
}
406+
399407
//
400408
// llm_graph_context
401409
//
@@ -953,8 +961,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
953961
return cur;
954962
}
955963

956-
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
964+
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
965+
if (kv_state == nullptr) {
966+
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
967+
}
958968

959969
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
960970

@@ -1283,6 +1293,44 @@ ggml_tensor * llm_graph_context::build_attn(
12831293
return cur;
12841294
}
12851295

1296+
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
1297+
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1298+
1299+
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state);
1300+
1301+
{
1302+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
1303+
1304+
const auto n_kv = kv_state->get_state_attn()->get_n_kv();
1305+
1306+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1307+
//cb(inp->self_kq_mask, "KQ_mask", -1);
1308+
ggml_set_input(inp->self_kq_mask);
1309+
1310+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1311+
}
1312+
1313+
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
1314+
}
1315+
1316+
ggml_tensor * llm_graph_context::build_attn(
1317+
llm_graph_input_attn_kv_hybrid_recurrent * inp,
1318+
ggml_cgraph * gf,
1319+
ggml_tensor * wo,
1320+
ggml_tensor * wo_b,
1321+
ggml_tensor * q_cur,
1322+
ggml_tensor * k_cur,
1323+
ggml_tensor * v_cur,
1324+
ggml_tensor * kq_b,
1325+
ggml_tensor * v_mla,
1326+
float kq_scale,
1327+
int il) const {
1328+
return build_attn(
1329+
static_cast<llm_graph_input_attn_kv_unified *>(inp),
1330+
gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
1331+
);
1332+
}
1333+
12861334
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
12871335
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
12881336

src/llama-graph.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct llama_memory_state_i;
2222
class llama_kv_cache_unified_state;
2323
class llama_kv_cache_unified_iswa_state;
2424
class llama_kv_cache_recurrent_state;
25+
class llama_kv_cache_hybrid_recurrent_state;
2526

2627
// certain models (typically multi-modal) can produce different types of graphs
2728
enum llm_graph_type {
@@ -242,7 +243,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
242243
cparams(cparams),
243244
kv_state(kv_state) {
244245
}
245-
~llm_graph_input_attn_kv_unified() = default;
246+
virtual ~llm_graph_input_attn_kv_unified() = default;
246247

247248
void set_input(const llama_ubatch * ubatch) override;
248249

@@ -285,6 +286,16 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
285286
const llama_kv_cache_unified_iswa_state * kv_state;
286287
};
287288

289+
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
290+
public:
291+
llm_graph_input_attn_kv_hybrid_recurrent(
292+
const llama_hparams & hparams,
293+
const llama_cparams & cparams,
294+
const llama_kv_cache_hybrid_recurrent_state * kv_state);
295+
296+
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
297+
};
298+
288299
class llm_graph_input_attn_cross : public llm_graph_input_i {
289300
public:
290301
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -508,7 +519,7 @@ struct llm_graph_context {
508519
ggml_tensor * build_inp_out_ids() const;
509520
ggml_tensor * build_inp_mean() const;
510521
ggml_tensor * build_inp_cls() const;
511-
ggml_tensor * build_inp_s_copy() const;
522+
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
512523

513524
ggml_tensor * build_inp_cross_embd() const;
514525
ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -574,6 +585,21 @@ struct llm_graph_context {
574585
float kq_scale,
575586
int il) const;
576587

588+
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
589+
590+
ggml_tensor * build_attn(
591+
llm_graph_input_attn_kv_hybrid_recurrent * inp,
592+
ggml_cgraph * gf,
593+
ggml_tensor * wo,
594+
ggml_tensor * wo_b,
595+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
596+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
597+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
598+
ggml_tensor * kq_b,
599+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
600+
float kq_scale,
601+
int il) const;
602+
577603
llm_graph_input_attn_cross * build_attn_inp_cross() const;
578604

579605
ggml_tensor * build_attn(

0 commit comments

Comments
 (0)