Skip to content

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Apr 21, 2025

This PR always samples draft tokens with argmax, regardless of the request’s temperature or other sampling parameters.
This will NOT affect the quality of the sampled outputs, but will affect (lower) the acceptance rate of the drafts, especially when the original temperature is high.

The reason behind this idea is that it’s tricky to handle the draft probability tensors efficiently.
If we use random sampling for draft tokens, we need to keep the draft probability tensors for rejection sampling.
However, because each of our scheduling step is (large model -> rejection sampling -> draft model), the draft probs tensors are not used immediately after they are created.
They are used when the corresponding draft tokens are scheduled, the timing of which is unpredictable.
Some of the draft tokens could be scheduled in the right next step, but others could be scheduled much later (e.g., because the request is preempted) or could be never scheduled (e.g., because the request has finished or been aborted).
This makes the management of the draft probs tensor as difficult as managing KV cache.

In contrast, if we use argmax sampling for draft tokens, we don't have to keep the draft prob tensor, which greatly simplifies the implementation.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Apr 21, 2025
@WoosukKwon WoosukKwon changed the title [V1][Spec Decode] Use argmax for sampling draft tokens [V1][Spec Decode] Always use argmax for sampling draft tokens Apr 21, 2025
Copy link
Contributor

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It totally makes sense. The trade-off between implementation complexity and slightly lower acceptance rates (especially at high temperatures) seems reasonable given the current constraints. We can revisit this part later if community feedback indicates a strong need for higher acceptance rates under varying temperature settings.

@benchislett
Copy link
Collaborator

@WoosukKwon do you have any preliminary results for acceptance rate degradation? If the impact is minor (as I would expect), then I am comfortable with this change. If testing indicates a significant decrease in acceptance rates for high-temp sampling then we might want to reconsider.

@WoosukKwon
Copy link
Collaborator Author

@benchislett Good point. However, it's a bit tricky since we haven't implemented EAGLE with random sampling yet.

@luyuzhe111 Could you help with this, or perhaps share the script you used earlier to measure the acceptance length?

@luyuzhe111
Copy link
Contributor

@WoosukKwon yea I can get the acceptance length with greedy drafting this week. the motivation behind this PR is also quite clear.

though if people want to contribute multi-draft spec dec in the future it will be impossible without draft probs.

@WoosukKwon WoosukKwon marked this pull request as ready for review April 21, 2025 21:08
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 21, 2025
@WoosukKwon
Copy link
Collaborator Author

@benchislett @luyuzhe111 I measured the acceptance lengths based on the EAGLE3 PR #16937:

Original temperature 0.0 0.6 0.7 0.8
Argmax 3.29 2.95 2.89 2.66
Same temp 3.29 3.30 3.25 3.11

@mergify mergify bot added the needs-rebase label Apr 22, 2025
@WoosukKwon
Copy link
Collaborator Author

so is a 10% drop in AL acceptable? For chatbots I think 0.7 is a common temp choice?

@luyuzhe111 I think this is definitely not acceptable in the long run, but I'm not sure whether we can use this PR as a temporary workaround.

@mergify mergify bot removed the needs-rebase label Apr 22, 2025
@wwl2755
Copy link
Contributor

wwl2755 commented Apr 22, 2025

I saw the PR #16077 has already aimed at part of this logic, maybe we could have the attempt from there?

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Apr 23, 2025

FWIW, TRTLLM and TGI do not apply temp (greedy sampling) while sampling draft tokens even when the target model is using T!=0.

Is there any paper or resource which recommends applying the same temp to draft sampling as the target model for increased AL? I see this as an empirical evidence but wonder if there is some theorectical insights too.

@WoosukKwon
Copy link
Collaborator Author

Is there any paper or resource which recommends applying the same temp to draft sampling as the target model for increased AL?

