Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][Speculative Decoding] DeepSeek MTP spec decode #12755

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

luccafong
Copy link

Implement DeepSeek MTP: #12181 to support DeepSeek MTP layers for next n prediction.

Copy link

github-actions bot commented Feb 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM. It's pretty clean so no concerns.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how long does it take to run all tests in this file?

vllm/spec_decode/draft_model_runner.py Outdated Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
vllm/transformers_utils/configs/__init__.py Show resolved Hide resolved
vllm/transformers_utils/configs/deepseek_mtp.py Outdated Show resolved Hide resolved
@@ -71,7 +71,7 @@ def __init__(
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \
not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \

inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: where did we truncate the input_ids?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for 1st stage: position 0 is masked for MTP, but it only applies to k=1, I need to change the mask to the [position<=k-1],
for 2+ stage, previous tokens in last stage is marked pre-computed, this is a bit complicated for k>1 on different layers, need to look into.

in short, the current change works for k=1 (which deepseek v3 model set), but need more changes for k>1

Copy link

@Neo9061 Neo9061 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way to put a MD file instructing examples on how to use the MTP for SD?

Especially,

  1. The num_nextn_predict_layers is 1, can we specify speculation length more than 1? and what are requirements on formatting the draft model artifacts?
  2. Is this code compatible with multi-node inference? assume so as the draft is loaded in single GPU?


# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 3
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The num_nextn_predict_layers in DeepSeek V3 has only 1. Will that mean you will reuse the MTP head if I specify MAX_SEC_TOKENS more than 1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a test file on dummy model. num_speculative_tokens should be <= num_nextn_predict_layers, the transformer blocks are different in different steps. I am adding some assertion for this case when user pass higher number.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to just re-use the MTP to predict tokens whose k > 1? as essentially they are the same right?

You can print some warning that this is not expected though.

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the mtp_start_layer_idx be num_hidden_layers -1?

num_hidden_layers is 61 in DeepSeek config. The index of last layer is 60.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://huggingface.co/deepseek-ai/DeepSeek-V3/raw/main/model.safetensors.index.json the last layer is 61 which is the mtp layer.

I see, thanks for clarifying!

@Neo9061
Copy link

Neo9061 commented Feb 5, 2025

@luccafong Sorry have to ask those questions as I hope to use your implementation.

  1. Have you tested it e2e with VLLM's multi-node distributed inference setting? asking as I can only deploy the model in multi-node settings.
  2. If I want to re-use the MTP head to do speculation length k > 1, what is the hacking implementation you would recommend to just make it work? as k=1 is too limited in my application.

@benchislett
Copy link

@luccafong I have been working on a similar implementation locally, and have faced a few challenges that I'm not sure are addressed here. Have you validated the acceptance rate for k=1 for real weights?

I believe that the final RMSNorm in the DeepSeekV3 main model is not necessary for speculative decoding since the hnorm already normalizes the previous hidden weights received from the main model. It's unclear to me how it is classified in the DeepSeek-V3 technical report, but I think that the norm might be included in the output head and therefore not normalized as input to the MTP module. Anecdotally, I observe a small increase in acceptance rate with this change.

Also, I have noticed the acceptance rate becomes very low (<50%) when I enable the recently added MLA attention. Have you noticed this also? I am not sure what could cause this, maybe it is a bug fixed in recent commits to vLLM. I would like to know if this is an issue for your implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a copy of the configuration? I think for local usage we don't need this file if you have --trust-remote-code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah i see it is used for deepseek mtp config...hmmm

Copy link
Author

@luccafong luccafong Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we don't need it if we add tokenizer config in both models, so going to remove these

@luccafong
Copy link
Author

@luccafong I have been working on a similar implementation locally, and have faced a few challenges that I'm not sure are addressed here. Have you validated the acceptance rate for k=1 for real weights?

I believe that the final RMSNorm in the DeepSeekV3 main model is not necessary for speculative decoding since the hnorm already normalizes the previous hidden weights received from the main model. It's unclear to me how it is classified in the DeepSeek-V3 technical report, but I think that the norm might be included in the output head and therefore not normalized as input to the MTP module. Anecdotally, I observe a small increase in acceptance rate with this change.

Also, I have noticed the acceptance rate becomes very low (<50%) when I enable the recently added MLA attention. Have you noticed this also? I am not sure what could cause this, maybe it is a bug fixed in recent commits to vLLM. I would like to know if this is an issue for your implementation.

The accept rate is around 56% during my testing, MLA attention could lead to different branch,
https://github.com/luccafong/vllm/blob/ds_mtp/vllm/spec_decode/multi_step_worker.py#L98 I fixed in a later commit.

regarding the norm, thanks for pointing out, let me try adjusting to see if there is an improvement.

@luccafong
Copy link
Author

luccafong commented Feb 6, 2025

@luccafong Sorry have to ask those questions as I hope to use your implementation.

  1. Have you tested it e2e with VLLM's multi-node distributed inference setting? asking as I can only deploy the model in multi-node settings.
  2. If I want to re-use the MTP head to do speculation length k > 1, what is the hacking implementation you would recommend to just make it work? as k=1 is too limited in my application.

1.Not tested with multi node settings; 2. We can reuse if you do some model processing, e.g. duplicate the weights to different layers, the hacky changes will not be proper since for n predict layers >1, and we do a k > n predict layers, it is difficult to decide which layer to forward multiple times.
Note for now as commented in the other thread, some changes are needed for K>1, I am working in progress, let me update with you if it works.

Copy link

mergify bot commented Feb 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luccafong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2025
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
inputs_embeds[positions <= spec_step_index] = 0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please excuse my multiple questions.

inputs_embeds[positions <= spec_step_index] = 0 is for pre-filling stage for each MTP head correct? as during draft model (MTP head) decoding stage, the inputs_embeds is a single hidden vector.

That is what I saw in EAGLE workflow. It firstly enters the code from here with num_steps being 1 for prefilling (that is where the mask is effective). Then the num_steps becomes to be speculation length k and inputs_embed for each forward pass is a single embed vector.

But I didn't see your logic is modified in

for step in range(num_steps):
to introduce spec_step_index, where do you introduce it then?

str(idx):
DeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my case of re-using the single MTP for k > 1. I will specify k to be the same as number of MTP heads.

In that case, can I simply hack f"{prefix}.layers.{idx}" to be models.layers.61 and self.num_mtp_layers to be k?

@Neo9061
Copy link

Neo9061 commented Feb 6, 2025

@luccafong Sorry have to ask those questions as I hope to use your implementation.

  1. Have you tested it e2e with VLLM's multi-node distributed inference setting? asking as I can only deploy the model in multi-node settings.
  2. If I want to re-use the MTP head to do speculation length k > 1, what is the hacking implementation you would recommend to just make it work? as k=1 is too limited in my application.

1.Not tested with multi node settings; 2. We can reuse if you do some model processing, e.g. duplicate the weights to different layers, the hacky changes will not be proper since for n predict layers >1, and we do a k > n predict layers, it is difficult to decide which layer to forward multiple times. Note for now as commented in the other thread, some changes are needed for K>1, I am working in progress, let me update with you if it works.

Thank you for your reply! Are you testing with single node of H200 then to load the target model?

I just found a bug to use VLLM class to load target model via distributed inference (official Ray setup) with speculative decoding arguments passed in. Without those arguments, the target model is loaded successfully.

#12841

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
inputs_embeds[positions <= spec_step_index] = 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have found that this line might interfere with CUDA graph recording. I am unsure why, but removing this line allowed the draft acceptance rate to go up to 85% in my testing, from <50% prior. This issue was not present for me when MLA was disabled, or when Enforce-Eager was enabled. I believe there might be some setup code in the MLA attention for cuda graphs causing issues with graph replay.

Regardless of the cause, I am curious if you can reproduce this issue. Do you see improved acceptance rates when this line is omitted?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried, but got 58%, not much improvement though. could you share your impl, there could be some difference

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @luccafong , see my code here:
#12915

I hope this can clear up any inconsistency. Please let me know if you cannot reproduce using my implementation, or if you identify any issues.

@WhatGhost
Copy link

@luccafong
Nice work,inspired me a lot!
I have two very small problems that have been bothering me.

  1. I notice you use luccafong/deepseek_mtp_main_random model . I wonder is it only a small size deepseek-like model for testing?
  2. How can i run a deepseek-v3 model with mtp? Just set speculative_model to deepseek-v3 it own?

Thanks!

@LiuXiaoxuanPKU
Copy link
Collaborator

@luccafong Nice work,inspired me a lot! I have two very small problems that have been bothering me.

  1. I notice you use luccafong/deepseek_mtp_main_random model . I wonder is it only a small size deepseek-like model for testing?
  2. How can i run a deepseek-v3 model with mtp? Just set speculative_model to deepseek-v3 it own?

Thanks!

Yes you can just plugin in deepseek model path as shown below, the following is my setup on 8xH200

llm = LLM(
    model="/path/to/DeepSeek-R1",
    tensor_parallel_size=8,
    speculative_draft_tensor_parallel_size=1,
    max_model_len=8192, # If you have enough memory with your hardware, you can ignore this
    num_speculative_tokens=1, # only 1 is supported for now
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants