Skip to content

fix: prevent KV cache corruption on SWA/ISWA models + hot-path perf#2180

Open
avion23 wants to merge 3 commits intoabetlen:mainfrom
avion23:fix/perf-and-iswa
Open

fix: prevent KV cache corruption on SWA/ISWA models + hot-path perf#2180
avion23 wants to merge 3 commits intoabetlen:mainfrom
avion23:fix/perf-and-iswa

Conversation

@avion23
Copy link
Copy Markdown

@avion23 avion23 commented Apr 12, 2026

Problem

Gemma-4 and any model with interleaved sliding window attention (ISWA) crashes on the second call to create_completion:

llm = Llama(model_path="gemma4-q4.gguf", n_gpu_layers=-1)

llm.create_chat_completion(messages=[{"role": "user", "content": "What is 2+2?"}])
# → OK

llm.create_chat_completion(messages=[{"role": "user", "content": "Write a hello world"}])
# → RuntimeError: error during generation: [end of text]

Root cause

ISWA KV caches store position tracking in global maps (g_iswa_pos_max / g_iswa_pos_min) that are cleared by llama_memory_clear() but not by llama_memory_seq_rm(). The generate() method detects a prefix match between consecutive prompts (shared BOS token), calls kv_cache_seq_rm() to remove the divergent tail, sees it return True, and skips the full reset. The stale position maps then cause batch allocator inconsistency → decode failure.

Additionally, reset() was a no-op on KV cache state (only reset n_tokens), so even explicitly calling llm.reset() between prompts didn't help.

Fix

  1. reset() now calls llama_memory_clear() unconditionally
  2. generate() bypasses prefix-match for SWA models, forcing full reset (same path as recurrent models)
  3. eval() only calls kv_cache_seq_rm() when actually rewinding (len(tokens) < n_tokens), skipping the FFI call on the common append path
llm = Llama(model_path="gemma4-q4.gguf", n_gpu_layers=-1)

llm.create_chat_completion(messages=[{"role": "user", "content": "What is 2+2?"}])
# → "2 + 2 = 4"

llm.create_chat_completion(messages=[{"role": "user", "content": "Write a hello world"}])
# → "print('Hello, World!')"

Performance

Location Change
set_batch() numpy bulk writes replace Python loop (512 tokens × 5 assignments → 5 numpy ops)
_create_completion incremental token_to_piece() accumulation replaces O(n) re-detokenization per token
_create_completion in-place logit_bias removes 128K-element np.copy() per token
_create_completion tokens[:-1]tokens fixes broken prefix detection (last token was excluded)
sample() np.argpartition O(V) replaces sorted() O(V log V) for top-k logprobs
_internals.token_to_piece returns actual byte length instead of padded 32-byte buffer

Files

  • llama_cpp/llama.py — reset(), eval(), generate(), _create_completion(), sample()
  • llama_cpp/_internals.py — set_batch(), token_to_piece()

Five optimizations targeting the eval/generate/sample hot path:

- eval(): skip kv_cache_seq_rm FFI call when not rewinding (was always called)
- set_batch(): numpy bulk array writes replace per-element Python loop
- _create_completion: incremental token_to_piece accumulation instead of
  O(n) re-detokenization per generated token (21x speedup at 512 tokens)
- _create_completion: in-place logit_bias modification instead of np.copy
  on full vocab array (128K+ elements) (21x speedup)
- sample(): np.argpartition O(V) for top-k logprobs instead of sorted O(V log V)
  (218x speedup at top_k=10)
@avion23 avion23 force-pushed the fix/perf-and-iswa branch from 5935c64 to 939fa72 Compare April 12, 2026 15:49
SWA/ISWA KV caches maintain global position maps (g_iswa_pos_max/min) that
are only cleared by llama_memory_clear(), not by kv_cache_seq_rm(). When
generate() finds a prefix match (e.g. shared BOS token), it calls
kv_cache_seq_rm which returns True for ISWA, skipping the full reset. But
the stale position maps cause batch allocator inconsistency and
llama_decode returned -1 on subsequent prompts.

Changes:
- Add _has_swa property via llama_model_n_swa() > 0
- reset() now calls llama_memory_clear() unconditionally
- generate() bypasses prefix-match optimization for SWA models,
  forcing full state reset (same path as recurrent models)
@avion23 avion23 force-pushed the fix/perf-and-iswa branch from 939fa72 to 9609c82 Compare April 12, 2026 15:58
@avion23 avion23 changed the title perf: vectorize hot-path operations + fix SWA/ISWA KV cache corruption (Gemma-4) fix: prevent KV cache corruption on SWA/ISWA models + hot-path perf Apr 12, 2026
@avion23 avion23 marked this pull request as ready for review April 12, 2026 16:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant