-
-
Notifications
You must be signed in to change notification settings - Fork 10.1k
[Model] Support deepseek with eagle #21086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
from collections.abc import Iterable | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.config import VllmConfig | ||
from vllm.distributed.parallel_state import get_pp_group | ||
from vllm.model_executor.layers.fused_moe import FusedMoE | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
ParallelLMHead, VocabParallelEmbedding) | ||
from vllm.model_executor.model_loader.weight_utils import ( | ||
default_weight_loader, maybe_remap_kv_scale_name) | ||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, | ||
DeepseekV3ForCausalLM) | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
|
||
from .utils import AutoWeightsLoader, maybe_prefix | ||
|
||
|
||
@support_torch_compile | ||
class DeepseekV2Model(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
*, | ||
vllm_config: VllmConfig, | ||
prefix: str = "", | ||
start_layer_id: int = 0, | ||
) -> None: | ||
super().__init__() | ||
self.config = vllm_config. \ | ||
speculative_config.draft_model_config.hf_config | ||
model_config = vllm_config.model_config | ||
cache_config = vllm_config.cache_config | ||
quant_config = vllm_config.quant_config | ||
self.vocab_size = self.config.vocab_size | ||
|
||
self.embed_tokens = VocabParallelEmbedding( | ||
self.config.vocab_size, | ||
self.config.hidden_size, | ||
quant_config=quant_config, | ||
prefix=maybe_prefix(prefix, "embed_tokens"), | ||
) | ||
|
||
self.layers = nn.ModuleList([ | ||
DeepseekV2DecoderLayer( | ||
self.config, | ||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), | ||
model_config=model_config, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
) for i in range(self.config.num_hidden_layers) | ||
]) | ||
|
||
self.fc = nn.Linear( | ||
self.config.model.hidden_size * 2, | ||
self.config.model.hidden_size, | ||
bias=False, | ||
) | ||
|
||
self.enorm = RMSNorm(self.config.hidden_size, | ||
eps=self.config.rms_norm_eps) | ||
self.hnorm = RMSNorm(self.config.hidden_size, | ||
eps=self.config.rms_norm_eps) | ||
self.norm = RMSNorm(self.config.hidden_size, | ||
eps=self.config.rms_norm_eps) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
input_embeds = self.embed_tokens(input_ids) | ||
|
||
inputs = torch.cat( | ||
[self.enorm(input_embeds), | ||
self.hnorm(hidden_states)], dim=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like too many norms being applied. In the Llama_Eagle reference code, the input layernorm to each layer is disabled, and IIRC there is no output layernorm. Here, there are two norms applied to the input (pre-concat and input-layernorm after concat) and two more norms applied after (post_attention_layernorm and self.norm). This does not seem correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have taken a look at the |
||
hidden_states = self.fc(inputs) | ||
residual = None | ||
for layer in self.layers: | ||
hidden_states, residual = layer( | ||
positions, | ||
hidden_states, | ||
residual, | ||
) | ||
hidden_states, _ = self.norm(hidden_states, residual) | ||
return hidden_states, hidden_states | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, | ||
torch.Tensor]]) -> set[str]: | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
("gate_up_proj", "gate_proj", 0), | ||
("gate_up_proj", "up_proj", 1), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be made compatible with the fused_qkv_a_proj optimization from #21116? I have observed multiple issues with weight loading in MTP not being consistent with the DeepSeek base model weight loading. Will similar issues apply here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have updated the stacked_params_mapping. Thanks! |
||
("fused_qkv_a_proj", "q_a_proj", 0), | ||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), | ||
] | ||
|
||
# Params for weights, fp8 weight scales, fp8 activation scales | ||
# (param_name, weight_name, expert_id, shard_id) | ||
expert_params_mapping = FusedMoE.make_expert_params_mapping( | ||
ckpt_gate_proj_name="gate_proj", | ||
ckpt_down_proj_name="down_proj", | ||
ckpt_up_proj_name="up_proj", | ||
num_experts=self.config.n_routed_experts) | ||
|
||
params_dict = dict(self.named_parameters()) | ||
loaded_params: set[str] = set() | ||
for name, loaded_weight in weights: | ||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
|
||
for param_name, weight_name, shard_id in stacked_params_mapping: | ||
# Skip non-stacked layers and experts (experts handled below). | ||
if weight_name not in name: | ||
continue | ||
# We have mlp.experts[0].gate_proj in the checkpoint. | ||
# Since we handle the experts below in expert_params_mapping, | ||
# we need to skip here BEFORE we update the name, otherwise | ||
# name will be updated to mlp.experts[0].gate_up_proj, which | ||
# will then be updated below in expert_params_mapping | ||
# for mlp.experts[0].gate_gate_up_proj, which breaks load. | ||
if ("mlp.experts." in name) and name not in params_dict: | ||
continue | ||
name_mapped = name.replace(weight_name, param_name) | ||
|
||
# QKV fusion is optional, fall back to normal | ||
# weight loading if it's not enabled | ||
# if go with fusion option, then update name | ||
if ((param_name == "fused_qkv_a_proj") | ||
and name_mapped not in params_dict): | ||
continue | ||
else: | ||
name = name_mapped | ||
|
||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
|
||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
for mapping in expert_params_mapping: | ||
param_name, weight_name, expert_id, shard_id = mapping | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
|
||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader( | ||
param, | ||
loaded_weight, | ||
name, | ||
shard_id=shard_id, | ||
expert_id=expert_id, | ||
) | ||
break | ||
else: | ||
# if PP disabled then draft will share embed with target | ||
if get_pp_group().world_size == 1 and \ | ||
"embed_tokens." in name: | ||
continue | ||
|
||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
|
||
# Remapping the name of FP8 kv-scale. | ||
name = maybe_remap_kv_scale_name(name, params_dict) | ||
if name is None: | ||
continue | ||
|
||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
loaded_params.add(name) | ||
return loaded_params | ||
|
||
|
||
class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
nn.Module.__init__(self) | ||
self.config = vllm_config. \ | ||
speculative_config.draft_model_config.hf_config | ||
quant_config = vllm_config.quant_config | ||
target_layer_num = vllm_config.model_config.get_num_layers( | ||
vllm_config.parallel_config) | ||
self.model = DeepseekV2Model(vllm_config=vllm_config, | ||
prefix="model", | ||
start_layer_id=target_layer_num) | ||
|
||
self.lm_head = ParallelLMHead(self.config.vocab_size, | ||
self.config.hidden_size, | ||
quant_config=quant_config) | ||
|
||
logit_scale = getattr(self.config, "logit_scale", 1.0) | ||
self.logits_processor = LogitsProcessor(self.config.vocab_size, | ||
scale=logit_scale) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
if inputs_embeds is not None: | ||
raise NotImplementedError( | ||
f"{type(self).__name__} does not support multimodal inputs yet." | ||
) | ||
return self.model(input_ids, positions, hidden_states) | ||
|
||
def compute_logits( | ||
self, | ||
hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
) -> Optional[torch.Tensor]: | ||
logits = self.logits_processor(self.lm_head, hidden_states, | ||
sampling_metadata) | ||
return logits | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||
loader = AutoWeightsLoader( | ||
self, | ||
skip_prefixes=None, | ||
) | ||
|
||
model_weights = {} | ||
for name, loaded_weight in weights: | ||
if "lm_head" not in name: | ||
name = "model." + name | ||
model_weights[name] = loaded_weight | ||
loader.load_weights(model_weights.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of
DeepseekV2Model
assumes that the draft model and the target model share the same hidden size. For instance,self.hnorm
is initialized with the draft model's hidden size (self.config.hidden_size
) but is applied tohidden_states
from the target model.This assumption is incorrect for the models used in testing (
deepseek-r1
has a hidden size of 4096, whileeagle-deepseek-r1
has 1024), and will lead to a runtime error due to shape mismatch.To fix this, you should explicitly use the hidden sizes from both the draft and target model configurations. You can access the target model's configuration via
vllm_config.model_config
.