Skip to content

Conversation

@elaineyz
Copy link
Contributor

@elaineyz elaineyz commented Oct 15, 2025

Purpose

After PR #14097 we now always subtract 1 usable block from the block pool and designate it as the null_block, even when sliding window attention is not used. This creates an edge case when users specify the minimum required number of blocks for their model & config via --num-gpu-blocks-override but then vllm comes and deducts 1 from that (if user specifies 1 block it's effectively 0).

This PR proposes to lazily-initialize the null_block only when it is called (mainly in remove_skipped_blocks). As a result, for kv cache managers like FullAttentionManager, the null_block never materializes and all blocks are available for allocation.

Test Plan

Added unit tests.

Also tested the change E2E locally on Neuron device with the following server command and request:

python3 -m vllm.entrypoints.openai.api_server \
  --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
  --max-num-seqs 1 \
  --max-model-len 128 \
  --tensor-parallel-size 64 \
  --enable-prefix-caching \
  --block-size 32 \
  --num-gpu-blocks-override 4 \
  --additional-config '{
    "override_neuron_config": {
        "is_prefix_caching": true,
        "is_block_kv_layout": true,
        "pa_num_blocks": 4,
        "pa_block_size": 32,
    }
  }' \
  --port 8000
curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
  "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
  "prompt": ["You are an expert software engineer and DevOps specialist. Tell me a story."],
  "min_tokens": 109,
  "max_tokens": 109,
  "temperature": 0
}'

Test Result

Without the change, the vllm scheduler is unable to schedule the request since it only sees 3 blocks available. It will repeatedly send an empty SchedulerOutput to the model runner and the server hangs indefinitely.

After the change, the request goes through successfully.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Elaine Zhao <[email protected]>
@mergify mergify bot added the v1 label Oct 15, 2025
Signed-off-by: Elaine Zhao <[email protected]>
Signed-off-by: Elaine Zhao <[email protected]>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Signed-off-by: Elaine Zhao <[email protected]>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@heheda12345
Copy link
Collaborator

I think lazy initialization is a bit magic. Can you take null_block and --num-gpu-blocks-override into consideration in check_enough_kv_cache_memory of vllm/v1/core/kv_cache_utils.py so that vLLM can raise an error if it can't schedule one request with max-model-len tokens.

@elaineyz
Copy link
Contributor Author

elaineyz commented Oct 21, 2025

I think lazy initialization is a bit magic. Can you take null_block and --num-gpu-blocks-override into consideration in check_enough_kv_cache_memory of vllm/v1/core/kv_cache_utils.py so that vLLM can raise an error if it can't schedule one request with max-model-len tokens.

Thanks @heheda12345.

I agree that check_enough_kv_cache_memory should be updated to account for:

  • --num-gpu-blocks-override, which may artificially limit the usable KV cache (now tracked under issue 27181, thanks!)

In addition, I think it should also account for:

  • the presence (or lazy creation) of null_block, which could reduce the effective number of available blocks.

However, those validations seem orthogonal to this PR. The goal of this PR is purely to introduce lazy initialization of null_block to avoid unnecessary early allocations. I would like to:

  1. Land this PR to keep the lazy initialization logic self-contained. I understand that lazy initialization can feel a bit “magic,” but the intent here is purely to defer allocation of null_block until it’s actually needed, which reduces memory footprint and improves startup performance without changing functional correctness. In fact, it is more "logically correct" for paths like full attention where the creation of a null_block doesn't even make sense.
  2. In a separate PR, extend check_enough_kv_cache_memory to explicitly consider both num_gpu_blocks_override and null_block so vLLM can fail early if the configuration can’t handle a full max_model_len request. I see that v1/kv_cache_utils: Respect num_gpu_blocks_override in memory check #27238 is already created. Happy to support that effort by @khaled-wsa.

Let me know your thoughts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants