fix: prevent KV cache corruption on SWA/ISWA models + hot-path perf#2180
Open
avion23 wants to merge 3 commits intoabetlen:mainfrom
Open
fix: prevent KV cache corruption on SWA/ISWA models + hot-path perf#2180avion23 wants to merge 3 commits intoabetlen:mainfrom
avion23 wants to merge 3 commits intoabetlen:mainfrom
Conversation
Five optimizations targeting the eval/generate/sample hot path: - eval(): skip kv_cache_seq_rm FFI call when not rewinding (was always called) - set_batch(): numpy bulk array writes replace per-element Python loop - _create_completion: incremental token_to_piece accumulation instead of O(n) re-detokenization per generated token (21x speedup at 512 tokens) - _create_completion: in-place logit_bias modification instead of np.copy on full vocab array (128K+ elements) (21x speedup) - sample(): np.argpartition O(V) for top-k logprobs instead of sorted O(V log V) (218x speedup at top_k=10)
5935c64 to
939fa72
Compare
SWA/ISWA KV caches maintain global position maps (g_iswa_pos_max/min) that are only cleared by llama_memory_clear(), not by kv_cache_seq_rm(). When generate() finds a prefix match (e.g. shared BOS token), it calls kv_cache_seq_rm which returns True for ISWA, skipping the full reset. But the stale position maps cause batch allocator inconsistency and llama_decode returned -1 on subsequent prompts. Changes: - Add _has_swa property via llama_model_n_swa() > 0 - reset() now calls llama_memory_clear() unconditionally - generate() bypasses prefix-match optimization for SWA models, forcing full state reset (same path as recurrent models)
939fa72 to
9609c82
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Gemma-4 and any model with interleaved sliding window attention (ISWA) crashes on the second call to
create_completion:Root cause
ISWA KV caches store position tracking in global maps (
g_iswa_pos_max/g_iswa_pos_min) that are cleared byllama_memory_clear()but not byllama_memory_seq_rm(). Thegenerate()method detects a prefix match between consecutive prompts (shared BOS token), callskv_cache_seq_rm()to remove the divergent tail, sees it returnTrue, and skips the full reset. The stale position maps then cause batch allocator inconsistency → decode failure.Additionally,
reset()was a no-op on KV cache state (only resetn_tokens), so even explicitly callingllm.reset()between prompts didn't help.Fix
reset()now callsllama_memory_clear()unconditionallygenerate()bypasses prefix-match for SWA models, forcing full reset (same path as recurrent models)eval()only callskv_cache_seq_rm()when actually rewinding (len(tokens) < n_tokens), skipping the FFI call on the common append pathPerformance
set_batch()_create_completiontoken_to_piece()accumulation replaces O(n) re-detokenization per token_create_completionnp.copy()per token_create_completiontokens[:-1]→tokensfixes broken prefix detection (last token was excluded)sample()np.argpartitionO(V) replacessorted()O(V log V) for top-k logprobs_internals.token_to_pieceFiles
llama_cpp/llama.py— reset(), eval(), generate(), _create_completion(), sample()llama_cpp/_internals.py— set_batch(), token_to_piece()