Skip to content

Commit

Permalink
cleanup and add comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lu Fang <[email protected]>
  • Loading branch information
luccafong committed Feb 5, 2025
1 parent 9de0bdf commit 793ef4f
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 6 deletions.
1 change: 0 additions & 1 deletion vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
print(f"{config=}")
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
Expand Down
2 changes: 2 additions & 0 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def execute_model(
compute_logits_kwargs = {}
# Run model
if hasattr(self.model.config, "num_nextn_predict_layers"):
# for DeepSeek MTP only to use the corresponding layer for
# each step
kwargs["step_idx"] = step
compute_logits_kwargs["step_idx"] = step
with set_forward_context(model_input.attn_metadata,
Expand Down
5 changes: 2 additions & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,8 @@ def create_worker(
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type in [
"eagle", "deepseek_mtp"
]:
if draft_model_config.hf_config.model_type in (
"eagle", "deepseek_mtp"):
raise NotImplementedError(
f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
Expand Down
1 change: 0 additions & 1 deletion vllm/transformers_utils/configs/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class DeepSeekMTPConfig(PretrainedConfig):
def __init__(self,
model: Union[PretrainedConfig, dict, None] = None,
**kwargs):
print("model: %s", model)
if model is not None:
self.model = DeepseekV3Config.from_dict(model, **kwargs)
else:
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \
not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \
else {"return_hidden_states": True}

ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
Expand Down

0 comments on commit 793ef4f

Please sign in to comment.