|
7 | 7 | #include "llama-kv-cache-unified.h"
|
8 | 8 | #include "llama-kv-cache-unified-iswa.h"
|
9 | 9 | #include "llama-kv-cache-recurrent.h"
|
| 10 | +#include "llama-kv-cache-hybrid-recurrent.h" |
10 | 11 |
|
11 | 12 | #include <cassert>
|
12 | 13 | #include <cmath>
|
@@ -396,6 +397,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
396 | 397 | }
|
397 | 398 | }
|
398 | 399 |
|
| 400 | +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( |
| 401 | + const llama_hparams & hparams, |
| 402 | + const llama_cparams & cparams, |
| 403 | + const llama_kv_cache_hybrid_recurrent_state * kv_state) : |
| 404 | + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { |
| 405 | +} |
| 406 | + |
399 | 407 | //
|
400 | 408 | // llm_graph_context
|
401 | 409 | //
|
@@ -953,8 +961,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
953 | 961 | return cur;
|
954 | 962 | }
|
955 | 963 |
|
956 |
| -ggml_tensor * llm_graph_context::build_inp_s_copy() const { |
957 |
| - const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); |
| 964 | +ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const { |
| 965 | + if (kv_state == nullptr) { |
| 966 | + kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); |
| 967 | + } |
958 | 968 |
|
959 | 969 | auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
960 | 970 |
|
@@ -1283,6 +1293,44 @@ ggml_tensor * llm_graph_context::build_attn(
|
1283 | 1293 | return cur;
|
1284 | 1294 | }
|
1285 | 1295 |
|
| 1296 | +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { |
| 1297 | + const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate); |
| 1298 | + |
| 1299 | + auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state); |
| 1300 | + |
| 1301 | + { |
| 1302 | + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); |
| 1303 | + |
| 1304 | + const auto n_kv = kv_state->get_state_attn()->get_n_kv(); |
| 1305 | + |
| 1306 | + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); |
| 1307 | + //cb(inp->self_kq_mask, "KQ_mask", -1); |
| 1308 | + ggml_set_input(inp->self_kq_mask); |
| 1309 | + |
| 1310 | + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; |
| 1311 | + } |
| 1312 | + |
| 1313 | + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); |
| 1314 | +} |
| 1315 | + |
| 1316 | +ggml_tensor * llm_graph_context::build_attn( |
| 1317 | + llm_graph_input_attn_kv_hybrid_recurrent * inp, |
| 1318 | + ggml_cgraph * gf, |
| 1319 | + ggml_tensor * wo, |
| 1320 | + ggml_tensor * wo_b, |
| 1321 | + ggml_tensor * q_cur, |
| 1322 | + ggml_tensor * k_cur, |
| 1323 | + ggml_tensor * v_cur, |
| 1324 | + ggml_tensor * kq_b, |
| 1325 | + ggml_tensor * v_mla, |
| 1326 | + float kq_scale, |
| 1327 | + int il) const { |
| 1328 | + return build_attn( |
| 1329 | + static_cast<llm_graph_input_attn_kv_unified *>(inp), |
| 1330 | + gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il |
| 1331 | + ); |
| 1332 | +} |
| 1333 | + |
1286 | 1334 | llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
1287 | 1335 | const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
1288 | 1336 |
|
|
0 commit comments