Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def check_available_online(
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
trust_remote_code=True),
"EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random",
speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501
trust_remote_code=True),
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
Expand Down
6 changes: 5 additions & 1 deletion tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,17 @@ def test_ngram_correctness(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
(("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "qwen3_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm"
"llama4_eagle_mm",
"deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
Expand All @@ -177,6 +180,7 @@ def test_eagle_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if (attn_backend == "TRITON_ATTN_VLLM_V1"
Expand Down
246 changes: 246 additions & 0 deletions vllm/model_executor/models/deepseek_eagle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import Optional

import torch
import torch.nn as nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer,
DeepseekV3ForCausalLM)
from vllm.model_executor.sampling_metadata import SamplingMetadata

from .utils import AutoWeightsLoader, maybe_prefix


@support_torch_compile
class DeepseekV2Model(nn.Module):

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
start_layer_id: int = 0,
) -> None:
super().__init__()
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = self.config.vocab_size

self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
DeepseekV2DecoderLayer(
self.config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
) for i in range(self.config.num_hidden_layers)
])

self.fc = nn.Linear(
self.config.model.hidden_size * 2,
self.config.model.hidden_size,
bias=False,
)

self.enorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.hnorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
Comment on lines +62 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of DeepseekV2Model assumes that the draft model and the target model share the same hidden size. For instance, self.hnorm is initialized with the draft model's hidden size (self.config.hidden_size) but is applied to hidden_states from the target model.

This assumption is incorrect for the models used in testing (deepseek-r1 has a hidden size of 4096, while eagle-deepseek-r1 has 1024), and will lead to a runtime error due to shape mismatch.

To fix this, you should explicitly use the hidden sizes from both the draft and target model configurations. You can access the target model's configuration via vllm_config.model_config.

        target_config = vllm_config.model_config.hf_config
        draft_hidden_size = self.config.hidden_size
        target_hidden_size = target_config.hidden_size

        self.fc = nn.Linear(
            draft_hidden_size + target_hidden_size,
            draft_hidden_size,
            bias=False,
        )

        self.enorm = RMSNorm(draft_hidden_size,
                             eps=self.config.rms_norm_eps)
        self.hnorm = RMSNorm(target_hidden_size,
                             eps=target_config.rms_norm_eps)
        self.norm = RMSNorm(draft_hidden_size,
                            eps=self.config.rms_norm_eps)


def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)

inputs = torch.cat(
[self.enorm(input_embeds),
self.hnorm(hidden_states)], dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like too many norms being applied. In the Llama_Eagle reference code, the input layernorm to each layer is disabled, and IIRC there is no output layernorm. Here, there are two norms applied to the input (pre-concat and input-layernorm after concat) and two more norms applied after (post_attention_layernorm and self.norm). This does not seem correct.

Copy link
Contributor Author

@xyang16 xyang16 Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have taken a look at the deepseek_mtp.py at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_mtp.py#L64. The only difference is output self.norm. But in our benchmarking, we found that including the output norm will increase acceptance rate.

hidden_states = self.fc(inputs)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, hidden_states

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be made compatible with the fused_qkv_a_proj optimization from #21116? I have observed multiple issues with weight loading in MTP not being consistent with the DeepSeek base model weight loading. Will similar issues apply here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the stacked_params_mapping. Thanks!

("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name_mapped = name.replace(weight_name, param_name)

# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
if ((param_name == "fused_qkv_a_proj")
and name_mapped not in params_dict):
continue
else:
name = name_mapped

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and \
"embed_tokens." in name:
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
quant_config = vllm_config.quant_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = DeepseekV2Model(vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num)

self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config)

logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)

model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
Expand Down