Skip to content

Commit e26554a

Browse files
committed
[Model] Support deepseek with eagle
Signed-off-by: Xin Yang <[email protected]>
1 parent b4b78d6 commit e26554a

File tree

3 files changed

+242
-1
lines changed

3 files changed

+242
-1
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,17 @@ def test_ngram_correctness(
144144
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
145145
True,
146146
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
147+
(("eagle", "eagle618/deepseek-v3-random",
148+
"eagle618/eagle-deepseek-v3-random", 1), False),
147149
],
148150
ids=[
149151
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
150152
# "qwen3_eagle3",
151153
"llama3_eagle",
152154
"llama3_eagle3",
153155
"llama4_eagle",
154-
"llama4_eagle_mm"
156+
"llama4_eagle_mm",
157+
"deepseek_eagle"
155158
])
156159
@pytest.mark.parametrize("attn_backend",
157160
get_attn_backend_list_based_on_platform())
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Iterable
5+
from typing import Optional
6+
7+
import torch
8+
import torch.nn as nn
9+
10+
from vllm.compilation.decorators import support_torch_compile
11+
from vllm.config import VllmConfig
12+
from vllm.distributed.parallel_state import get_pp_group
13+
from vllm.model_executor.layers.fused_moe import FusedMoE
14+
from vllm.model_executor.layers.layernorm import RMSNorm
15+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16+
from vllm.model_executor.layers.vocab_parallel_embedding import (
17+
ParallelLMHead, VocabParallelEmbedding)
18+
from vllm.model_executor.model_loader.weight_utils import (
19+
default_weight_loader, maybe_remap_kv_scale_name)
20+
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
21+
DeepseekV3ForCausalLM)
22+
from vllm.model_executor.sampling_metadata import SamplingMetadata
23+
24+
from .utils import AutoWeightsLoader, maybe_prefix
25+
26+
27+
@support_torch_compile
28+
class DeepseekV2Model(nn.Module):
29+
30+
def __init__(
31+
self,
32+
*,
33+
vllm_config: VllmConfig,
34+
prefix: str = "",
35+
start_layer_id: int = 0,
36+
) -> None:
37+
super().__init__()
38+
self.config = vllm_config. \
39+
speculative_config.draft_model_config.hf_config
40+
model_config = vllm_config.model_config
41+
cache_config = vllm_config.cache_config
42+
quant_config = vllm_config.quant_config
43+
self.vocab_size = self.config.vocab_size
44+
45+
self.embed_tokens = VocabParallelEmbedding(
46+
self.config.vocab_size,
47+
self.config.hidden_size,
48+
quant_config=quant_config,
49+
prefix=maybe_prefix(prefix, "embed_tokens"),
50+
)
51+
52+
self.layers = nn.ModuleList([
53+
DeepseekV2DecoderLayer(
54+
self.config,
55+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
56+
model_config=model_config,
57+
cache_config=cache_config,
58+
quant_config=quant_config,
59+
) for i in range(self.config.num_hidden_layers)
60+
])
61+
62+
self.fc = nn.Linear(
63+
self.config.model.hidden_size * 2,
64+
self.config.model.hidden_size,
65+
bias=False,
66+
)
67+
68+
self.enorm = RMSNorm(self.config.hidden_size,
69+
eps=self.config.rms_norm_eps)
70+
self.hnorm = RMSNorm(self.config.hidden_size,
71+
eps=self.config.rms_norm_eps)
72+
self.norm = RMSNorm(self.config.hidden_size,
73+
eps=self.config.rms_norm_eps)
74+
75+
def forward(
76+
self,
77+
input_ids: torch.Tensor,
78+
positions: torch.Tensor,
79+
hidden_states: torch.Tensor,
80+
) -> tuple[torch.Tensor, torch.Tensor]:
81+
input_embeds = self.embed_tokens(input_ids)
82+
83+
inputs = torch.cat(
84+
[self.enorm(input_embeds),
85+
self.hnorm(hidden_states)], dim=-1)
86+
hidden_states = self.fc(inputs)
87+
88+
# masking inputs at position=0
89+
hidden_states[positions == 0] = 0
90+
residual = None
91+
for layer in self.layers:
92+
hidden_states, residual = layer(
93+
positions,
94+
hidden_states,
95+
residual,
96+
)
97+
hidden_states, _ = self.norm(hidden_states, residual)
98+
return hidden_states, hidden_states
99+
100+
def load_weights(self, weights: Iterable[tuple[str,
101+
torch.Tensor]]) -> set[str]:
102+
stacked_params_mapping = [
103+
# (param_name, shard_name, shard_id)
104+
("gate_up_proj", "gate_proj", 0),
105+
("gate_up_proj", "up_proj", 1),
106+
]
107+
108+
# Params for weights, fp8 weight scales, fp8 activation scales
109+
# (param_name, weight_name, expert_id, shard_id)
110+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
111+
ckpt_gate_proj_name="gate_proj",
112+
ckpt_down_proj_name="down_proj",
113+
ckpt_up_proj_name="up_proj",
114+
num_experts=self.config.n_routed_experts)
115+
116+
params_dict = dict(self.named_parameters())
117+
loaded_params: set[str] = set()
118+
for name, loaded_weight in weights:
119+
if "rotary_emb.inv_freq" in name:
120+
continue
121+
122+
for param_name, weight_name, shard_id in stacked_params_mapping:
123+
# Skip non-stacked layers and experts (experts handled below).
124+
if weight_name not in name:
125+
continue
126+
# We have mlp.experts[0].gate_proj in the checkpoint.
127+
# Since we handle the experts below in expert_params_mapping,
128+
# we need to skip here BEFORE we update the name, otherwise
129+
# name will be updated to mlp.experts[0].gate_up_proj, which
130+
# will then be updated below in expert_params_mapping
131+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
132+
if ("mlp.experts." in name) and name not in params_dict:
133+
continue
134+
name = name.replace(weight_name, param_name)
135+
# Skip loading extra bias for GPTQ models.
136+
if name.endswith(".bias") and name not in params_dict:
137+
continue
138+
139+
param = params_dict[name]
140+
weight_loader = param.weight_loader
141+
weight_loader(param, loaded_weight, shard_id)
142+
break
143+
else:
144+
for mapping in expert_params_mapping:
145+
param_name, weight_name, expert_id, shard_id = mapping
146+
if weight_name not in name:
147+
continue
148+
name = name.replace(weight_name, param_name)
149+
150+
param = params_dict[name]
151+
weight_loader = param.weight_loader
152+
weight_loader(
153+
param,
154+
loaded_weight,
155+
name,
156+
shard_id=shard_id,
157+
expert_id=expert_id,
158+
)
159+
break
160+
else:
161+
# if PP disabled then draft will share embed with target
162+
if get_pp_group().world_size == 1 and \
163+
"embed_tokens." in name:
164+
continue
165+
166+
# Skip loading extra bias for GPTQ models.
167+
if name.endswith(".bias") and name not in params_dict:
168+
continue
169+
170+
# Remapping the name of FP8 kv-scale.
171+
name = maybe_remap_kv_scale_name(name, params_dict)
172+
if name is None:
173+
continue
174+
175+
param = params_dict[name]
176+
weight_loader = getattr(param, "weight_loader",
177+
default_weight_loader)
178+
weight_loader(param, loaded_weight)
179+
loaded_params.add(name)
180+
return loaded_params
181+
182+
183+
class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
184+
185+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
186+
nn.Module.__init__(self)
187+
self.config = vllm_config. \
188+
speculative_config.draft_model_config.hf_config
189+
quant_config = vllm_config.quant_config
190+
target_layer_num = vllm_config.model_config.get_num_layers(
191+
vllm_config.parallel_config)
192+
self.model = DeepseekV2Model(vllm_config=vllm_config,
193+
prefix="model",
194+
start_layer_id=target_layer_num)
195+
196+
self.lm_head = ParallelLMHead(self.config.vocab_size,
197+
self.config.hidden_size,
198+
quant_config=quant_config)
199+
200+
logit_scale = getattr(self.config, "logit_scale", 1.0)
201+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
202+
scale=logit_scale)
203+
204+
def forward(
205+
self,
206+
input_ids: torch.Tensor,
207+
positions: torch.Tensor,
208+
hidden_states: torch.Tensor,
209+
inputs_embeds: Optional[torch.Tensor] = None,
210+
) -> tuple[torch.Tensor, torch.Tensor]:
211+
if inputs_embeds is not None:
212+
raise NotImplementedError(
213+
f"{type(self).__name__} does not support multimodal inputs yet."
214+
)
215+
return self.model(input_ids, positions, hidden_states)
216+
217+
def compute_logits(
218+
self,
219+
hidden_states: torch.Tensor,
220+
sampling_metadata: SamplingMetadata,
221+
) -> Optional[torch.Tensor]:
222+
logits = self.logits_processor(self.lm_head, hidden_states,
223+
sampling_metadata)
224+
return logits
225+
226+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
227+
loader = AutoWeightsLoader(
228+
self,
229+
skip_prefixes=None,
230+
)
231+
232+
model_weights = {}
233+
for name, loaded_weight in weights:
234+
if "lm_head" not in name:
235+
name = "model." + name
236+
model_weights[name] = loaded_weight
237+
loader.load_weights(model_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@
262262
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
263263
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
264264
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
265+
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
265266
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
266267
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
267268
"MedusaModel": ("medusa", "Medusa"),

0 commit comments

Comments
 (0)