Skip to content

Commit

Permalink
[VLM] Add MLA with pure RoPE support for deepseek-vl2 models (#12729)
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored Feb 5, 2025
1 parent 249824c commit 98fd089
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
30 changes: 26 additions & 4 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down Expand Up @@ -174,6 +175,8 @@ def __init__(
self.v_head_dim = v_head_dim

self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
Expand Down Expand Up @@ -420,6 +423,24 @@ def _forward_decode(
) -> torch.Tensor:
raise NotImplementedError

def apply_pure_rope(
self,
input_positions: torch.Tensor,
q_pe: torch.Tensor,
k_pe: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = input_positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape

q_pe, k_pe = self.rotary_emb(
input_positions,
q_pe.reshape(seq_len, -1),
k_pe.reshape(seq_len, -1),
)
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

return q_pe, k_pe

def forward(
self,
layer: AttentionLayer,
Expand All @@ -444,21 +465,22 @@ def forward(
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
rope_fn = (self.rotary_emb
if self.use_yarn_rope else self.apply_pure_rope)

if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = \
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
else:
assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)

# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(
rope_fn(
attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.o_proj")

rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.o_proj")

rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down

0 comments on commit 98fd089

Please sign in to comment.