Skip to content

llama : support Jamba hybrid Transformer-Mamba models #7531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 52 commits into
base: master
Choose a base branch
from

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented May 25, 2024

This adds support for Jamba (fixes #6372). (https://arxiv.org/abs/2403.19887)

(this has been open for a while, and this description was very different originally (much broader scope), feel free to look at the edit history)

New features

  • Jamba support
    • The first hybrid Transformer+Mamba model in llama.cpp
  • State checkpoints for recurrent models
    • Works best when n_parallel is at least 3 or 4 times the number of actual users
    • Allows backtracking tokens from the end of the last generation without having to reprocess the whole context
      • Very useful with the server example when trimming the stop string
  • Variable GQA (see also OpenELM support #7359)
    • GGUF metadata {model}.attention.head_count_kv can now also be an array of int32_t, one value per layer
    • Layers with 0 kv heads are considered recurrent layers (Mamba, in the case of Jamba).
    • This will make proper support of DeciLM possible

Internal changes

  • move build_mamba_layer functions to a shared parent class between both llm_build_mamba and llm_build_jamba.
  • remove llm_graph_context::build_inp_mem_hybrid
    • Redundant, see next point.
  • remove llm_graph_input_mem_hybrid
    • It's redundant with llm_graph_input_rs and llm_graph_input_attn_kv_unified, and causes unnecessary duplication and overloads of build_rs and build_attn.

Future ideas

  • Recurrent state checkpoints, to allow for backtracking recurrent states
  • Fairly split the available KV cells among active sequences, similarly to RS cells.
  • Handle token shift (and Self-Extend?) when finding a slot.
    • This could help with the fair split of KV cells by freeing cells of sequences which use more than their fair share of cells.

Testing

Example output of jamba-900M-v0.13-KIx2 (click to expand)
$  ./bin/main -m /srv/LLMstash/tmp/jamba-900M.bf16.gguf --temp 0 -e -p "I believe the meaning of life is" --repeat-penalty 1.2 --repeat-last-n 256 -c 16384 -n 256
Log start
main: build = 3003 (0fd13e94)
main: built with gcc (GCC) 13.2.0 for x86_64-unknown-linux-gnu
main: seed  = 1716594011
llama_model_loader: loaded meta data with 26 key-value pairs and 189 tensors from /srv/LLMstash/tmp/jamba-900M.bf16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jamba
llama_model_loader: - kv   1:                               general.name str              = jamba-900M-v0.13-KIx2
llama_model_loader: - kv   2:                          jamba.block_count u32              = 12
llama_model_loader: - kv   3:                       jamba.context_length u32              = 16384
llama_model_loader: - kv   4:                     jamba.embedding_length u32              = 1024
llama_model_loader: - kv   5:                  jamba.feed_forward_length u32              = 4096
llama_model_loader: - kv   6:                 jamba.attention.head_count u32              = 32
llama_model_loader: - kv   7:              jamba.attention.head_count_kv arr[i32,12]      = [0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8, 0]
llama_model_loader: - kv   8:                      jamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv   9:                       jamba.ssm.inner_size u32              = 2048
llama_model_loader: - kv  10:                       jamba.ssm.state_size u32              = 16
llama_model_loader: - kv  11:                   jamba.ssm.time_step_rank u32              = 256
llama_model_loader: - kv  12:     jamba.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  13:                         jamba.expert_count u32              = 8
llama_model_loader: - kv  14:                    jamba.expert_used_count u32              = 2
llama_model_loader: - kv  15:                          general.file_type u32              = 32
llama_model_loader: - kv  16:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  17:                         tokenizer.ggml.pre str              = gpt-2
llama_model_loader: - kv  18:                      tokenizer.ggml.tokens arr[str,65024]   = ["<EOT>", "<META>", "<META_START>", "...
llama_model_loader: - kv  19:                  tokenizer.ggml.token_type arr[i32,65024]   = [3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  20:                      tokenizer.ggml.merges arr[str,64739]   = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "ĠĠ �...
llama_model_loader: - kv  21:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  22:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  23:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  25:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  121 tensors
llama_model_loader: - type bf16:   68 tensors
llm_load_vocab: special tokens definition check successful ( 29/65024 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = jamba
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 65024
llm_load_print_meta: n_merges         = 64739
llm_load_print_meta: n_ctx_train      = 16384
llm_load_print_meta: n_embd           = 1024
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 12
llm_load_print_meta: n_rot            = 32
llm_load_print_meta: n_embd_head_k    = 32
llm_load_print_meta: n_embd_head_v    = 32
llm_load_print_meta: n_gqa            = 0
llm_load_print_meta: n_embd_k_gqa     = 0
llm_load_print_meta: n_embd_v_gqa     = 0
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 4096
llm_load_print_meta: n_expert         = 8
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = -1
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 16384
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 2048
llm_load_print_meta: ssm_d_state      = 16
llm_load_print_meta: ssm_dt_rank      = 256
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 887.66 M
llm_load_print_meta: model size       = 1.67 GiB (16.19 BPW) 
llm_load_print_meta: general.name     = jamba-900M-v0.13-KIx2
llm_load_print_meta: BOS token        = 0 '<EOT>'
llm_load_print_meta: EOS token        = 0 '<EOT>'
llm_load_print_meta: UNK token        = 0 '<EOT>'
llm_load_print_meta: PAD token        = 0 '<EOT>'
llm_load_print_meta: LF token         = 133 'Ä'
llm_load_tensors: ggml ctx size =    0.09 MiB
llm_load_tensors:        CPU buffer size =  1713.16 MiB
......................................
llama_new_context_with_model: n_ctx      = 16384
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_cache_init:        CPU cache buf size =    49.34 MiB
llama_new_context_with_model: SSM state size =     1.34 MiB, R (f32):    0.21 MiB, S (f32):    1.12 MiB
llama_new_context_with_model: KV cache size  =    48.00 MiB, K (f16):   24.00 MiB, V (f16):   24.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.25 MiB
llama_new_context_with_model:        CPU compute buffer size =  1062.03 MiB
llama_new_context_with_model: graph nodes  = 621
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 2 / 4 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
        repeat_last_n = 256, repeat_penalty = 1.200, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 16384, n_batch = 2048, n_predict = 256, n_keep = 0


<EOT>I believe the meaning of life is not to be found in a single word, but rather as an expression of one's own feelings and thoughts.

The idea that we are all born with our bodies, whether they are human or animal, has been around for centuries. It was believed by some that it was something like a body made up of bones, which were attached to each other at birth. The most common form of this type of bone is called a "bone." This is what makes it so hard to tell if you're alive or dead. In fact, there are many different types of bones, including those that have been used for various purposes such as healing wounds, wounding wounds, etc.

In ancient times, people had a lot of teeth, and these were often very small. They could also be placed on top of their heads, where they would sit down and look at them. These were usually large, round stones, which were sometimes covered with hair. When the skin was removed from the head, the bones became more prominent, and the muscles began to grow larger.

This kind of bone was known as a "bone" because it was made out of two parts: the outermost part (the innermost portion) and the innermost part (the outermost
llama_print_timings:        load time =     252.28 ms
llama_print_timings:      sample time =     303.07 ms /   256 runs   (    1.18 ms per token,   844.68 tokens per second)
llama_print_timings: prompt eval time =     200.72 ms /     8 tokens (   25.09 ms per token,    39.86 tokens per second)
llama_print_timings:        eval time =   12516.79 ms /   255 runs   (   49.09 ms per token,    20.37 tokens per second)
llama_print_timings:       total time =   13213.95 ms /   263 tokens
Log end

@compilade compilade added enhancement New feature or request model Model specific refactoring Refactoring need feedback Testing and feedback with results are needed embeddings embedding related topics python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 25, 2024
@compilade compilade marked this pull request as draft May 25, 2024 03:38
llama.cpp Outdated
Comment on lines 5244 to 5248
switch (hparams.n_layer) {
// TODO: Jamba layers are a bit heterogenous, so naming this is hard.
case 12: // 900M 8x???M
case 32: // 51B 16x?B
default: model.type = e_model::MODEL_UNKNOWN;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what model size type(s) I should give to Jamba models.

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label May 25, 2024
Copy link
Contributor

github-actions bot commented May 25, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 557 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8384.34ms p(95)=20451.68ms fails=, finish reason: stop=510 truncated=47
  • Prompt processing (pp): avg=102.96tk/s p(95)=478.95tk/s
  • Token generation (tg): avg=36.48tk/s p(95)=48.13tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=compilade/refactor-kv-cache commit=fee3c1d740c0e027c81e2f2f3fb48d619857175f

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 306.61, 306.61, 306.61, 306.61, 306.61, 572.5, 572.5, 572.5, 572.5, 572.5, 579.51, 579.51, 579.51, 579.51, 579.51, 601.73, 601.73, 601.73, 601.73, 601.73, 638.34, 638.34, 638.34, 638.34, 638.34, 702.62, 702.62, 702.62, 702.62, 702.62, 704.56, 704.56, 704.56, 704.56, 704.56, 718.91, 718.91, 718.91, 718.91, 718.91, 723.54, 723.54, 723.54, 723.54, 723.54, 739.59, 739.59, 739.59, 739.59, 739.59, 771.46, 771.46, 771.46, 771.46, 771.46, 802.48, 802.48, 802.48, 802.48, 802.48, 815.12, 815.12, 815.12, 815.12, 815.12, 804.65, 804.65, 804.65, 804.65, 804.65, 797.38, 797.38, 797.38, 797.38, 797.38, 800.86, 800.86, 800.86, 800.86, 800.86, 805.61, 805.61, 805.61, 805.61, 805.61, 803.64, 803.64, 803.64, 803.64, 803.64, 824.04, 824.04, 824.04, 824.04, 824.04, 823.3, 823.3, 823.3, 823.3, 823.3, 830.32, 830.32, 830.32, 830.32, 830.32, 832.47, 832.47, 832.47, 832.47, 832.47, 846.38, 846.38, 846.38, 846.38, 846.38, 842.07, 842.07, 842.07, 842.07, 842.07, 844.76, 844.76, 844.76, 844.76, 844.76, 861.96, 861.96, 861.96, 861.96, 861.96, 855.54, 855.54, 855.54, 855.54, 855.54, 854.58, 854.58, 854.58, 854.58, 854.58, 856.84, 856.84, 856.84, 856.84, 856.84, 860.17, 860.17, 860.17, 860.17, 860.17, 858.21, 858.21, 858.21, 858.21, 858.21, 861.33, 861.33, 861.33, 861.33, 861.33, 871.29, 871.29, 871.29, 871.29, 871.29, 847.29, 847.29, 847.29, 847.29, 847.29, 832.73, 832.73, 832.73, 832.73, 832.73, 831.59, 831.59, 831.59, 831.59, 831.59, 831.76, 831.76, 831.76, 831.76, 831.76, 835.52, 835.52, 835.52, 835.52, 835.52, 836.15, 836.15, 836.15, 836.15, 836.15, 836.37, 836.37, 836.37, 836.37, 836.37, 817.57, 817.57, 817.57, 817.57, 817.57, 820.16, 820.16, 820.16, 820.16, 820.16, 820.49, 820.49, 820.49, 820.49, 820.49, 820.0, 820.0, 820.0, 820.0, 820.0, 817.08, 817.08, 817.08, 817.08, 817.08, 820.83, 820.83, 820.83, 820.83, 820.83, 823.82, 823.82, 823.82, 823.82, 823.82, 823.03, 823.03, 823.03, 823.03, 823.03, 827.7, 827.7, 827.7, 827.7, 827.7, 826.96, 826.96, 826.96, 826.96, 826.96, 833.12, 833.12, 833.12, 833.12, 833.12, 832.75, 832.75, 832.75, 832.75, 832.75, 832.65, 832.65, 832.65, 832.65, 832.65, 826.23, 826.23, 826.23, 826.23, 826.23, 827.38, 827.38, 827.38, 827.38, 827.38, 827.43, 827.43, 827.43, 827.43, 827.43, 827.46, 827.46, 827.46, 827.46, 827.46, 825.87, 825.87, 825.87, 825.87, 825.87, 828.84, 828.84, 828.84, 828.84, 828.84, 829.05, 829.05, 829.05, 829.05, 829.05, 829.15, 829.15, 829.15]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 42.1, 42.1, 42.1, 42.1, 42.1, 30.42, 30.42, 30.42, 30.42, 30.42, 28.2, 28.2, 28.2, 28.2, 28.2, 28.69, 28.69, 28.69, 28.69, 28.69, 29.63, 29.63, 29.63, 29.63, 29.63, 30.55, 30.55, 30.55, 30.55, 30.55, 32.02, 32.02, 32.02, 32.02, 32.02, 32.76, 32.76, 32.76, 32.76, 32.76, 33.41, 33.41, 33.41, 33.41, 33.41, 33.56, 33.56, 33.56, 33.56, 33.56, 34.05, 34.05, 34.05, 34.05, 34.05, 33.99, 33.99, 33.99, 33.99, 33.99, 33.35, 33.35, 33.35, 33.35, 33.35, 33.38, 33.38, 33.38, 33.38, 33.38, 32.25, 32.25, 32.25, 32.25, 32.25, 31.71, 31.71, 31.71, 31.71, 31.71, 30.36, 30.36, 30.36, 30.36, 30.36, 30.81, 30.81, 30.81, 30.81, 30.81, 30.82, 30.82, 30.82, 30.82, 30.82, 30.39, 30.39, 30.39, 30.39, 30.39, 30.41, 30.41, 30.41, 30.41, 30.41, 30.5, 30.5, 30.5, 30.5, 30.5, 30.85, 30.85, 30.85, 30.85, 30.85, 30.97, 30.97, 30.97, 30.97, 30.97, 31.24, 31.24, 31.24, 31.24, 31.24, 31.45, 31.45, 31.45, 31.45, 31.45, 31.23, 31.23, 31.23, 31.23, 31.23, 31.18, 31.18, 31.18, 31.18, 31.18, 31.36, 31.36, 31.36, 31.36, 31.36, 31.43, 31.43, 31.43, 31.43, 31.43, 31.63, 31.63, 31.63, 31.63, 31.63, 31.71, 31.71, 31.71, 31.71, 31.71, 31.78, 31.78, 31.78, 31.78, 31.78, 31.61, 31.61, 31.61, 31.61, 31.61, 31.48, 31.48, 31.48, 31.48, 31.48, 31.35, 31.35, 31.35, 31.35, 31.35, 31.43, 31.43, 31.43, 31.43, 31.43, 31.54, 31.54, 31.54, 31.54, 31.54, 31.71, 31.71, 31.71, 31.71, 31.71, 31.79, 31.79, 31.79, 31.79, 31.79, 31.85, 31.85, 31.85, 31.85, 31.85, 31.71, 31.71, 31.71, 31.71, 31.71, 31.42, 31.42, 31.42, 31.42, 31.42, 31.06, 31.06, 31.06, 31.06, 31.06, 29.65, 29.65, 29.65, 29.65, 29.65, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.4, 29.4, 29.4, 29.4, 29.4, 29.46, 29.46, 29.46, 29.46, 29.46, 29.58, 29.58, 29.58, 29.58, 29.58, 29.61, 29.61, 29.61, 29.61, 29.61, 29.57, 29.57, 29.57, 29.57, 29.57, 29.58, 29.58, 29.58, 29.58, 29.58, 29.45, 29.45, 29.45, 29.45, 29.45, 29.55, 29.55, 29.55, 29.55, 29.55, 29.69, 29.69, 29.69, 29.69, 29.69, 29.83, 29.83, 29.83, 29.83, 29.83, 29.9, 29.9, 29.9, 29.9, 29.9, 29.96, 29.96, 29.96, 29.96, 29.96, 29.97, 29.97, 29.97, 29.97, 29.97, 30.03, 30.03, 30.03]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14, 0.14, 0.14, 0.14, 0.14, 0.37, 0.37, 0.37, 0.37, 0.37, 0.25, 0.25, 0.25, 0.25, 0.25, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.25, 0.25, 0.25, 0.25, 0.25, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.41, 0.41, 0.41, 0.41, 0.41, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.23, 0.23, 0.23, 0.23, 0.23, 0.2, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.32, 0.32, 0.32, 0.32, 0.32, 0.21, 0.21, 0.21, 0.21, 0.21, 0.1, 0.1, 0.1, 0.1, 0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.28, 0.28, 0.28, 0.28, 0.28, 0.3, 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.09, 0.09, 0.09, 0.09, 0.09, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.45, 0.45, 0.45, 0.45, 0.45, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.64, 0.64, 0.64, 0.64, 0.64, 0.36, 0.36, 0.36, 0.36, 0.36, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.29, 0.29, 0.29, 0.29, 0.29, 0.27, 0.27, 0.27, 0.27, 0.27, 0.24, 0.24, 0.24, 0.24, 0.24, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0]
                    
Loading

@arch-btw
Copy link
Contributor

Great job! Works for me too, it's very fast. There were some warnings during compilation, but nothing major.

<EOT>Hello!

I'll get a new one for you and I think this is going to be really cool, so good. And I'm sure there's lots of ways in which [...]

llama_print_timings:        load time =     286.42 ms
llama_print_timings:      sample time =     155.94 ms /   256 runs   (    0.61 ms per token,  1641.63 tokens per second)
llama_print_timings: prompt eval time =      70.77 ms /     3 tokens (   23.59 ms per token,    42.39 tokens per second)
llama_print_timings:        eval time =    9368.54 ms /   255 runs   (   36.74 ms per token,    27.22 tokens per second)
llama_print_timings:       total time =    9686.16 ms /   258 tokens

@TechxGenus
Copy link

Amazing work!
I initially tested Jamba-v0.1 on a machine with 500G RAM and it worked great!

./main -m ./Jamba-v0.1-hf-00001-of-00024.gguf -n 120 --prompt "def max(arr):" --temp 0
Log start
main: build = 3006 (fc59407e)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed  = 1716710334
llama_model_loader: additional 23 GGUFs metadata loaded.
llama_model_loader: loaded meta data with 31 key-value pairs and 531 tensors from ./Jamba-v0.1-hf-00001-of-00024.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jamba
llama_model_loader: - kv   1:                               general.name str              = Jamba-v0.1-hf
llama_model_loader: - kv   2:                          jamba.block_count u32              = 32
llama_model_loader: - kv   3:                       jamba.context_length u32              = 262144
llama_model_loader: - kv   4:                     jamba.embedding_length u32              = 4096
llama_model_loader: - kv   5:                  jamba.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 jamba.attention.head_count u32              = 32
llama_model_loader: - kv   7:              jamba.attention.head_count_kv arr[i32,32]      = [0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, ...
llama_model_loader: - kv   8:                      jamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv   9:                       jamba.ssm.inner_size u32              = 8192
llama_model_loader: - kv  10:                       jamba.ssm.state_size u32              = 16
llama_model_loader: - kv  11:                   jamba.ssm.time_step_rank u32              = 256
llama_model_loader: - kv  12:     jamba.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  13:                         jamba.expert_count u32              = 16
llama_model_loader: - kv  14:                    jamba.expert_used_count u32              = 2
llama_model_loader: - kv  15:                          general.file_type u32              = 32
llama_model_loader: - kv  16:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  17:                         tokenizer.ggml.pre str              = default
llama_model_loader: - kv  18:                      tokenizer.ggml.tokens arr[str,65536]   = ["<|pad|>", "<|startoftext|>", "<|end...
llama_model_loader: - kv  19:                      tokenizer.ggml.scores arr[f32,65536]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,65536]   = [3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  21:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  22:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  23:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  24:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - kv  28:                                   split.no u16              = 0
llama_model_loader: - kv  29:                                split.count u16              = 24
llama_model_loader: - kv  30:                        split.tensors.count i32              = 531
llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type bf16:  170 tensors
llm_load_vocab: special tokens definition check successful ( 1799/65536 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = jamba
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 65536
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 262144
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 0
llm_load_print_meta: n_embd_k_gqa     = 0
llm_load_print_meta: n_embd_v_gqa     = 0
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 16
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = -1
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 262144
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 8192
llm_load_print_meta: ssm_d_state      = 16
llm_load_print_meta: ssm_dt_rank      = 256
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 51.57 B
llm_load_print_meta: model size       = 96.30 GiB (16.04 BPW) 
llm_load_print_meta: general.name     = Jamba-v0.1-hf
llm_load_print_meta: BOS token        = 1 '<|startoftext|>'
llm_load_print_meta: EOS token        = 2 '<|endoftext|>'
llm_load_print_meta: UNK token        = 3 '<|unk|>'
llm_load_print_meta: PAD token        = 0 '<|pad|>'
llm_load_print_meta: LF token         = 1554 '<0x0A>'
llm_load_print_meta: EOT token        = 2 '<|endoftext|>'
llm_load_tensors: ggml ctx size =    0.24 MiB
llm_load_tensors:        CPU buffer size =  4851.72 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  5095.47 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  3584.03 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4851.77 MiB
llm_load_tensors:        CPU buffer size =  3584.03 MiB
..............................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_cache_init:        CPU cache buf size =    24.63 MiB
llama_new_context_with_model: SSM state size =    16.62 MiB, R (f32):    2.62 MiB, S (f32):   14.00 MiB
llama_new_context_with_model: KV cache size  =     8.00 MiB, K (f16):    4.00 MiB, V (f16):    4.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.25 MiB
llama_new_context_with_model:        CPU compute buffer size =   145.10 MiB
llama_new_context_with_model: graph nodes  = 1730
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 32 / 64 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 2048, n_predict = 120, n_keep = 1


<|startoftext|> def max(arr):
    return max(arr)


def min(arr):
    return min(arr)


def mean(arr):
    return sum(arr) / len(arr)


def median(arr):
    arr.sort()
    if len(arr) % 2 == 0:
        return (arr[len(arr) // 2] + arr[len(arr) // 2 - 1]) / 2
    else:
        return arr[len(arr) // 2]


llama_print_timings:        load time =   82494.54 ms
llama_print_timings:      sample time =       9.61 ms /   120 runs   (    0.08 ms per token, 12490.89 tokens per second)
llama_print_timings: prompt eval time =     666.33 ms /     6 tokens (  111.06 ms per token,     9.00 tokens per second)
llama_print_timings:        eval time =   27656.31 ms /   119 runs   (  232.41 ms per token,     4.30 tokens per second)
llama_print_timings:       total time =   28862.18 ms /   125 tokens
Log end

ggml.c Outdated
Comment on lines 16264 to 16267
if (n_rs > 1) {
// multiple sequences means it's hard to know when it's the first time a state is read,
// so copy them all over to the destination, just to be sure.
for (int i3 = 0; i3 < n_kv; ++i3) {
for (int i3 = 0; i3 < n_rs; ++i3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at adding the missing Metal kernels for SSM_CONV and SSM_SCAN. I'm wondering if this part of the kernels where we copy src0 -> dst could be extracted outside of the operation via ggml_cpy + ggml_view or ggml_acc? Would simplify the implementation

Also, I still haven't understood the details of the computation, but if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this part of the kernels where we copy src0 -> dst could be extracted outside of the operation via ggml_cpy + ggml_view or ggml_acc? Would simplify the implementation

Yes, this is definitely possible. I'll find a way to extract the copies outside.

if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.

For SSM_SCAN, I think there's a way to fully express it in terms of other ops, though it will use much more memory because of the big intermediate tensors, and new operators like SOFT_PLUS and EXP would be needed instead. But different lengths of simultaneous sequences might make a custom operator still necessary. I'll think about ways to make it simpler, especially since other recurrent architectures (like RWKV) will also need to work on multiple sequences per batch.

For simplifying SSM_CONV, I don't think ggml_conv supports working on independent 1D rolling windows with varying sequence lengths.

When working on a single sequence, though, it's quite simple to do the equivalent of ggml_ssm_conv with a self-overlapping view, as I did in my original implementation which I described in more detail in #5328 (comment):

https://github.com/ggerganov/llama.cpp/blob/64fbce052373faf07a36b599528f8fe1cb1d62fb/llama.cpp#L6973-L6982

Setting nb[2] to the element size makes the view self-overlapping.

But this would create too many nodes in the compute graph when done with multiple sequences (unless they're always all the same length in which case the 4th dimension could be used), so a custom operator is necessary.

Copy link
Member

@ggerganov ggerganov May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that we might consider is to unfuse the n_rs dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch

The main goal would be to simplify the SSM operators, and potentially express them as other existing ops if possible. But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention. The main purpose of supporting this mode would be to achieve reproducible results during parallel decoding (currently, decoding the same sequence in parallel can yield slightly different results due to the unified KV cache).

Just throwing some thoughts that I have so far - will continue looking at the PR in the next days

Edit: I was writing this comment before I saw you posted - will take a look tomorrow

Copy link
Collaborator Author

@compilade compilade May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that we might consider is to unfuse the n_rs dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch

Yes, this would be doable, but would make the number of compute graph nodes scale with the number of sequences. (EDIT: if it's split when making ubatches, then the number of compute graph nodes can stay constant)

Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.

The recurrent steps are simpler for ubatches with sequence lengths of 1, but prompt processing performance would be much slower than with a per-recurrent-architecture operator for longer sequences. Still thinking about ways to generalize this while keeping good performance.

But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention.

For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.

I also think there's a way to keep the unified KV cache (one buffer) and chunk it to make each sequence have their own independent contiguous reserved cells. Batching sequences together might still be possible though, if the KQ mask gets another dimension (the number of sequences in the ubatch, and the number of new tokens per sequence instead of the batch size) so that these equal-sized "chunks" get processed independently in parallel. But this might not work (because the newly-calculated KV cells have to be copied in a bunch of not-regularly-spaced places), unless... unless maybe with some kind of ggml_set_rows? Not sure about the transposed V cache, though.

A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's split when making ubatches, then the number of compute graph nodes can stay constant

No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance

Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.

For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.

Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?

A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).

From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.

I'm currently working on a big refactor of how Mamba (and Jamba) works to make all sequences of a sub-batch be of the same length (initially only for models with recurrent states), and to make recurrent state slots contiguous, with the goal of simplifying the SSM operations (and removing GGML_OP_SSM_CONV), so that GPU support will be much easier to implement after that.

Looking forward to this!

Copy link
Collaborator Author

@compilade compilade May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance

It will sacrifice some performance, but only in the cases where a batch contains an unequal number of tokens for each affected sequence. So this should not affect large prompt processing or parallel text generation, if both are not done in the same batch.

Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?

This is not about adding dummy tokens, but about making the number of new tokens in each ubatch the same per sequence. I think the overhead will be minmal, though there is still some.

Let me illustrate.

Let's say there's a batch with new tokens for 4 sequences of length 16, 7, 1, 1, respectively.

0: ################
1: #######
2: #
3: #

Splitting that into equal-length sequences would make 3 ubatches, like so:

0: #
1: #
2: #
3: #
0: ######
1: ######
0: #########

Each of these shapes are nice and rectangular, which is good for recurrent architectures because their operations can be more easily batched across sequences this way.

But I'm not yet sure if it would also benefit Transformers, which is why I'm thinking of initially only enabling the equal-length splitting for recurrent (or hybrid) model architectures.

From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.

Doing this with a constant number of graph nodes is pretty much what using same-length sequences (as illustrated above) allows, because the split into same-sequence tokens can then simply become another tensor dimension.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, got it. Good idea. I'm also not sure if this can help Transformers, but it's something to think about 👍

@theogbob
Copy link

theogbob commented Jan 4, 2025

Progress?

@lexasub
Copy link
Contributor

lexasub commented Jan 28, 2025

@theogbob , you may tag author, @compilade, which progress? )

@hg0428
Copy link

hg0428 commented Mar 12, 2025

Progress @compilade ?

@gabe-l-hart
Copy link
Contributor

@compilade Thanks for all the work here! I've also been working through a very similar architecture for bamba independently. Bamba is essentially the same as Jamba, but with mamba2 layers instead of mamba layers.

I suspect your implementation of llama_rs_cache is a much better approach than the one I took of simply creating a duplicate llama_kv_cache and conditionally making the two caches have zero-sized layers. I've also based my branch on your mamba2 work, so I'd be really interested in consolidating these threads and helping where possible with your work to support hybrid-recurrent models (we are really interested in these architectures at IBM).

It looks like this branch is pretty out of date with the latest refactors in the codebase. I have a version of my branch that I got working against the rebased tip of your mamba2 branch (BambaArchitectureRefactor), but it looks like it's out-of-date again based on further changes in the KV caching interface, and similarly it looks like the mamba2 branch is somewhat out of date at this point.

We just released an updated V2 of bamba, so I'd love to push forward with the architecture. If there's interest, I'd be happy to try to rebase this branch on the tip of master with all other refactors. I'm a lot less familiar with the kernel-level optimizations for mamba2, but could look at resolving conflicts there too.

@compilade
Copy link
Collaborator Author

I suspect your implementation of llama_rs_cache is a much better approach than the one I took of simply creating a duplicate llama_kv_cache and conditionally making the two caches have zero-sized layers.

@gabe-l-hart
Interestingly, this sounds very similar to what I've done here. llama_rs_cache and llama_kv_cache have some mutually-exclusive zero-sized layers in Jamba.

Another approach like using per-layer cache types would need considerable additional refactoring which would conflict even more with #12799 (although it might simplify #13194).

(Now that I write this out, you're making me realize that all the kv-cache needs for hybrid models is per-type (e.g. self-attention and recurrent) top-level metadata (the cells) and some data buffers (of which there seem to always be up to 2 per layer (k and v, or r and s), since no layer ever has both Attention and recurrent states (at least this seems true for the hybrid models I've seen so far)). That is pretty much what is implemented here with the zero-sized layers, but this hints towards possible future simplifications (which will be doable after resolving conflicts from #12799 here).)

I've also based my branch on your mamba2 work, so I'd be really interested in consolidating these threads and helping where possible with your work to support hybrid-recurrent models (we are really interested in these architectures at IBM).

I too would be interested in consolidating with your work, or at least making it easier for you to get Bamba supported. How would you prefer this to happen?

Note that I will update the mamba2 branch to keep up with the latest changes, and it may or may not result in lots of conflicts in your branch. I'm not sure if that's avoidable. Hopefully it's not too bad.

It looks like this branch is pretty out of date with the latest refactors in the codebase. I have a version of my branch that I got working against the rebased tip of your mamba2 branch (BambaArchitectureRefactor), but it looks like it's out-of-date again based on further changes in the KV caching interface, and similarly it looks like the mamba2 branch is somewhat out of date at this point.

Yes, this branch is not very up to date, but it's fixable. The main reason close to no progress was being made here was because I don't find it particularly fun to resolve thousands of lines of conflicts. Or at least I need to dedicate a good chunk of time to that so that I don't get lost half-way (since the conflict resolutions of this size mean mostly re-thinking the approach and porting it to the new structures).

So this PR might have been neglected for a while because the moments where I had enough time and the moments where I wanted to fix this and/or reply1 to "progress?" comments did not align.

But I am in a period where I'm starting to have more spare time, and so I could dedicate a day (or more) to resolve the conflicts here and in #9126 (but I suspect it's going to take more than a day).

We just released an updated V2 of bamba, so I'd love to push forward with the architecture.

That's awesome!

If there's interest, I'd be happy to try to rebase this branch on the tip of master with all other refactors.

When branches drift that much, merging is usually simpler to handle than rebasing and still leaves a trail of tested versions, and allows resolving conflicts once (per merge) instead of at every commit which change conflicting parts.

But I see what you mean, and I'd love to get help with the conflict resolutions, but it's unfortunately something which almost has to be done in one go (because git doesn't have first-class conflicts), and so collaboration on that aspect isn't particularly straightforward.

I'm a lot less familiar with the kernel-level optimizations for mamba2, but could look at resolving conflicts there too.

Right, the Mamba2 branch (in #9126) modifies a bit how the SSM operator works (to minimize useless copies), and that will need to be adapted to the CUDA version of the operator which was added in #10558.

Footnotes

  1. This comment did take me more than 3 hours to write. I should probably write smaller comments.

@gabe-l-hart
Copy link
Contributor

Thank you for the detailed response! It's really helpful. I 100% hear you on the giant merge conflicts, and I agree at this stage merging is better than rebasing.

I spent yesterday trying to resolve mamba2 with the latest master (https://github.com/gabe-l-hart/llama.cpp/tree/BambaAbstractMemory). It's not actually working yet, so I clearly missed something. I'll take another whack at it today and see how far I can get it. If I can get mamba2 working by itself, I may try to push on the hybrid architecture more.

It looks like the biggest change since I last synced is around moving to more abstract interfaces for things. In particular, it looks like all caching has moved behind the memory interface, though it gets liberally cast back to the unified cache type. This makes me think the intent is to move closer to how this is done in transformers where individual models can define their own cache semantics, but I'm not totally clear here yet. I'll post useful findings as I go unless you end up getting deep into it and making a lot of progress.

As always, thanks for the outstanding work here, 3-hour comments included!

@gabe-l-hart
Copy link
Contributor

gabe-l-hart commented May 1, 2025

Ok, I found my merge bugs in https://github.com/gabe-l-hart/llama.cpp/tree/BambaAbstractMemory and I'm now able to run a lightweight mamba2 model (details below).

As a separate question, this probably isn't the right place to centralize this discussion. Would it be best to create a central issue to discuss the convergence of mamba2, jamba, and bamba?

Details

# Download lightweight mamba2 model
huggingface-cli download AntonV/mamba2-370m-hf --local-dir ~/models/mamba2-370m-hf

# Convert to GGUF
python convert_hf_to_gguf.py ~/models/mamba2-370m-hf/

# Run a sample query
./build/bin/llama-cli -m ~/models/mamba2-370m-hf/mamba2-370M-hf-F16.gguf -p "Hello world" -ngl 0 --temp 0 -n 20

@compilade
Copy link
Collaborator Author

compilade commented May 1, 2025

I'm now able to run a lightweight mamba2 model (details below).

@gabe-l-hart Amazing!

I've also merged from latest master (into #9126), and some parts differ, but most is similar or the same.

It's very helpful to compare both merges to compare the approaches1 (and sometimes notice when changes are missing). It does reduce the stress of a bad merge. Thank you!

(although it seems like git log --remerge-diff doesn't work on your merge; was it a squash merge perhaps?)

Multi-sequence inference is broken, though (that's also true on master with plain Mamba and RWKV). To test this, you can use:

$ ./build/bin/llama-parallel -m ~/models/mamba2-370m-hf/mamba2-370M-hf-F16.gguf -np 5 -ns 8 --temp 0 --repeat-penalty 1.1

Part of the problem is caused by an early return true in seq_rm, but there's another problem where it seems like the states are not properly isolated between sequences (which also seems to be a problem on master). I'll try to find a fix. I suspect it might be due to modifying const_cast-ed values, but it might be something else.

As a separate question, this probably isn't the right place to centralize this discussion. Would it be best to create a central issue to discuss the convergence of mamba2, jamba, and bamba?

Yes, I think that should be more appropriate. It's true that technically Mamba2 isn't directly related to Jamba. If I create the issue, I will tag you and refer to the relevant PRs and issues.

Footnotes

  1. with git diff 611a470fc1e25e7388c71734f09852a5d9c6ed06 6def5cd729fdde64b2addeaa5cce016c72485e06

@gabe-l-hart
Copy link
Contributor

@compilade Great to hear that you got the merge working, and not at all surprised that I missed some nuance beyond basic single-sequence generation. I'll look to pick up your changes on my branch.

(although it seems like git log --remerge-diff doesn't work on your merge; was it a squash merge perhaps?)

I've never used --remerge-diff! I love learning new tricks. I did not do anything with squashing intentionally, but I did ammend the merge commit a couple of times, so maybe that did it?

I did also start taking a whack at the hybrid cache based on the new layers of abstraction in llama-memory and llama-kv-cache. It's in a broken state, so nothing is pushed yet, but the approach I'm taking is to move everything in llama-context to use the llama_kv_cache abstract interface and then liberally hoisting methods from llama_kv_cache_unified up as part of the abstract method set in llama_kv_cache. This would then allow llama_kv_cache_hybrid to implement them by dispatching to the appropriate cache by layer.

The trickiest part seems to be the intermixing of kv_self_update in llama_context which currently needs intimate details of the member data from llama_kv_cache_unified. I tried moving all of that over into the kv cache class hierarchy, but it also needs intimate knowledge of graph creation and execution which seems to be correctly silo'ed in llama-context. I'll keep digging tomorrow!

@gabe-l-hart
Copy link
Contributor

It looks like the work of hoisting the cache abstraction is almost all done in #12799! I'll move to build off of that branch.

@compilade compilade force-pushed the compilade/refactor-kv-cache branch from aa4039d to 2bcaf64 Compare July 3, 2025 03:42
@compilade compilade marked this pull request as ready for review July 3, 2025 05:49
Copy link
Collaborator Author

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this to use the hybrid cache implementation from #13979. As discussed a while ago, state checkpoints for recurrent state rollbacks will be implemented in a separate PR (which I did not begin).

This seems to still work; I've tested a previous conversion of https://huggingface.co/pszemraj/jamba-900M-v0.13-KIx2, and a new conversion, and both work. I will test the official Jamba-Mini-1.6 in the next days.

I've also shortened the main description of the PR, since the scope has reduced a lot over time (most of the changes were split into other pull requests which were merged in the past year).

Comment on lines +4964 to +4967
def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused

return "gpt-2"
Copy link
Collaborator Author

@compilade compilade Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pre-tokenizer override is pretty much only used by https://huggingface.co/pszemraj/jamba-900M-v0.13-KIx2.
The official Jamba models and finetunes use a sentencepiece tokenizer.model.

Comment on lines 10223 to 10227
const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);

auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());

auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_rs and build_attn now use inp->mctx instead of casting this->mctx again, and so this removes the need for build_inp_mem_hybrid and hybrid-specific overloads of build_rs and build_attn, because the bare build_rs_inp and build_attn_inp_kv_unified can now be used directly instead by giving them the correct mctx to use (defaults to casting from this->mctx, and so existing model graph builders don't need to change).

This will also potentially make it easier to support hybrid models with sliding-window attention.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good change. Thanks.

Copy link
Member

@ggerganov ggerganov Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, however this will not be compatible with the llama_graph_result_i::update(mctx) mechanism from #14482 for reusing compute graphs. The idea is when we can reuse the old graph, to call res->update(mctx) with the new memory context in order to make all inputs from the previous graph result point to that new memory context.

In this case, we will call update([hybrid_context]) and the rs input wouldn't know that it needs to get the recurrent sub-context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabe-l-hart Heads up I am still contemplating this change. If we can't figure a reasonable alternative for the mechanism in #14482 we will have to go back to using a separate hybrid memory input as on master.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I think we could probably make the virtual inheritance thing work even with the separate hybrid memory input, so it probably still makes sense to put GR4 after Jamba.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm reading the changes in #14482 correctly, the issue you're describing would come up here when llm_graph_result::update gets called and in turn calls update(mctx) on each of the input types. In this case, the mctx would be a hybrid context, but the input would be expecting one of the child context types. I recall trying to add cast operators to the hybrid context so that it could be cast to the child types, but I think that caused problems because it didn't inherit from them. It seems like it might make sense for the hybrid cache to have some mechanism for acting like either of the child types. I'll play around a little and see if I can make a dummy version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I the simplest solution (but one that would violate encapsulation pretty badly) would be to use dynamic_cast on mctx in the implementation of update for the input types that might need to consume different context types.

void use_a(const Interface* i) {
    // const TypA * a = static_cast<const TypA *>(i);
    const TypA * a = dynamic_cast<const TypA *>(i);
    if (!a) {
        a = static_cast<const TypAB *>(i)->get_a();
    }
    std::cout << "----" << std::endl;
    std::cout << "Using A" << std::endl;
    a->doit();
    a->just_a();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not suggesting this as a final solution, just brainstorming)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another not-that-good option would be to handle this with multiple inheritance. We could create an intermediate base class derived from llama_memory_context_i that holds the common members between the non-hybrid cache types (ubatches, status, possibly i_*), then use virtual inheritance for kv_cache_unified and recurrent from there. This would allow hybrid to inherit from both (rather than "hasa" both). It would also avoid the ubatch vector copying that's currently happening in both hybrid and iswa.

Copy link
Collaborator Author

@compilade compilade Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can't figure a reasonable alternative for the mechanism in #14482 we will have to go back to using a separate hybrid memory input as on master.

@ggerganov In that case, I think the hybrid graph input could be re-added, but it should contain sub-cache graph inputs, without adding them to res->inputs. This would still allow the mechanism in #14482 to be used (since the hybrid input expects an hybrid mctx, and can update its sub-inputs), and it would also continue to allow passing the sub-cache graph inputs to build_rs and build_attn (to avoid unnecessary overloads).

I've implemented this in 20f8e43.

@@ -303,33 +303,6 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
const llama_cross * cross = nullptr;
};

class llm_graph_input_mem_hybrid : public llm_graph_input_i {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I really like removing these duplicate inputs in favor of allowing the caller to explicitly fetch the hybrid caches' children and pass them to the creation methods.

compilade added 3 commits July 3, 2025 16:04
But this time it contains the sub-cache graph inputs.
This *should* make it easier to handle updating the inputs
when caching the graph (eventually).
Comment on lines +1155 to +1160
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx_cur) {
Copy link
Collaborator Author

@compilade compilade Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(from changes in 20f8e43)

I'm not entirely sure where would be the most convenient place for shared graph input building functions, so I've used static functions with an _impl suffix.

This is used both in llm_graph_context::build_attn_inp_kv_unified() and in llm_graph_context::build_inp_mem_hybrid().

I guess they could be added to llm_graph_context as private methods, or somewhere else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
android Issues specific to Android embeddings embedding related topics enhancement New feature or request examples ggml changes relating to the ggml tensor library for machine learning model Model specific need feedback Testing and feedback with results are needed python python script changes refactoring Refactoring Review Complexity : High Generally require indepth knowledge of LLMs or GPUs server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Suport for Jamba JambaForCausalLM
10 participants