Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][Speculative Decoding] Add EAGLE-style MTP module reference code for DeepSeek-R1 #12915

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Run config for DeepSeek-R1 on a single 8xH200 node
# Using one MTP module for speculative execution,
# Called recursively for k=2 speculative tokens.
# Expected draft acceptance rate is ~70%
# (~80% for token 1, ~60% for token 2 due to accuracy decay)
python3 \
-m vllm.entrypoints.openai.api_server \
--disable-log-requests \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder which dataset is used in this testing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample requests from ShareGPT are used.

--gpu-memory-utilization 0.85 \
--quantization fp8 \
--max-model-len 65536 \
--max-num-seqs 128 \
--seed 0 \
--tensor-parallel-size 8 \
--swap-space 0 \
--block-size 32 \
--model deepseek-ai/DeepSeek-R1 \
--distributed-executor-backend=mp \
--trust-remote-code \
--num-speculative-tokens 2 \
--speculative-model DeepSeekV3MTP
7 changes: 4 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vllm.platforms import CpuArchEnum
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,

Check failure on line 30 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F401)

vllm/config.py:30:31: F401 `vllm.transformers_utils.config.get_hf_image_processor_config` imported but unused
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
Expand Down Expand Up @@ -313,8 +313,7 @@

self.hf_text_config = get_hf_text_config(self.hf_config)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.hf_image_processor_config = {}# get_hf_image_processor_config(self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
Expand Down Expand Up @@ -420,6 +419,8 @@
def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
) -> Optional["MultiModalConfig"]:
return None

architectures = getattr(self.hf_config, "architectures", [])
if ModelRegistry.is_multimodal_model(architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
Expand Down Expand Up @@ -756,7 +757,7 @@
def is_deepseek_mla(self) -> bool:
return (hasattr(self.hf_text_config, "model_type")) \
and (self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3'))\
('deepseek_v2', 'deepseek_v3', 'eagle'))\
and (self.hf_text_config.kv_lora_rank is not None)

def get_head_size(self) -> int:
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model = _initialize_model(vllm_config=vllm_config)

weights_to_load = {name for name, _ in model.named_parameters()}
if hasattr(model_config.hf_config, 'model_type') and model_config.hf_config.model_type == 'eagle':
model_config.model = 'deepseek-ai/DeepSeek-R1'
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
# We only enable strict check for non-quantized models
Expand Down
188 changes: 188 additions & 0 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from typing import Iterable, List, Optional, Set, Tuple

import torch
from torch import nn

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .utils import is_pp_missing_parameter
from .deepseek_v2 import DeepseekV2DecoderLayer

class DeepseekV3MTPSpeculator(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "", mtp_layer_index: int = 0):
super().__init__()
config = vllm_config.model_config.hf_config
config.first_k_dense_replace = 0
self.config = config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.quant_config = vllm_config.quant_config

self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)

self.shared_head = nn.ModuleDict({
"head": ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=self.quant_config),
"norm": RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
})

layer_index = 61

self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.transformer = DeepseekV2DecoderLayer(config, f"{prefix}.layers.{layer_index}", quant_config=self.quant_config, cache_config=self.cache_config, model_config=self.model_config)

self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
if inputs_embeds is not None:
embedding = inputs_embeds
else:
embedding = self.embed_tokens(input_ids)

h_normed = self.hnorm(previous_hidden_states)
e_normed = self.enorm(embedding)

cat_in = torch.cat([e_normed, h_normed], dim=-1) # swapped from the paper
proj_out = self.eh_proj(cat_in)

(mtp_hidden, mtp_residual) = self.transformer(
positions,
proj_out,
kv_cache=kv_caches[0],
attn_metadata=attn_metadata,
residual=None
)

return mtp_hidden + mtp_residual
# hidden_states = mtp_hidden
# hidden_states, _ = self.shared_head["norm"](hidden_states, mtp_residual)
# return hidden_states

def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.shared_head["head"], self.shared_head["norm"](hidden_states), sampling_metadata)
return logits

def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:

stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),

Check failure on line 99 in vllm/model_executor/models/deepseek_mtp.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/deepseek_mtp.py:99:81: E501 Line too long (82 > 80)
("gate_up_proj", "up_proj", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

assert self.config.num_nextn_predict_layers == 1
layer_idx = 61
if name.startswith(f"model.layers.{layer_idx}"):
name = name.replace(f"model.layers.{layer_idx}.", "")
if name.startswith("input_layernorm") or name.startswith("post_attention_layernorm") or name.startswith("mlp") or name.startswith("self_attn"):
name = "transformer." + name
else:
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

if name not in params_dict:
breakpoint()
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name, self):
continue

if name not in params_dict:
breakpoint()
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

if name not in params_dict:
breakpoint()
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

6 changes: 3 additions & 3 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ def forward(
"residual": residual
})

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
# hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states + residual


class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
Expand Down Expand Up @@ -684,7 +684,7 @@ def compute_logits(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
logits = self.logits_processor(self.lm_head, self.model.norm(hidden_states),
sampling_metadata)
return logits

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
_SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"MedusaModel": ("medusa", "Medusa"),
"DeepseekV3MTPModel": ("deepseek_mtp", "DeepseekV3MTPSpeculator"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

Expand Down
10 changes: 10 additions & 0 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,22 @@ def sampler_output(
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
if expanded_request.previous_hidden_states is not None:
self.worker.model_runner.return_hidden_states = True
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]

if expanded_request.previous_hidden_states is not None:
assert hasattr(model_output, 'hidden_states')
seq_group_meta_with_hidden = [
sg for sg in expanded_request.seq_group_metadata_list
if sg.do_sample
]
expanded_request.previous_hidden_states = HiddenStates(model_output.hidden_states, seq_group_meta_with_hidden, expanded_request.previous_hidden_states.hidden_states)

self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
Expand Down
8 changes: 4 additions & 4 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ def create_worker(
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() and not draft_model_config.use_mla:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
"EAGLE does not support TP > 1 yet")
# if draft_model_config.hf_config.model_type == "eagle":
# raise NotImplementedError(
# "EAGLE does not support TP > 1 yet")

allow_zero_draft_token_step = False
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
Expand Down
18 changes: 14 additions & 4 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,24 +176,29 @@ def get_config(
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models

if model == "DeepSeekV3MTP":
model_base_name = "deepseek-ai/DeepSeek-R1"
else:
model_base_name = model

is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = Path(model).name
model = Path(model).parent

if config_format == ConfigFormat.AUTO:
if is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision):
model_base_name, HF_CONFIG_NAME, revision=revision):
config_format = ConfigFormat.HF
elif file_or_path_exists(model, MISTRAL_CONFIG_NAME,
elif file_or_path_exists(model_base_name, MISTRAL_CONFIG_NAME,
revision=revision):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model,
file_exists(model_base_name,
HF_CONFIG_NAME,
revision=revision,
token=HF_TOKEN)
Expand All @@ -202,13 +207,18 @@ def get_config(

if config_format == ConfigFormat.HF:
config_dict, _ = PretrainedConfig.get_config_dict(
model,
model_base_name,
revision=revision,
code_revision=code_revision,
token=HF_TOKEN,
**kwargs,
)

if model == "DeepSeekV3MTP":
config_dict["model_type"] = "eagle"
config_dict["num_hidden_layers"] = 1
config_dict["architectures"] = ["DeepseekV3MTPModel"]

# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
Expand Down
Loading
Loading