@ekagra-ranjan Great question. I think it is backed by theory as well as experience. The draft token is accepted by rejection sampling when target_prob / draft_prob >= u where u is sampled from U(0, 1). Using argmax for draft tokens means making draft_prob always 1 (i.e., target_probs becomes the acceptance rate). Therefore, if the temperature for sampling target tokens is high, or if there are multiple good candidates for the next token, target_prob becomes small, so does the acceptance rate. If we use random sampling for draft tokens, draft_prob becomes smaller than 1, so the problem is mitigated generally.

@WoosukKwon
Copy link
Collaborator Author

WoosukKwon commented Apr 23, 2025

@benchislett @luyuzhe111 @ekagra-ranjan @wwl2755 @ShangmingCai

Let me merge this PR first and get back to the acceptance rate issue later. The current main branch has a bug when using temp > 0, because it uses random sampling for draft tokens but does not consider draft probs in rejection sampling. This PR at least fixes this bug.

In the long run, I think the drop in the acceptance rate is unacceptable, so we should find out a better solution. This PR should be regarded as a band-aid solution.

@WoosukKwon WoosukKwon merged commit 41fb013 into main Apr 23, 2025
22 of 25 checks passed
@WoosukKwon WoosukKwon deleted the draft-argmax branch April 23, 2025 21:57
@WoosukKwon WoosukKwon mentioned this pull request Apr 23, 2025
10 tasks
gshtras added a commit to ROCm/vllm that referenced this pull request Apr 25, 2025
* [BugFix] Remove default multiproc executor `collective_rpc` timeout (vllm-project#17000)

Signed-off-by: Nick Hill <[email protected]>

* [Core][V1][TPU] Enable structured decoding on TPU V1 (vllm-project#16499)

Signed-off-by: Chenyaaang <[email protected]>

* [Bugfix] validate urls object for multimodal content parts (vllm-project#16990)

Signed-off-by: Guillaume Calmettes <[email protected]>

* add Dockerfile build vllm against torch nightly (vllm-project#16936)

Signed-off-by: Yang Wang <[email protected]>

* [Kernel][ROCM] Upstream prefix prefill speed up for vLLM V1 (vllm-project#13305)

Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: maleksan85 <[email protected]>
Signed-off-by: <>
Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: root <[email protected]>

* [V1][DP] More robust DP/EP dummy request coordination (vllm-project#16277)

Signed-off-by: Nick Hill <[email protected]>

* [BugFix] Revert ROCm Custom Paged Attention Env Flag Check (vllm-project#17022)

Signed-off-by: vllmellm <[email protected]>

* Revert "[Misc] Add S3 environment variables for better support of MinIO." (vllm-project#17021)

* [misc] tune some env vars for GB200 (vllm-project#16992)

Signed-off-by: youkaichao <[email protected]>

* [INTEL-HPU][v0] Port delayed sampling to upstream (vllm-project#16949)

Signed-off-by: Michal Adamczyk <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Co-authored-by: Michal Adamczyk <[email protected]>

* [doc] add download path tips (vllm-project#17013)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Triton FA function takes no keyword arguments (vllm-project#16902)

Signed-off-by: vllmellm <[email protected]>

* [V1] Avoid socket errors during shutdown when requests are in in-flight (vllm-project#16807)

Signed-off-by: Nick Hill <[email protected]>

* [BugFix] llama4 fa3 fix - RuntimeError: scheduler_metadata must have shape (metadata_size) (vllm-project#16998)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Misc] Improve readability of get_open_port function. (vllm-project#17024)

Signed-off-by: gitover22 <[email protected]>

* [Bugfix] Fix AssertionError: skip_special_tokens=False is not supported for Mistral tokenizers (vllm-project#16964)

Signed-off-by: chaunceyjiang <[email protected]>

* [CI] Run v1/test_serial_utils.py in CI (vllm-project#16996)

Signed-off-by: Russell Bryant <[email protected]>

* Mistral-format support for compressed-tensors (vllm-project#16803)

Signed-off-by: mgoin <[email protected]>

* Categorize `tests/kernels/` based on kernel type (vllm-project#16799)

Signed-off-by: mgoin <[email protected]>

* [Doc] Add top anchor and a note to quantization/bitblas.md (vllm-project#17042)

Signed-off-by: windsonsea <[email protected]>

* Ensure that `pid` passed to `kill_process_tree` is `int` for `mypy` (vllm-project#17051)

Signed-off-by: Harry Mellor <[email protected]>

* [CI] Update structured-output label automation (vllm-project#17055)

Signed-off-by: Russell Bryant <[email protected]>

* Improve Transformers backend model loading QoL (vllm-project#17039)

Signed-off-by: Harry Mellor <[email protected]>

* `CacheConfig.block_size` should always be `int` when used (vllm-project#17052)

Signed-off-by: Harry Mellor <[email protected]>

* Use `@property` and private field for `data_parallel_rank_local` (vllm-project#17053)

Signed-off-by: Harry Mellor <[email protected]>

* [Frontend] Support guidance:no-additional-properties for compatibility with xgrammar (vllm-project#15949)

Signed-off-by: Travis Johnson <[email protected]>

* [BugFix][V1] Fix int32 token index overflow when preparing input ids (vllm-project#16806)

* [V1][Spec Decode] Always use argmax for sampling draft tokens  (vllm-project#16899)

Signed-off-by: Woosuk Kwon <[email protected]>

* [CI/Build] workaround for CI build failure (vllm-project#17070)

Signed-off-by: csy1204 <[email protected]>
Co-authored-by: Michael Goin <[email protected]>

* [Quantization]add prefix for commandA quantized model (vllm-project#17017)

* [Minor] Use larger batch sizes for A100/B100/B200/MI300x (vllm-project#17073)

Signed-off-by: Woosuk Kwon <[email protected]>

* [Bugfix] Enable V1 usage stats (vllm-project#16986)

Signed-off-by: mgoin <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>

* More informative error when using Transformers backend (vllm-project#16988)

Signed-off-by: Harry Mellor <[email protected]>

* Addendum Fix to support FIPS enabled machines with MD5 hashing (vllm-project#17043)

Signed-off-by: sydarb <[email protected]>

* [Bugfix][Core] add seq_id_to_seq_group clearing to avoid memory leak when s… (vllm-project#16472)

Signed-off-by: 开哲 <[email protected]>
Co-authored-by: 开哲 <[email protected]>

* [V1] Update structured output (vllm-project#16812)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [doc] update to hyperlink (vllm-project#17096)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* Add docs for runai_streamer_sharded (vllm-project#17093)

Signed-off-by: Omer Dayan (SW-GPU) <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [Chore] Remove Sampler from Model Code (vllm-project#17084)

Signed-off-by: Woosuk Kwon <[email protected]>

* Disable enforce_eager for V1 TPU sampler and structured output tests (vllm-project#17016)

Signed-off-by: mgoin <[email protected]>

* Simplify `TokenizerGroup` (vllm-project#16790)

Signed-off-by: Harry Mellor <[email protected]>

* Fix OOT registration test (vllm-project#17099)

Signed-off-by: Harry Mellor <[email protected]>

* [V1][PP] Optimization: continue scheduling prefill chunks (vllm-project#17080)

Signed-off-by: Rui Qiao <[email protected]>

* [Misc] Remove OLMo2 config copy (vllm-project#17066)

Signed-off-by: Isotr0py <[email protected]>

* Improve static type checking in `LoRAModelRunnerMixin` (vllm-project#17104)

Signed-off-by: Harry Mellor <[email protected]>

* [V1][Structured Output] Clear xgrammar compiler object when engine core shut down to avoid nanobind leaked warning (vllm-project#16954)

Signed-off-by: shen-shanshan <[email protected]>

* [Frontend] Using matryoshka_dimensions control the allowed output dimensions. (vllm-project#16970)

* Add missing rocm_skinny_gemms kernel test to CI (vllm-project#17060)

Signed-off-by: mgoin <[email protected]>

* [Misc] refactor example series - structured outputs (vllm-project#17040)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics (vllm-project#16665)

Signed-off-by: Mark McLoughlin <[email protected]>

* [CI] Add automation for the `tool-calling` github label (vllm-project#17118)

Signed-off-by: Russell Bryant <[email protected]>

* Updating builkite job for IBM Power  (vllm-project#17111)

Signed-off-by: Aaruni Aggarwal <[email protected]>

* existing torch installation pip command fix for docs (vllm-project#17059)

* Molmo Requirements (vllm-project#17026)

Signed-off-by: Eyshika Agarwal <[email protected]>
Signed-off-by: eyshika <[email protected]>

* Add `:markdownhelp:` to `EngineArgs` docs so markdown docstrings render properly (vllm-project#17124)

Signed-off-by: Harry Mellor <[email protected]>

* Improve configs - `LoRAConfig` + `PromptAdapterConfig` (vllm-project#16980)

Signed-off-by: Harry Mellor <[email protected]>

* [Docs] Generate correct github links for decorated functions (vllm-project#17125)

Signed-off-by: Russell Bryant <[email protected]>

* Add collective_rpc to llm engine (vllm-project#16999)

Signed-off-by: Yinghai Lu <[email protected]>

* Add chat template for Llama 4 models (vllm-project#16428)

Signed-off-by: Max de Bayser <[email protected]>

* [Misc] Add example to run DeepSeek with Ray Serve LLM (vllm-project#17134)

Signed-off-by: Rui Qiao <[email protected]>

* Better error message for missing mistral params.json (vllm-project#17132)

Signed-off-by: mgoin <[email protected]>

* Use custom address for listening socket (vllm-project#15988)

Signed-off-by: Jens Glaser <[email protected]>

* [FEAT] [ROCm]: AITER Fused MOE V1 Support (vllm-project#16752)

Signed-off-by: vllmellm <[email protected]>
Co-authored-by: tjtanaa <[email protected]>

* [Attention] FA3 decode perf improvement - single mma warp group support for head dim 128 (vllm-project#16864)

Signed-off-by: Lucas Wilkinson <[email protected]>

* fix float16 support for kimi-vl (vllm-project#17156)

Co-authored-by: zhouzaida <[email protected]>

* [Doc] V1 : Update LoRA status (vllm-project#17133)

Signed-off-by: varun sundar rabindranath <[email protected]>
Co-authored-by: varun sundar rabindranath <[email protected]>

* [Docs] Fix True->true in supported_models.md (vllm-project#17141)

* Move missed `SchedulerConfig` args into scheduler config group in `EngineArgs` (vllm-project#17131)

Signed-off-by: Harry Mellor <[email protected]>

* [Misc] Clean up redundant code in uniproc_executor.py (vllm-project#16762)

Signed-off-by: Lifu Huang <[email protected]>

* [Bugfix][Misc] Use TritonPlaceholderModule to defensively import triton (vllm-project#15099)

Signed-off-by: Mengqing Cao <[email protected]>

* [Misc] Benchmark Serving Script Support Appending Results (vllm-project#17028)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance (vllm-project#16457)

Signed-off-by: cynthieye <[email protected]>
Co-authored-by: MagnetoWang <[email protected]>

* [Bugfix] remove fallback in guided_json (int range, patterns) (vllm-project#16725)

Signed-off-by: csy1204 <[email protected]>
Co-authored-by: 조상연[플레이스 AI] <[email protected]>

* [Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization (vllm-project#15734)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>

* [Doc] Add headings to improve gptqmodel.md (vllm-project#17164)

Signed-off-by: windsonsea <[email protected]>

* Only turn on FastIncrementalDetokenizer when tokenizers >= 0.21.1 (vllm-project#17158)

* [Doc] Add two links to disagg_prefill.md (vllm-project#17168)

Signed-off-by: windsonsea <[email protected]>

* [Doc] Move todo out of beam search docstring (vllm-project#17183)

Signed-off-by: Alex-Brooks <[email protected]>

* [Bugfix] Fix mistral model tests (vllm-project#17181)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Fix Mistral ChatCompletionRequest Body Exception (vllm-project#16769)

Signed-off-by: Jasmond Loh <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* Fix API typo and remove FP8 on V1 restriction

---------

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Guillaume Calmettes <[email protected]>
Signed-off-by: Yang Wang <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: maleksan85 <[email protected]>
Signed-off-by: <>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Michal Adamczyk <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: reidliu41 <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: gitover22 <[email protected]>
Signed-off-by: chaunceyjiang <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: windsonsea <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: csy1204 <[email protected]>
Signed-off-by: sydarb <[email protected]>
Signed-off-by: 开哲 <[email protected]>
Signed-off-by: Omer Dayan (SW-GPU) <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: shen-shanshan <[email protected]>
Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: Aaruni Aggarwal <[email protected]>
Signed-off-by: Eyshika Agarwal <[email protected]>
Signed-off-by: eyshika <[email protected]>
Signed-off-by: Yinghai Lu <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Jens Glaser <[email protected]>
Signed-off-by: varun sundar rabindranath <[email protected]>
Signed-off-by: Lifu Huang <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: cynthieye <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Jasmond Loh <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: Chenyaaang <[email protected]>
Co-authored-by: Guillaume Calmettes <[email protected]>
Co-authored-by: Yang Wang <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: vllmellm <[email protected]>
Co-authored-by: Chauncey <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Chendi.Xue <[email protected]>
Co-authored-by: Michal Adamczyk <[email protected]>
Co-authored-by: Reid <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: huafeng <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Michael Yao <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Travis Johnson <[email protected]>
Co-authored-by: Yong Hoon Shin <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Sangyeon Cho <[email protected]>
Co-authored-by: Chen Xia <[email protected]>
Co-authored-by: Areeb Syed <[email protected]>
Co-authored-by: 张宇 <[email protected]>
Co-authored-by: 开哲 <[email protected]>
Co-authored-by: omer-dayan <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Shanshan Shen <[email protected]>
Co-authored-by: wang.yuqi <[email protected]>
Co-authored-by: Mark McLoughlin <[email protected]>
Co-authored-by: Aaruni Aggarwal <[email protected]>
Co-authored-by: Atilla <[email protected]>
Co-authored-by: Eyshika Agarwal <[email protected]>
Co-authored-by: Yinghai Lu <[email protected]>
Co-authored-by: Maximilien de Bayser <[email protected]>
Co-authored-by: jglaser <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Co-authored-by: Zaida Zhou <[email protected]>
Co-authored-by: zhouzaida <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: varun sundar rabindranath <[email protected]>
Co-authored-by: Lifu Huang <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
Co-authored-by: yexin(叶鑫) <[email protected]>
Co-authored-by: MagnetoWang <[email protected]>
Co-authored-by: 조상연[플레이스 AI] <[email protected]>
Co-authored-by: rasmith <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Alex Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Jasmond L <[email protected]>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
@keyboardAnt
Copy link

In the long run, I think the drop in the acceptance rate is unacceptable, so we should find out a better solution. This PR should be regarded as a band-aid solution.

@WoosukKwon, is this on the roadmap yet? Any timeline in mind?

@keyboardAnt
Copy link

keyboardAnt commented Sep 1, 2025

@WoosukKwon yea I can get the acceptance length with greedy drafting this week. the motivation behind this PR is also quite clear.

though if people want to contribute multi-draft spec dec in the future it will be impossible without draft probs.

Yeah, my team and I would like to contribute, but we’re blocked by this issue 🥲

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants