Skip to content

Conversation

khaled-wsa
Copy link

Summary

  • Fixes bug where check_enough_kv_cache_memory ignored num_gpu_blocks_override, allowing engine initialization with an insufficient number of KV blocks for max_model_len.
  • Adds a unit test to ensure that when num_gpu_blocks_override is too small (e.g., 1), initialization raises a clear error even if raw available_memory is large.

Context

Technical Details

  • In check_enough_kv_cache_memory:
    • Validate override against per-layer requirements: compute ceil(spec.max_memory_usage_bytes / spec.page_size_bytes) for each layer and ensure num_gpu_blocks_override >= max(required_blocks_per_layer). This closes the hole for heterogeneous specs (e.g., cross-attn vs self-attn).
    • Cap raw available_memory by sum(page_size_bytes) * num_gpu_blocks_override to form effective_available_memory.
    • Compare needed_memory to effective_available_memory, and pass the effective capacity to estimate_max_model_len for accurate guidance.
    • Improve the error message to explicitly mention when the override constrains effective capacity.

Files Changed

  • vllm/v1/core/kv_cache_utils.py
    • Enforce per-layer minimum blocks for num_gpu_blocks_override and apply memory cap. Adjust error message.
  • tests/v1/core/test_kv_cache_utils.py
    • Add test_check_enough_kv_cache_memory_respects_num_gpu_blocks_override.
    • Add test_override_must_cover_worst_layer_blocks_in_heterogeneous_model to cover cross-attn vs self-attn scenario.

How To Test

  • Unit test (CPU):
    • pytest -q tests/v1/core/test_kv_cache_utils.py::test_check_enough_kv_cache_memory_respects_num_gpu_blocks_override
  • Manual sanity:
    • Start the server with a small model and --num_gpu_blocks_override=1.
    • Expect initialization to fail with a ValueError that mentions the effective capacity and the override value.

Notes

  • The change is localized and only alters the pre-initialization capacity check; runtime behavior is unchanged.
  • Works for both uniform and hybrid KV cache specs since the per-block total uses each layer's page_size_bytes.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Oct 21, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug where num_gpu_blocks_override was ignored during memory checks, which could lead to engine initialization with insufficient KV blocks. The changes in check_enough_kv_cache_memory are logical and now correctly validate the override value and use it to cap the effective available memory. The new unit tests are comprehensive, covering both the basic case and a more complex heterogeneous model scenario, ensuring the fix is robust.

I have one high-severity comment regarding a latent bug due to a function call with side effects within the modified code block. While it doesn't cause an issue in the current execution path, it's a potential source of future bugs and should be addressed for better code maintainability and correctness.

Comment on lines 699 to 705
estimated_max_len = estimate_max_model_len(
vllm_config, kv_cache_spec, available_memory
vllm_config, kv_cache_spec, effective_available_memory
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function estimate_max_model_len modifies the vllm_config.model_config.max_model_len attribute as a side effect of its binary search implementation. While this is not currently causing a bug because this code path always raises an exception, it is a latent bug that could cause issues in the future if this function is called in a context that doesn't terminate.

A function should not have hidden side effects on its arguments. It would be best to refactor estimate_max_model_len to not modify vllm_config, for example by restoring the original value before returning or by working on a copy.

Since the definition of estimate_max_model_len is not in this diff, I'm pointing this out here at the call site. A fix could look like this inside estimate_max_model_len:

def estimate_max_model_len(...):
    original_max_len = vllm_config.model_config.max_model_len
    try:
        # ... existing logic ...
        return result
    finally:
        vllm_config.model_config.max_model_len = original_max_len

@khaled-wsa khaled-wsa force-pushed the fix/kv-cache-check-override branch 3 times, most recently from 53e6b20 to 2aad22e Compare October 21, 2025 02:41
@elaineyz
Copy link
Contributor

Hi @khaled-wsa, please see discussion under PR #26939.

In addition to the num_gpu_blocks_override param, the initialization of a null_block may also reduce the total available memory. Could you factor that part into this PR as well? It will make check_enough_kv_cache_memory more robust.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants