Skip to content

Commit a266dff

Browse files
wwl2755skyloevil
authored andcommitted
[BugFix][Spec Decode] Fix out-of-range index triggered by eagle3; re-enable test for LlamaForCausalLMEagle3 (vllm-project#24392)
Signed-off-by: wwl2755 <[email protected]>
1 parent c2ef645 commit a266dff

File tree

7 files changed

+58
-41
lines changed

7 files changed

+58
-41
lines changed

tests/models/registry.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,10 @@ def check_available_online(
602602
trust_remote_code=True,
603603
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
604604
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
605-
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
606-
# "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
607-
# trust_remote_code=True,
608-
# speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
609-
# tokenizer="Qwen/Qwen3-8B"),
605+
"LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
606+
trust_remote_code=True,
607+
speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
608+
tokenizer="Qwen/Qwen3-8B"),
610609
"EagleLlama4ForCausalLM": _HfExamplesInfo(
611610
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
612611
trust_remote_code=True,

tests/v1/e2e/test_spec_decode.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,37 +125,30 @@ def test_ngram_correctness(
125125
cleanup_dist_env_and_memory()
126126

127127

128-
@pytest.mark.parametrize(
129-
["model_setup", "mm_enabled"],
130-
[
131-
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
132-
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
133-
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
134-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
135-
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
136-
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
137-
pytest.param(
138-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
139-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
140-
False,
141-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
142-
pytest.param(
143-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
144-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
145-
True,
146-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
147-
(("eagle", "eagle618/deepseek-v3-random",
148-
"eagle618/eagle-deepseek-v3-random", 1), False),
149-
],
150-
ids=[
151-
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
152-
# "qwen3_eagle3",
153-
"llama3_eagle",
154-
"llama3_eagle3",
155-
"llama4_eagle",
156-
"llama4_eagle_mm",
157-
"deepseek_eagle"
158-
])
128+
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
129+
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
130+
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
131+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
132+
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
133+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
134+
pytest.param(
135+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
136+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
137+
False,
138+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
139+
pytest.param(
140+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
141+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
142+
True,
143+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
144+
(("eagle", "eagle618/deepseek-v3-random",
145+
"eagle618/eagle-deepseek-v3-random", 1), False),
146+
],
147+
ids=[
148+
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
149+
"llama4_eagle", "llama4_eagle_mm",
150+
"deepseek_eagle"
151+
])
159152
@pytest.mark.parametrize("attn_backend",
160153
get_attn_backend_list_based_on_platform())
161154
def test_eagle_correctness(

vllm/config/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,9 +2191,14 @@ def __post_init__(self):
21912191
# Automatically detect the method
21922192
if self.method in ('eagle', 'eagle3'):
21932193
pass
2194-
elif "eagle-" in self.draft_model_config.model.lower() or \
2195-
"eagle3-" in self.draft_model_config.model.lower():
2194+
# examples:
2195+
# yuhuili/EAGLE-LLaMA3-Instruct-8B
2196+
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
2197+
# AngelSlim/Qwen3-8B_eagle3
2198+
elif "eagle-" in self.draft_model_config.model.lower():
21962199
self.method = "eagle"
2200+
elif "eagle3" in self.draft_model_config.model.lower():
2201+
self.method = "eagle3"
21972202
elif self.draft_model_config.hf_config.model_type == "medusa":
21982203
self.method = "medusa"
21992204
elif (self.draft_model_config.hf_config.model_type ==

vllm/model_executor/models/llama.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,22 @@ def __init__(
171171

172172
sliding_window = None
173173
if layer_types := getattr(config, "layer_types", None):
174-
is_sliding = layer_types[layer_idx] == "sliding_attention"
174+
# Fix for Eagle3 compatibility:
175+
# for draft models, subtract target layer count
176+
# to get draft-relative layer index starting from 0
177+
if hasattr(config, 'target_layer_count'):
178+
# This is a draft model,
179+
# adjust layer_idx to be relative to draft layers
180+
effective_layer_idx = layer_idx - config.target_layer_count
181+
else:
182+
# This is a target model, use layer_idx directly
183+
effective_layer_idx = layer_idx
184+
assert effective_layer_idx < len(layer_types), \
185+
f"effective_layer_idx: {effective_layer_idx} \
186+
is out of bounds for layer_types: {layer_types}"
187+
188+
is_sliding = layer_types[
189+
effective_layer_idx] == "sliding_attention"
175190
if is_sliding:
176191
sliding_window = config.sliding_window
177192

vllm/model_executor/models/llama_eagle3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
199199
speculative_config.draft_model_config.hf_config
200200
target_layer_num = vllm_config.model_config.get_num_layers(
201201
vllm_config.parallel_config)
202+
203+
# Store target layer count in draft config for
204+
# proper layer_types indexing in draft models
205+
self.config.target_layer_count = target_layer_num
202206
self.model = LlamaModel(vllm_config=vllm_config,
203207
prefix="model",
204208
start_layer_id=target_layer_num)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@
277277
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
278278
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
279279
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
280-
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
281-
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
280+
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
282281
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
283282
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
284283
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),

vllm/transformers_utils/configs/eagle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ def __init__(self,
4646
# Eagle model name should follow naming convention of
4747
# LlamaForCausalLM -> EagleLlamaForCausalLM
4848
# LlamaForCausalLM -> Eagle3LlamaForCausalLM
49+
# LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3
4950
if method == "eagle":
5051
assert self.model is not None, \
5152
"model should not be None when method is eagle"
5253
kwargs["architectures"] = [
5354
f"Eagle{arch}" if not arch.startswith("Eagle") \
5455
else arch for arch in self.model.architectures
5556
]
57+
5658
elif method == "eagle3":
5759
assert self.model is not None, \
5860
"model should not be None when method is eagle3"

0 commit comments

Comments
 (0)