Skip to content

Commit 0866c0b

Browse files
xyang16amd-xiaoyu12
authored andcommitted
[Model] Support deepseek with eagle (vllm-project#21086)
Signed-off-by: Xin Yang <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 69d9cd9 commit 0866c0b

File tree

4 files changed

+255
-1
lines changed

4 files changed

+255
-1
lines changed

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,9 @@ def check_available_online(
530530
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
531531
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
532532
trust_remote_code=True),
533+
"EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random",
534+
speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501
535+
trust_remote_code=True),
533536
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
534537
trust_remote_code=True,
535538
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",

tests/v1/e2e/test_spec_decode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,17 @@ def test_ngram_correctness(
144144
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
145145
True,
146146
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
147+
(("eagle", "eagle618/deepseek-v3-random",
148+
"eagle618/eagle-deepseek-v3-random", 1), False),
147149
],
148150
ids=[
149151
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
150152
# "qwen3_eagle3",
151153
"llama3_eagle",
152154
"llama3_eagle3",
153155
"llama4_eagle",
154-
"llama4_eagle_mm"
156+
"llama4_eagle_mm",
157+
"deepseek_eagle"
155158
])
156159
@pytest.mark.parametrize("attn_backend",
157160
get_attn_backend_list_based_on_platform())
@@ -177,6 +180,7 @@ def test_eagle_correctness(
177180
'''
178181
with monkeypatch.context() as m:
179182
m.setenv("VLLM_USE_V1", "1")
183+
m.setenv("VLLM_MLA_DISABLE", "1")
180184
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
181185

182186
if (attn_backend == "TRITON_ATTN_VLLM_V1"
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Iterable
5+
from typing import Optional
6+
7+
import torch
8+
import torch.nn as nn
9+
10+
from vllm.compilation.decorators import support_torch_compile
11+
from vllm.config import VllmConfig
12+
from vllm.distributed.parallel_state import get_pp_group
13+
from vllm.model_executor.layers.fused_moe import FusedMoE
14+
from vllm.model_executor.layers.layernorm import RMSNorm
15+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16+
from vllm.model_executor.layers.vocab_parallel_embedding import (
17+
ParallelLMHead, VocabParallelEmbedding)
18+
from vllm.model_executor.model_loader.weight_utils import (
19+
default_weight_loader, maybe_remap_kv_scale_name)
20+
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
21+
DeepseekV3ForCausalLM)
22+
from vllm.model_executor.sampling_metadata import SamplingMetadata
23+
24+
from .utils import AutoWeightsLoader, maybe_prefix
25+
26+
27+
@support_torch_compile
28+
class DeepseekV2Model(nn.Module):
29+
30+
def __init__(
31+
self,
32+
*,
33+
vllm_config: VllmConfig,
34+
prefix: str = "",
35+
start_layer_id: int = 0,
36+
) -> None:
37+
super().__init__()
38+
self.config = vllm_config. \
39+
speculative_config.draft_model_config.hf_config
40+
model_config = vllm_config.model_config
41+
cache_config = vllm_config.cache_config
42+
quant_config = vllm_config.quant_config
43+
self.vocab_size = self.config.vocab_size
44+
45+
self.embed_tokens = VocabParallelEmbedding(
46+
self.config.vocab_size,
47+
self.config.hidden_size,
48+
quant_config=quant_config,
49+
prefix=maybe_prefix(prefix, "embed_tokens"),
50+
)
51+
52+
self.layers = nn.ModuleList([
53+
DeepseekV2DecoderLayer(
54+
self.config,
55+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
56+
model_config=model_config,
57+
cache_config=cache_config,
58+
quant_config=quant_config,
59+
) for i in range(self.config.num_hidden_layers)
60+
])
61+
62+
self.fc = nn.Linear(
63+
self.config.model.hidden_size * 2,
64+
self.config.model.hidden_size,
65+
bias=False,
66+
)
67+
68+
self.enorm = RMSNorm(self.config.hidden_size,
69+
eps=self.config.rms_norm_eps)
70+
self.hnorm = RMSNorm(self.config.hidden_size,
71+
eps=self.config.rms_norm_eps)
72+
self.norm = RMSNorm(self.config.hidden_size,
73+
eps=self.config.rms_norm_eps)
74+
75+
def forward(
76+
self,
77+
input_ids: torch.Tensor,
78+
positions: torch.Tensor,
79+
hidden_states: torch.Tensor,
80+
) -> tuple[torch.Tensor, torch.Tensor]:
81+
input_embeds = self.embed_tokens(input_ids)
82+
83+
inputs = torch.cat(
84+
[self.enorm(input_embeds),
85+
self.hnorm(hidden_states)], dim=-1)
86+
hidden_states = self.fc(inputs)
87+
residual = None
88+
for layer in self.layers:
89+
hidden_states, residual = layer(
90+
positions,
91+
hidden_states,
92+
residual,
93+
)
94+
hidden_states, _ = self.norm(hidden_states, residual)
95+
return hidden_states, hidden_states
96+
97+
def load_weights(self, weights: Iterable[tuple[str,
98+
torch.Tensor]]) -> set[str]:
99+
stacked_params_mapping = [
100+
# (param_name, shard_name, shard_id)
101+
("gate_up_proj", "gate_proj", 0),
102+
("gate_up_proj", "up_proj", 1),
103+
("fused_qkv_a_proj", "q_a_proj", 0),
104+
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
105+
]
106+
107+
# Params for weights, fp8 weight scales, fp8 activation scales
108+
# (param_name, weight_name, expert_id, shard_id)
109+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
110+
ckpt_gate_proj_name="gate_proj",
111+
ckpt_down_proj_name="down_proj",
112+
ckpt_up_proj_name="up_proj",
113+
num_experts=self.config.n_routed_experts)
114+
115+
params_dict = dict(self.named_parameters())
116+
loaded_params: set[str] = set()
117+
for name, loaded_weight in weights:
118+
if "rotary_emb.inv_freq" in name:
119+
continue
120+
121+
for param_name, weight_name, shard_id in stacked_params_mapping:
122+
# Skip non-stacked layers and experts (experts handled below).
123+
if weight_name not in name:
124+
continue
125+
# We have mlp.experts[0].gate_proj in the checkpoint.
126+
# Since we handle the experts below in expert_params_mapping,
127+
# we need to skip here BEFORE we update the name, otherwise
128+
# name will be updated to mlp.experts[0].gate_up_proj, which
129+
# will then be updated below in expert_params_mapping
130+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
131+
if ("mlp.experts." in name) and name not in params_dict:
132+
continue
133+
name_mapped = name.replace(weight_name, param_name)
134+
135+
# QKV fusion is optional, fall back to normal
136+
# weight loading if it's not enabled
137+
# if go with fusion option, then update name
138+
if ((param_name == "fused_qkv_a_proj")
139+
and name_mapped not in params_dict):
140+
continue
141+
else:
142+
name = name_mapped
143+
144+
# Skip loading extra bias for GPTQ models.
145+
if name.endswith(".bias") and name not in params_dict:
146+
continue
147+
148+
param = params_dict[name]
149+
weight_loader = param.weight_loader
150+
weight_loader(param, loaded_weight, shard_id)
151+
break
152+
else:
153+
for mapping in expert_params_mapping:
154+
param_name, weight_name, expert_id, shard_id = mapping
155+
if weight_name not in name:
156+
continue
157+
name = name.replace(weight_name, param_name)
158+
159+
param = params_dict[name]
160+
weight_loader = param.weight_loader
161+
weight_loader(
162+
param,
163+
loaded_weight,
164+
name,
165+
shard_id=shard_id,
166+
expert_id=expert_id,
167+
)
168+
break
169+
else:
170+
# if PP disabled then draft will share embed with target
171+
if get_pp_group().world_size == 1 and \
172+
"embed_tokens." in name:
173+
continue
174+
175+
# Skip loading extra bias for GPTQ models.
176+
if name.endswith(".bias") and name not in params_dict:
177+
continue
178+
179+
# Remapping the name of FP8 kv-scale.
180+
name = maybe_remap_kv_scale_name(name, params_dict)
181+
if name is None:
182+
continue
183+
184+
param = params_dict[name]
185+
weight_loader = getattr(param, "weight_loader",
186+
default_weight_loader)
187+
weight_loader(param, loaded_weight)
188+
loaded_params.add(name)
189+
return loaded_params
190+
191+
192+
class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
193+
194+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
195+
nn.Module.__init__(self)
196+
self.config = vllm_config. \
197+
speculative_config.draft_model_config.hf_config
198+
quant_config = vllm_config.quant_config
199+
target_layer_num = vllm_config.model_config.get_num_layers(
200+
vllm_config.parallel_config)
201+
self.model = DeepseekV2Model(vllm_config=vllm_config,
202+
prefix="model",
203+
start_layer_id=target_layer_num)
204+
205+
self.lm_head = ParallelLMHead(self.config.vocab_size,
206+
self.config.hidden_size,
207+
quant_config=quant_config)
208+
209+
logit_scale = getattr(self.config, "logit_scale", 1.0)
210+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
211+
scale=logit_scale)
212+
213+
def forward(
214+
self,
215+
input_ids: torch.Tensor,
216+
positions: torch.Tensor,
217+
hidden_states: torch.Tensor,
218+
inputs_embeds: Optional[torch.Tensor] = None,
219+
) -> tuple[torch.Tensor, torch.Tensor]:
220+
if inputs_embeds is not None:
221+
raise NotImplementedError(
222+
f"{type(self).__name__} does not support multimodal inputs yet."
223+
)
224+
return self.model(input_ids, positions, hidden_states)
225+
226+
def compute_logits(
227+
self,
228+
hidden_states: torch.Tensor,
229+
sampling_metadata: SamplingMetadata,
230+
) -> Optional[torch.Tensor]:
231+
logits = self.logits_processor(self.lm_head, hidden_states,
232+
sampling_metadata)
233+
return logits
234+
235+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
236+
loader = AutoWeightsLoader(
237+
self,
238+
skip_prefixes=None,
239+
)
240+
241+
model_weights = {}
242+
for name, loaded_weight in weights:
243+
if "lm_head" not in name:
244+
name = "model." + name
245+
model_weights[name] = loaded_weight
246+
loader.load_weights(model_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@
264264
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
265265
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
266266
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
267+
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
267268
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
268269
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
269270
"MedusaModel": ("medusa", "Medusa"),

0 commit comments

Comments
 (0)