-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[VLM] Add MLA with pure RoPE support for deepseek-vl2 models #12729
Conversation
Signed-off-by: Isotr0py <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
deepseek-vl2-small with deepseek-v2 backbone should work with MLA backend now:
|
def apply_pure_rope( | ||
self, | ||
input_positions: torch.Tensor, | ||
q_pe: torch.Tensor, | ||
k_pe: torch.Tensor, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
seq_len = input_positions.size(0) | ||
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape | ||
|
||
q_pe, k_pe = self.rotary_emb( | ||
input_positions, | ||
q_pe.reshape(seq_len, -1), | ||
k_pe.reshape(seq_len, -1), | ||
) | ||
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) | ||
|
||
return q_pe, k_pe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you say a bit about why you needed to wrap rotary_embedding
when using pure_rope? Wondering if we could clean things up by always doing this reshape so that we could always call self.rotary_embedding
without the special cases for pure rope vs yarn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrapped the rotary_embedding
to reshape with pure_rope because if q_pe
and k_pe
have shape of [seq_len, num_heads, head_dim]
and passed to pure_rope directly, it will cause an illegal memory allocation on q_pe
when applying flash_attention_varlen_func
:
[rank0]: File "/home/zifeng/develop-projects/vllm/vllm/attention/backends/mla/utils.py", line 531, in _forward_prefill_flash
[rank0]: attn_output = flash_attn_varlen_func(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/zifeng/develop-projects/vllm/vllm/vllm_flash_attn/flash_attn_interface.py", line 172, in flash_attn_varlen_func
[rank0]: out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/zifeng/miniconda3/envs/vllm/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]: return self._op(*args, **(kwargs or {}))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
BTW, if we use forward_native
for pure_rope without reshape, the error won't be encountered and it can also work with shape of [seq_len, num_heads, head_dim]
, so the issue is forward_cuda
specific. Perhaps we should add a shape check in RotaryEmbedding
's forward_cuda
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, seems that it's because the calculation for num_heads in rotary_embedding cuda ops is unsuitable for tensor with shape [seq_len, num_heads, head_dim]
:
vllm/csrc/pos_encoding_kernels.cu
Lines 124 to 136 in b3a0d01
void rotary_embedding( | |
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] | |
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or | |
// [num_tokens, num_heads * head_size] | |
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or | |
// [num_tokens, num_kv_heads * head_size] | |
int64_t head_size, | |
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] | |
bool is_neox) { | |
int64_t num_tokens = query.numel() / query.size(-1); | |
int rot_dim = cos_sin_cache.size(1); | |
int num_heads = query.size(-1) / head_size; | |
int num_kv_heads = key.size(-1) / head_size; |
Let's fix it in a separate PR to avoid blocking v0.7.2 release, especially it's on the kernel side and I need some time to build with compilation. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a bug in the kernel -- I'll look into it tomorrow. In the meantime I like adding a shape check in forward_cuda if you have a good idea of what shapes are problematic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's fix it in a separate PR to avoid blocking v0.7.2 release, especially it's on the kernel side and I need some time to build with compilation. :)
Nice find, sounds good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
I had one comment from earlier today asking if the implementation could be cleaned up a bit - I still wonder about that. For now, I'm accepting because functionality looks good and I think we should get it into 0.7.2
FIX #11578 (comment)