Skip to content

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Nov 6, 2025

fix #17037 #17058
cont #16812

The small SWA caches can be padded to 256 without concerns about memory usage. Pad the cache size to 256. This is friendly for the CUDA backend since the FA implementation benefits from round sizes of the K/V tensors. Can also help other backends.

This is essentially a partial revert of #16812.

Comparing with before the regression in #16812:

GGML_CUDA=ON CUDA_VISIBLE_DEVICES=0 ./scripts/compare-commits.sh a8ca18b4b d2c30c61a llama-bench -m ~/.cache/llama.cpp/ggml-org_gpt-oss-20b-GGUF_gpt-oss-20b-mxfp4.gguf -m /home/ggerganov/.cache/llama.cpp/ggml-org_Qwen2.5-Coder-3B-Q8_0-GGUF_qwen2.5-coder-3b-q8_0.gguf -ngl 99 -d 4096,8192,16384,32768 -ub 512,4096 -b 4096 -fa 1 -n 32 -mmp 0

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes

Model Microbatch size Test t/s a8ca18b t/s d2c30c6 Speedup
gpt-oss 20B MXFP4 MoE 512 pp512@d4096 9702.25 9724.85 1.00
gpt-oss 20B MXFP4 MoE 512 pp512@d8192 8609.17 8643.88 1.00
gpt-oss 20B MXFP4 MoE 512 pp512@d16384 7169.27 7224.84 1.01
gpt-oss 20B MXFP4 MoE 512 pp512@d32768 5349.99 5380.43 1.01
gpt-oss 20B MXFP4 MoE 512 tg32@d4096 306.13 338.25 1.10
gpt-oss 20B MXFP4 MoE 512 tg32@d8192 292.57 317.68 1.09
gpt-oss 20B MXFP4 MoE 512 tg32@d16384 265.66 304.90 1.15
gpt-oss 20B MXFP4 MoE 512 tg32@d32768 236.92 262.29 1.11
gpt-oss 20B MXFP4 MoE 4096 pp512@d4096 8720.64 8735.30 1.00
gpt-oss 20B MXFP4 MoE 4096 pp512@d8192 7908.52 7799.28 0.99
gpt-oss 20B MXFP4 MoE 4096 pp512@d16384 6656.47 6583.46 0.99
gpt-oss 20B MXFP4 MoE 4096 pp512@d32768 5063.06 4967.41 0.98
gpt-oss 20B MXFP4 MoE 4096 tg32@d4096 296.76 318.91 1.07
gpt-oss 20B MXFP4 MoE 4096 tg32@d8192 279.30 322.56 1.15
gpt-oss 20B MXFP4 MoE 4096 tg32@d16384 251.35 283.65 1.13
gpt-oss 20B MXFP4 MoE 4096 tg32@d32768 227.88 253.40 1.11
qwen2 3B Q8_0 512 pp512@d4096 17229.54 17278.24 1.00
qwen2 3B Q8_0 512 pp512@d8192 14011.21 14133.12 1.01
qwen2 3B Q8_0 512 pp512@d16384 10304.85 10307.07 1.00
qwen2 3B Q8_0 512 pp512@d32768 6612.16 6567.63 0.99
qwen2 3B Q8_0 512 tg32@d4096 279.56 294.93 1.05
qwen2 3B Q8_0 512 tg32@d8192 210.92 213.51 1.01
qwen2 3B Q8_0 512 tg32@d16384 185.15 188.19 1.02
qwen2 3B Q8_0 512 tg32@d32768 147.62 149.84 1.02
qwen2 3B Q8_0 4096 pp512@d4096 16599.01 16781.20 1.01
qwen2 3B Q8_0 4096 pp512@d8192 13664.67 13715.42 1.00
qwen2 3B Q8_0 4096 pp512@d16384 10056.36 10027.96 1.00
qwen2 3B Q8_0 4096 pp512@d32768 6585.53 6579.99 1.00
qwen2 3B Q8_0 4096 tg32@d4096 274.54 279.28 1.02
qwen2 3B Q8_0 4096 tg32@d8192 219.26 224.87 1.03
qwen2 3B Q8_0 4096 tg32@d16384 188.45 192.18 1.02
qwen2 3B Q8_0 4096 tg32@d32768 148.30 150.20 1.01

llama_context_params cparams = llama_context_default_params();

cparams.n_ctx = n_prompt + n_gen + n_depth;
cparams.n_ctx = GGML_PAD(n_prompt + n_gen + n_depth, 256);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is good, applications shouldn't be responsible for this padding.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the padding to llama_context constructor?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a significant performance advantage from doing that then yes, that seems like a good thing to do. Otherwise, every single application would need to pad n_ctx manually.

Copy link
Member Author

@ggerganov ggerganov Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, the padding in llama-bench here is important because we do the TG test at full context. So we hit the bounds of the unpadded cache buffer when generating tokens (given that the resulting cparams.n_ctx number is not multiple of 256):

uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
// pad the n_kv value so that the graph remains constant across batches and can be reused
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
const uint32_t n_pad_cur = std::max(n_pad, 256u);
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]];
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
}
return result;
}

Normally, applications will rarely experience this problem, because it shows up only when you are close to the full context.

I think it's ok to be in the constructor. The drawback is that you could have your requested n_ctx mutated which is not ideal. But probably a less of a problem than the current one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The drawback is that you could have your requested n_ctx mutated which is not ideal.

As long as the value is never below what the application requested, it should be fine. The application can choose to use llama_n_ctx to take advantage of every slot, but they don't have to. I think this was already the situation before the padding was removed.

@ggerganov ggerganov changed the title kv-cache : pad the size of the small SWA cache for performance kv-cache : pad the cache size to 256 for performance Nov 7, 2025
@ggerganov ggerganov requested a review from slaren November 7, 2025 15:09
Comment on lines 48 to 50
// note: the SWA cache is always padded to 256 for performance
// https://github.com/ggml-org/llama.cpp/issues/17037
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, 256));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, I think this relies on kv_size being padded to 256. The GGML_PAD could be moved outside the std::min to ensure that it is still padded if this changes in the future.

@github-actions github-actions bot added python python script changes server labels Nov 7, 2025
@ggerganov ggerganov merged commit 16bcc12 into master Nov 7, 2025
50 of 66 checks passed
@ggerganov ggerganov deleted the gg/iswa-pad-256 branch November 7, 2025 18:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Performance regression in prompt processing

3 participants