-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[ROCm][MLA] Support block-size > 1 for AITER MLA backend #27224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: ganyi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables support for block sizes greater than 1 for the AITER MLA backend on ROCm, which was previously a limitation. The approach of remapping the block table to token-level indices to match the expectation of the underlying AITER kernel is sound. The implementation is largely correct, but I've identified a potential issue with a hardcoded device string that could lead to runtime errors in multi-GPU environments. Addressing this will improve the robustness of the change.
Signed-off-by: ganyi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| ).unsqueeze(0) < seq_lens_device.unsqueeze(1) | ||
| paged_kv_indices = block_table_tensor[mask] | ||
|
|
||
| paged_kv_last_page_len = seq_lens_device % page_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recompute last page lengths after token-level remapping
After expanding each block table entry into per-token indices, the code still derives paged_kv_last_page_len from the original page_size (seq_lens % page_size, falling back to page_size). Once the remapping is done, each entry represents a single token, so the last-page length for any non-empty request should always be 1. Keeping the old computation causes the decode kernel to believe that the final page contains page_size tokens (e.g. 128) and it will read that many elements starting from the last token’s index, potentially stepping past the valid token range when block_size > 1. This defeats the goal of supporting larger block sizes and can lead to out-of-bounds accesses or garbage attention results for any request longer than one token.
Useful? React with 👍 / 👎.
|
@HAIAI Please kindly help to review this PR. |
|
Just a sharing of the performance metric of this amazing optimization PR. There is improvement even in the original support Here's a comparison table of the benchmark results on DeepSeek-R1 PTPC FP8: General Metrics
Latency Metrics
Workload |
@tjtanaa Thanks for the benchmark metric you shared! I'm actually quite surprise that the |
|
@ganyi1996ppo can you share the performance value of your experiment for |
|
hi @tjtanaa, We tested the performance with 8k/1k and 32/1k configure for different scenario, the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Does this change require any particular AITER version or branch? |
@gshtras No specific aiter version is required, it just maps |




Purpose
The
AITERMLABackendnow only supportblock-size=1scenario for inference. This constrain may lead to some serious host overhead when we are about to allocate or free cache blocks for long context requests cause there might exist large amount of blocks to operate. Thanks to the insights of @gyu-amd .In this PR, we remapping the
block_tableto 1 block size case every step inAITERMLAMetadataBuilderto alleviate the host overhead during allocate and deallocate blocks This change also helps to support a wider range of block size forAITERMLABackend, makes theAITERMLABackendon ROCm platform aligns with the vllm's official usgae and more flexible .Test Plan
Verified on gsm8k for accuracy, performance improvement will also be attached later
test script:
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.