Skip to content

[V1] Large Block_size solution #21123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

nadathurv
Copy link

@nadathurv nadathurv commented Jul 17, 2025

This is the updated and current work on this issue. It is related to the To-Do item

Original Problem

Hybrid models were using extremely large block sizes (~400 tokens) due to individual layer constraints. Each attention layer was padded so that kv_hidden_size * block_size of one layer was larger than the mamba state size of one layer, leading to inefficient memory usage.

Solution

Implement aggregate constraint approach instead of individual layer constraints:

Before: Each layer individually satisfies mamba state requirement
After: Combined memory of all attention layers satisfies mamba state requirement

Key Changes

  1. kv_cache_coordinator.py: Add calculate_optimal_block_size() method

    • Implements aggregate constraint calculation: max_mamba_state / (num_attention_layers * min_per_token_bytes)
    • Provides fallback to OPTIMAL_BLOCK_FALLBACK when calculation fails
    • Includes cached version with LRU cache for performance optimization
  2. kv_cache_utils.py: Add _get_kv_cache_config_optimal_block_size() integration

    • Deep copies all specs to prevent mutation of original configurations
    • Applies calculated optimal block size uniformly across all layer specs
    • Wraps calculation in try-catch with fallback to existing uniform page size logic
    • Integrates with existing get_kv_cache_config() flow for hybrid models

cc @heheda12345 @tlrmchlsmth

Outdated links: Original Work

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an intelligent way to calculate the optimal block size for hybrid models, which should improve memory efficiency. The core logic for the calculation in kv_cache_coordinator.py is robust and handles edge cases well. The integration in kv_cache_utils.py is also well-structured.

I've identified one high-severity issue regarding error handling. The use of a broad, silent except Exception could mask bugs and should be updated to include logging for better maintainability and easier debugging. Other than that, the changes look good.

@WorldExplored WorldExplored force-pushed the large-block-size-solution branch from 5d7de66 to 0061d19 Compare July 19, 2025 02:46
@mergify mergify bot added documentation Improvements or additions to documentation ci/build frontend llama Related to Llama models new-model Requests to new models performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm tpu Related to Google TPUs labels Jul 19, 2025
Copy link

mergify bot commented Jul 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @nadathurv.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 19, 2025
@WorldExplored WorldExplored force-pushed the large-block-size-solution branch from 0061d19 to 99a755a Compare July 19, 2025 02:56
Introduces BlockHash and BlockHashWithGroupId NamedTuple classes for KV cache
prefix caching, including support for token IDs, extra keys, and group IDs to
facilitate multi‑group cache management and reduce hash collisions.

Signed-off-by: WorldExplored <[email protected]>
@WorldExplored WorldExplored force-pushed the large-block-size-solution branch from 99a755a to f38a044 Compare July 19, 2025 05:38
@mergify mergify bot removed tpu Related to Google TPUs needs-rebase labels Jul 19, 2025
Fixed some errors with BlockPool and other files in the core directory.

Signed-off-by: WorldExplored <[email protected]>
@WorldExplored WorldExplored force-pushed the large-block-size-solution branch from 2148f1c to 897fc1b Compare July 19, 2025 22:18
@DarkLight1337
Copy link
Member

cc @tdoublep @tlrmchlsmth

@nadathurv nadathurv changed the title [UPDATED] - Large Block_size solution [V1] Large Block_size solution Jul 20, 2025
@tdoublep
Copy link
Member

Thanks for the contribution. I'm struggling to understand this PR. There seems to be a bunch of unrelated changes and also stuff missing (e.g., where is the calculate_optimal_block_size method?).

Perhaps more fundamentally, I don't quite understand how it is as simple as choosing the attention block size such that on aggregate the page size matches the mamba page size. This is due to how the memory layout work.

Imagine we have a model today that has 4 attention layers and 4 mamba layers. We are going to have 2 KVCacheGroups each of size 4 layers. And the current logic will make the attention block size large enough such that the attention page size matches the mamba page size (attn_page_size = mamba_page_size).

We will create 4 KVCacheTensors for each of the 4 layers within the groups, and attention blocks and mamba blocks are interleaved across the 4 KVCacheTensors as follows:
image

Since an attention block and a mamba block consume the same amount of memory, we can free an attention blocks and replace it with a mamba block (and vice-versa) interchangeably.

If I understand correctly, you are proposing that we could instead choosing the attention block size such that 4*attention_page_size = mamba_page_size. But, unless we make additional changes to how the memory layout works, will result in something like this:

image

How do we handle this? Do we make the number of attention blocks larger (e.g., 4x)? I don't think that really makes sense because a sequence will all need the same number of attention blocks as mamba blocks. The other option would be to the change the layout so the attention blocks are not interleaved across 4 KVCacheTensors but rather one. How does this PR propose to solve these issues?

@nadathurv
Copy link
Author

nadathurv commented Jul 21, 2025

This PR has mostly fallen apart. The force-push in the commit history wiped my original edits and fixes, but I had worked out a solution to your proposed changes and could use feedback and a bit of guidance on next steps.

Restating the original problem: Block sizes are extremely large (~400 tokens) because the current constraint requires: kv_hidden_size * block_size of one attention layer ≥ mamba state size of one layer

Main idea: Reduce attention block size to ~100 tokens (what they actually need) instead of ~400 tokens (padded to match mamba).

Let's say 1 Aggregation unit = physical memory size of 1 mamba block.
Each unit can hold either:

  • 1 mamba block (400 tokens), OR
  • 4 attention blocks (4 × 100 tokens)

this is only possible if all allocation/deallocation happens at the unit level, not block level. This gives us the memory savings (attention uses 100-token blocks) while maintaining interchangeability (units are swappable).

The memory layout would look something like this:

Unit 0: [A₁A₂A₃A₄] or [M]  (interchangeable)
Unit 1: [A₅A₆A₇A₈] or [M]  (interchangeable)
...

This preserves the interleaving pattern across KVCacheTensors at the unit level.

Key Changes:

  • KVCacheCoordinator: track allocation at unit granularity, but expose multiple blocks for attention layers
  • Block allocation returns 4 blocks per unit for attention, 1 block per unit for mamba
  • Maintain same number of units across all requests

The most important thing to this idea is decoupling the logical block management at unit level from physical block access at the sub-block level.

My concerns are:

  • How to manage sub-blocks within units
  • How prefix caching works with aggregated blocks

    For this, hash blocks individually but track the units for allocation (hash at block level, allocate/evict at unit level)

  • Performance impact of different access patterns (attention accessing 4 blocks vs mamba accessing 1)

    For this, we can make attention blocks within each unit contiguous so we can maintain kernel performance. The migration path can also have feature flag.

What are your thoughts? And should I make a new PR for this, (citing and closing this one). Let me know if I missed a concern you had. Apologies for the confusion. Thanks for the support.

@heheda12345
Copy link
Collaborator

I'm thinking of changing the layout to:
image

One magic is that we don't need to change attention kernel with this layout. Just set the pointer of $k_i$ to $tensor + 2 \times i \times page\_size$ and $v_i$ to $tensor + (2 \times i + 1) \times page\_size$, and the block_id to $block\_id\_alloc \times num\_attention\_layer \times 2$, where $block\_id\_alloc$ is the $block\_id$ allocated by kv cache manager.

And we also need benchmarking. This will lead to (num_mamba_layer+1) kv cache groups, and I'm not sure whether it will cause some performance issue.

@heheda12345
Copy link
Collaborator

I think large block size and (2, num_block) layout may be coupled. But for simplifying code review, @nadathurv can you first work on large block size problem, and then extend it to (2, num_block) layout to enable more attention backends?

@nadathurv
Copy link
Author

@heheda12345 Got it. It will be addressed in a new PR I will cite it here once done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend llama Related to Llama models new-model Requests to new models performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants