Skip to content

Conversation

wwl2755
Copy link
Contributor

@wwl2755 wwl2755 commented Sep 7, 2025

Purpose

  1. Fix bug when eagle3 model's layer_types aren't None. Fix [Bug]: Support qwen3 Models in eagle3 Speculative Decoding #23464
(EngineCore_0 pid=830911) layer_types: ['full_attention']
(EngineCore_0 pid=830911) layer_idx: 40
  1. Re-enable the test disabled in [BugFix] [Spec Decode] Remove LlamaForCausalLMEagle3 to fix CI #22611

cc: @lec77 @22quinn

Test Result

Before PR:

(EngineCore_0 pid=3736098)   File "~/vllm/vllm/model_executor/models/llama.py", line 174, in __init__
(EngineCore_0 pid=3736098)     is_sliding = layer_types[layer_idx] == "sliding_attention"
(EngineCore_0 pid=3736098)                  ~~~~~~~~~~~^^^^^^^^^^^
(EngineCore_0 pid=3736098) IndexError: list index out of range

After PR:
Server successfully launched.

Command:

vllm serve Qwen/Qwen3-14B \
	--port 18000 \
	--max-model-len 8080 \
	--tensor-parallel-size 1 \
	--max-num-seqs 8 \
	--max-num-batched-tokens 2048 \
	--speculative-config '{ "method": "eagle3", "model": "AngelSlim/Qwen3-14B_eagle3","num_speculative_tokens": 3}' 

Unit tests:

pytest tests/v1/e2e/test_spec_decode.py -k qwen3_eagle3
...
=========================================================== 1 passed, 2 skipped, 16 deselected, 1 warning in 81.76s (0:01:21) ===========================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added llama Related to Llama models new-model Requests to new models labels Sep 7, 2025
@wwl2755 wwl2755 changed the title [Bug][Spec Decode] Fix bug in eagle3 [BugFix][Spec Decode] Fix bug in eagle3 Sep 7, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes an IndexError that occurred with eagle3 speculative decoding models. The fix in vllm/model_executor/models/llama.py properly adds a bounds check before accessing the layer_types list, preventing a crash. The corresponding change in vllm/model_executor/models/registry.py to re-enable the LlamaForCausalLMEagle3 model is a logical consequence of this bug fix. The changes are accurate and well-implemented, and I have no further recommendations.

@wwl2755 wwl2755 changed the title [BugFix][Spec Decode] Fix bug in eagle3 [BugFix][Spec Decode] Fix out-of-range index triggered by eagle3 Sep 7, 2025
@22quinn
Copy link
Collaborator

22quinn commented Sep 7, 2025

Shall we re-enable everything that was disabled in #22611?

@mergify mergify bot added the v1 label Sep 7, 2025
@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 7, 2025

Shall we re-enable everything that was disabled in #22611?

Sure. Updated on the new commit.

Local pytest has passed through pytest tests/v1/e2e/test_spec_decode.py -k qwen3_eagle3

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 8, 2025
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
@wwl2755 wwl2755 changed the title [BugFix][Spec Decode] Fix out-of-range index triggered by eagle3 [BugFix][Spec Decode] Fix out-of-range index triggered by eagle3; re-enable test for LlamaForCausalLMEagle3 Sep 8, 2025
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - left a nit

Signed-off-by: wwl2755 <[email protected]>
sliding_window = None
if layer_types := getattr(config, "layer_types", None):
if (layer_types := getattr(config, "layer_types",
None)) and layer_idx < len(layer_types):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we instead have an assert that layer_idx < len(layer_types)? That would help catch bug which would get pass the current if check silently and might manifest in some quality drop if layer_idx is over layer_types due to misconfig

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, which model is using sliding window in the eagle head?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In AngelSlim/Qwen3-14B_eagle3, it has layer_types: ['full_attention'] and layer_idx: 40. If we use assertion, it will be stopped here.

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct to assume that layer_idx of 40 is layer idx corresponding to the eagle 3 draft head? What is the value of len(layer_types) when layer_idx is 40?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I printed the number right before the bug.
(EngineCore_0 pid=830911) layer_types: ['full_attention']
(EngineCore_0 pid=830911) layer_idx: 40

which model is using sliding window in the eagle head?

None in my understanding, it is only to avoid entering the codepath.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current fix wont work for draft having SWA. The fix bypasses applying the SWA if draft had layer_types: ['sliding_attention']. The foolproof fix would be to redefine layer_types in draft as draft.layer_types = target.layer_types + draft.layer_types and this new value would be passed to llama.py from here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. But I'm not sure whether other eagle3 model holds this assumption. Let me have some tests.

# Eagle model name should follow naming convention of
# LlamaForCausalLM -> EagleLlamaForCausalLM
# LlamaForCausalLM -> Eagle3LlamaForCausalLM
# LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didnt get this part as to why do we need both LlamaForCausalLMEagle3 and Eagle3LlamaForCausalLM to represent the same thing, i.e., llama Eagle 3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment means we are able to handle both cases. There are two different names because people who creates the eagle model defines it that way. We cannot control what they have in the HF repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I definitely agree we should unify it (maybe in the future PR). It took me some time to understand it.

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Just to confirm,

  1. LlamaForCausalLM and LlamaForCausalLMEagle3 are model names defined in HF?
  2. there is no LlamaForCausalLMEagle to define EAGLE 1 in HF?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes. Both yuhuili/EAGLE-LLaMA3.1-Instruct-8B and yuhuili/EAGLE3-LLaMA3.1-Instruct-8B used LlamaForCausalLM, even though they are eagle1 and eagle3. Newer models like AngelSlim/Qwen3-8B_eagle3 uses LlamaForCausalLMEagle3

  2. None, to my understanding. If there is a popular one, there should be some bug reported.

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 9, 2025

@ekagra-ranjan Updated according to your suggestion. PTAL. I changed a little bit by using effective_layer_idx = layer_idx - target_num_layers. Because I think it makes more sense for draft model to only see its own layers.

Tested with AngelSlim/Qwen3-14B_eagle3 and yuhuili/EAGLE3-LLaMA3.1-Instruct-8B

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing SWA as well. LGTM

@simon-mo simon-mo merged commit 53b42f4 into vllm-project:main Sep 10, 2025
36 of 39 checks passed
@wwl2755 wwl2755 deleted the qwen3 branch September 10, 2025 05:36
@ggg-s
Copy link

ggg-s commented Sep 10, 2025

error:
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] Error in inspecting model architecture 'Eagle3LlamaForCausalLM'
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] Traceback (most recent call last):
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 863, in _run_in_subprocess
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] returned.check_returncode()
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/opt/miniconda/envs/vllm/lib/python3.12/subprocess.py", line 502, in check_returncode
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] raise CalledProcessError(self.returncode, self.args, self.stdout,
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] subprocess.CalledProcessError: Command '['/opt/miniconda/envs/vllm/bin/python3.12', '-m', 'vllm.model_executor.models.registry']' returned non-zero exit status 1.
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445]
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] The above exception was the direct cause of the following exception:
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445]
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] Traceback (most recent call last):
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 443, in _try_inspect_model_cls
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] return model.inspect_model_cls()
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] ^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 414, in inspect_model_cls
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] return _run_in_subprocess(
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] ^^^^^^^^^^^^^^^^^^^
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 866, in _run_in_subprocess
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] raise RuntimeError(f"Error raised in subprocess:\n"
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] RuntimeError: Error raised in subprocess:
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] :128: RuntimeWarning: 'vllm.model_executor.models.registry' found in sys.modules after import of package 'vllm.model_executor.models', but prior to execution of 'vllm.model_executor.models.registry'; this may result in unpredictable behaviour
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] Traceback (most recent call last):
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "", line 198, in _run_module_as_main
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "", line 88, in _run_code
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 887, in
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] _run()
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 880, in _run
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] result = fn()
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] ^^^^
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 415, in
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 415, in
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] ^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/registry.py", line 418, in load_model_cls
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] mod = importlib.import_module(self.module_name)
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/opt/miniconda/envs/vllm/lib/python3.12/importlib/init.py", line 90, in import_module
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] return _bootstrap._gcd_import(name[level:], package, level)
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "", line 1381, in _gcd_import
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "", line 1354, in _find_and_load
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "", line 1325, in _find_and_load_unlocked
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "", line 929, in _load_unlocked
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "", line 994, in exec_module
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "", line 488, in _call_with_frames_removed
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/model_executor/models/llama_eagle3.py", line 24, in
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] from vllm.v1.sample.metadata import SamplingMetadata
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/v1/sample/metadata.py", line 9, in
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] from vllm.v1.sample.logits_processor import LogitsProcessors
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/v1/sample/logits_processor/init.py", line 16, in
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] File "/workspace/vllm/vllm/vllm/v1/sample/logits_processor/builtin.py", line 8, in
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] from vllm import SamplingParams
(APIServer pid=45967) ERROR 09-10 09:32:46 [registry.py:445] ImportError: cannot import name 'SamplingParams' from 'vllm' (unknown location). Did you mean: 'sampling_params'?

@DarkLight1337
Copy link
Member

This PR is causing Basic Models Test to fail on main

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 10, 2025

This PR is causing Basic Models Test to fail on main

Hi @DarkLight1337 Do you know how to reproduce in the env as in the CI? Directly run pytest -v -s tests/models/test_initialization.py::test_can_initialize[Eagle3LlamaForCausalLM] will cause CUDA initialization problem to me. (This error occurs locally in any model.) So I rely on the CI check on this branch which seemed okay before merging.

I pulled the latest main and it works fine e2e. 😕

python examples/offline_inference/spec_decode.py --method eagle3 --num-prompts 10 --model-dir Qwen/Qwen3-14B --eagle-dir AngelSlim/Qwen3-14B_eagle3
2025-09-10 17:24:54.896041: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO 09-10 17:24:57 [__init__.py:216] Automatically detected platform cuda.
INFO 09-10 17:25:00 [datasets.py:517] Sampling input_len from [1024, 1024] and output_len from [128, 128]
INFO 09-10 17:25:00 [utils.py:328] non-default args: {'trust_remote_code': True, 'max_model_len': 2048, 'gpu_memory_utilization': 0.8, 'limit_mm_per_prompt': {'image': 5}, 'enable_chunked_prefill': False, 'disable_chunked_mm_input': True, 'speculative_config': {'method': 'eagle3', 'model': 'AngelSlim/Qwen3-14B_eagle3', 'num_speculative_tokens': 2}, 'model': 'Qwen/Qwen3-14B'}
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
INFO 09-10 17:25:12 [__init__.py:742] Resolved architecture: Qwen3ForCausalLM
INFO 09-10 17:25:12 [__init__.py:1797] Using max model len 2048
INFO 09-10 17:25:26 [__init__.py:742] Resolved architecture: LlamaForCausalLMEagle3
INFO 09-10 17:25:26 [__init__.py:1797] Using max model len 40960
INFO 09-10 17:25:26 [scheduler.py:222] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 09-10 17:25:27 [__init__.py:2971] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reasons: CUDA is initialized
INFO 09-10 17:25:34 [__init__.py:216] Automatically detected platform cuda.
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:37 [core.py:654] Waiting for init message from front-end.
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:37 [core.py:76] Initializing a V1 LLM engine (v0.1.dev9129+ga74232209) with config: model='Qwen/Qwen3-14B', speculative_config=SpeculativeConfig(method='eagle3', model='AngelSlim/Qwen3-14B_eagle3', num_spec_tokens=2), tokenizer='Qwen/Qwen3-14B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen3-14B, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2","vllm.mamba_mixer","vllm.short_conv","vllm.linear_attention","vllm.plamo2_mamba_mixer"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":1,"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"pass_config":{},"max_capture_size":512,"local_cache_dir":null}
(EngineCore_DP0 pid=1729231) W0910 17:25:38.434000 1729231 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(EngineCore_DP0 pid=1729231) W0910 17:25:38.434000 1729231 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
(EngineCore_DP0 pid=1729231) 2025-09-10 17:25:38,437 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[W910 17:25:39.401798369 ProcessGroupNCCL.cpp:981] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:40 [parallel_state.py:1164] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:40 [topk_topp_sampler.py:58] Using FlashInfer for top-p & top-k sampling.
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:40 [gpu_model_runner.py:2197] Starting to load model Qwen/Qwen3-14B...
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:42 [gpu_model_runner.py:2229] Loading model from scratch...
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:42 [cuda.py:338] Using Flash Attention backend on V1 engine.
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:42 [weight_utils.py:348] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/8 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  12% Completed | 1/8 [00:00<00:05,  1.23it/s]
Loading safetensors checkpoint shards:  25% Completed | 2/8 [00:01<00:04,  1.20it/s]
Loading safetensors checkpoint shards:  38% Completed | 3/8 [00:02<00:03,  1.47it/s]
Loading safetensors checkpoint shards:  50% Completed | 4/8 [00:02<00:02,  1.40it/s]
Loading safetensors checkpoint shards:  62% Completed | 5/8 [00:03<00:02,  1.27it/s]
Loading safetensors checkpoint shards:  75% Completed | 6/8 [00:04<00:01,  1.18it/s]
Loading safetensors checkpoint shards:  88% Completed | 7/8 [00:05<00:00,  1.15it/s]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [00:06<00:00,  1.20it/s]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [00:06<00:00,  1.23it/s]
(EngineCore_DP0 pid=1729231) 
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:49 [default_loader.py:267] Loading weights took 6.63 seconds
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:49 [gpu_model_runner.py:2239] Loading drafter model...
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:49 [weight_utils.py:348] Using model weights format ['*.safetensors', '*.bin', '*.pt']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.10it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.10it/s]
(EngineCore_DP0 pid=1729231) 
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:50 [default_loader.py:267] Loading weights took 0.77 seconds
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:50 [eagle.py:634] Assuming the EAGLE head shares the same vocab embedding with the target model.
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:52 [gpu_model_runner.py:2251] Model loading took 28.6541 GiB and 8.400862 seconds
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:59 [backends.py:538] Using cache directory: /home/cc/.cache/vllm/torch_compile_cache/06ea8de55c/rank_0_0/backbone for vLLM's torch.compile
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:59 [backends.py:549] Dynamo bytecode transform time: 6.26 s
(EngineCore_DP0 pid=1729231) INFO 09-10 17:25:59 [backends.py:194] Cache the graph for dynamic shape for later use
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:05 [backends.py:215] Compiling a graph for dynamic shape takes 5.94 s
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:13 [monitor.py:34] torch.compile takes 12.19 s in total
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:14 [backends.py:538] Using cache directory: /home/cc/.cache/vllm/torch_compile_cache/06ea8de55c/rank_0_0/eagle_head for vLLM's torch.compile
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:14 [backends.py:549] Dynamo bytecode transform time: 0.38 s
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:14 [backends.py:194] Cache the graph for dynamic shape for later use
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:14 [backends.py:215] Compiling a graph for dynamic shape takes 0.14 s
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:14 [monitor.py:34] torch.compile takes 12.71 s in total
(EngineCore_DP0 pid=1729231) 2025-09-10 17:26:14,829 - INFO - flashinfer.jit: Loading JIT ops: sampling
(EngineCore_DP0 pid=1729231) [rank0]:W0910 17:26:14.830000 1729231 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(EngineCore_DP0 pid=1729231) [rank0]:W0910 17:26:14.830000 1729231 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
(EngineCore_DP0 pid=1729231) [rank0]:W0910 17:26:14.838000 1729231 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(EngineCore_DP0 pid=1729231) [rank0]:W0910 17:26:14.838000 1729231 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
(EngineCore_DP0 pid=1729231) 2025-09-10 17:26:14,858 - INFO - flashinfer.jit: Finished loading JIT ops: sampling
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:16 [gpu_worker.py:276] Available KV cache memory: 1.57 GiB
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:16 [kv_cache_utils.py:864] GPU KV cache size: 10,048 tokens
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:16 [kv_cache_utils.py:868] Maximum concurrency for 2,048 tokens per request: 4.91x
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:04<00:00, 14.26it/s]
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:22 [gpu_model_runner.py:2953] Graph capturing finished in 6 secs, took 0.72 GiB
(EngineCore_DP0 pid=1729231) INFO 09-10 17:26:22 [core.py:218] init engine (profile, create kv cache, warmup model) took 30.20 seconds
INFO 09-10 17:26:23 [llm.py:285] Supported_tasks: ['generate']
INFO 09-10 17:26:23 [__init__.py:36] No IOProcessor plugins requested by the model
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 2879.32it/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:12<00:00,  1.30s/it, est. speed input: 783.40 toks/s, output: 197.02 toks/s]
--------------------------------------------------
total_num_output_tokens: 2560
num_drafts: 1911
num_draft_tokens: 3822
num_accepted_tokens: 643
mean acceptance length: 1.34
--------------------------------------------------
acceptance at token 0: 0.31
acceptance at token 1: 0.03
[rank0]:[W910 17:26:36.221861814 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

@DarkLight1337
Copy link
Member

Maybe you can use the error log to help debug: https://buildkite.com/vllm/ci/builds/30179#01993427-9309-4bf2-a37a-e63f8b87dc4f

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 10, 2025

Maybe you can use the error log to help debug: https://buildkite.com/vllm/ci/builds/30179#01993427-9309-4bf2-a37a-e63f8b87dc4f

Thanks! It could be related to how we use the layer_idx. I'm looking closer at it.

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 10, 2025

Update: Solved by adding VLLM_WORKER_MULTIPROC_METHOD=spawn


Sorry but I still cannot get what's going on in the CI env. I may need some help to understand that.

Here I print the layer name right before the assertion in my e2e test. You can see all the layers are correctly loaded and indexes are correct. So I'm thinking there must be something different either in the test case or the CI env?

And I'm not sure where the error ValueError: Duplicate layer name: model.layers.1.self_attn.attn comes from. Even if the index is wrong, it should start to log the error beginning fromlayers.0.

The worse thing is that I could not reproduce it locally. (pytest failed in all models with error logs different than the CI).

Could you reprodice the error locally? Or in case this could not be solved shortly, should we skip it to recover our CIs? This test was previously skipped and it seems there are not only one error hidden in this case.

image

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 10, 2025

Hi @ggg-s I replied to you at: #23464 (comment)

@qandrew
Copy link
Contributor

qandrew commented Sep 10, 2025

My CI is failing on an PR (#24127) that doesn't touch code here https://buildkite.com/vllm/ci/builds/30242#01993572-2b87-4ab4-a155-8c7b1a4a6570

Can we disable the CI test again as it's blocking forward progress for other PRs?

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 10, 2025

My CI is failing on an PR (#24127) that doesn't touch code here https://buildkite.com/vllm/ci/builds/30242#01993572-2b87-4ab4-a155-8c7b1a4a6570

Can we disable the CI test again as it's blocking forward progress for other PRs?

Just a sec, on the final testing on a PR to fix it.

I saw your PR is only metric-related, so it should be okay to go if it is the only failure. It is known, you can ask in the slack. Sorry for the inconvience.

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 10, 2025

Fixing PR in: #24613

skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
rogeryoungh pushed a commit to MiniMax-AI/vllm that referenced this pull request Sep 15, 2025
…enable test for LlamaForCausalLMEagle3 (vllm-project#24392)

Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: rogeryoungh <[email protected]>
cboss6 pushed a commit to cboss6/vllm that referenced this pull request Sep 16, 2025
…enable test for LlamaForCausalLMEagle3 (vllm-project#24392)

Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
cboss6 pushed a commit to cboss6/vllm that referenced this pull request Sep 16, 2025
…enable test for LlamaForCausalLMEagle3 (vllm-project#24392)

Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
langc23 pushed a commit to zte-riscv/vllm that referenced this pull request Sep 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llama Related to Llama models new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Support qwen3 Models in eagle3 Speculative Decoding
8 participants