Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/convert/eagle3/apply_eagle3_llama4_maverick.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
speculators convert nvidia/Llama-4-Maverick-17B-128E-Eagle3 \
--algorithm eagle3 \
--verifier RedHatAI/Llama-4-Maverick-17B-128E-Instruct-quantized.w4a16 \
--output-path Llama4-Maverick-Eagle3-Speculators \
--validate-device cuda:0 \
--algorithm-kwargs '{"eagle_aux_hidden_state_layer_ids": [1,23,44], "norm_before_residual": false}'
3 changes: 2 additions & 1 deletion src/speculators/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def convert(
help=(
"Additional keyword args for the conversion alg as a JSON string. "
'Options for Eagle: {"layernorms": true, "fusion_bias": true}. '
'Options for Eagle3: {"norm_before_residual": true}.'
'Options for Eagle3: {"norm_before_residual": true, '
'"eagle_aux_hidden_state_layer_ids": [1,23,44]}.'
),
),
] = None,
Expand Down
125 changes: 38 additions & 87 deletions src/speculators/convert/eagle/eagle3_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import torch
from loguru import logger
from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig
from transformers import LlamaConfig, PretrainedConfig

from speculators.config import SpeculatorsConfig, VerifierConfig
from speculators.convert.eagle.utils import (
ensure_checkpoint_is_local,
find_vocab_size,
load_checkpoint_config,
load_checkpoint_weights,
)
Expand All @@ -34,6 +35,7 @@ def convert(
base_model: str,
validate: bool = True,
norm_before_residual: bool = False,
eagle_aux_hidden_state_layer_ids: Optional[list[int]] = None,
cache_dir: Optional[Union[str, Path]] = None,
) -> None:
logger.info(f"Converting Eagle-3 checkpoint: {input_path}")
Expand All @@ -44,107 +46,50 @@ def convert(
weights = load_checkpoint_weights(local_checkpoint_path)
logger.info(f"Loaded {len(weights)} weights")

# Patch: ensure target_vocab_size matches t2d tensor shape
eagle_config["target_vocab_size"] = weights["t2d"].shape[0]
reduce_vocab_size = False
# Get target_vocab_size from t2d tensor shape if available
if "t2d" in weights:
eagle_config["target_vocab_size"] = weights["t2d"].shape[0]
logger.debug(
f"Using target_vocab_size from t2d tensor: "
f"{eagle_config['target_vocab_size']}"
)
reduce_vocab_size = True
else:
# fall back to target model config - search for vocab_size at any level
target_config_dict, _ = PretrainedConfig.get_config_dict(base_model)
vocab_size = find_vocab_size(target_config_dict)
if vocab_size is None:
raise ValueError(
"Could not determine vocab_size from target model config."
)
eagle_config["target_vocab_size"] = vocab_size
logger.debug(
f"Using target_vocab_size from config: "
f"{eagle_config['target_vocab_size']}"
)

config = self._build_eagle3_speculator_config(
eagle_config,
base_model,
norm_before_residual,
eagle_aux_hidden_state_layer_ids,
)

# Process weights and ensure embeddings are properly handled
processed_weights = self._process_checkpoint_weights(weights, base_model)
has_drafter_embedding = "embed_tokens.weight" in weights

saved_path = self._save_converted_checkpoint(
config, processed_weights, output_path
config,
weights,
output_path,
reduce_vocab_size,
has_drafter_embedding,
)
logger.success(f"Saved to: {saved_path}")

if validate:
self._validate_converted_checkpoint(saved_path, base_model)

def _process_checkpoint_weights(
self, weights: dict[str, torch.Tensor], base_model: str
) -> dict[str, torch.Tensor]:
"""
Process and validate Eagle3 checkpoint weights.

Eagle3 models need embeddings that match the verifier model for good acceptance.
We ALWAYS replace embeddings with verifier embeddings for compatibility.

:param weights: Original checkpoint weights
:param base_model: Base model name to load verifier embeddings from
:return: Processed weights with verifier embeddings
"""
logger.debug(f"Processing {len(weights)} Eagle3 weights")

# Remap weight names: midlayer.* -> layers.0.*
processed_weights = {}
for original_name, tensor in weights.items():
# Remap midlayer.* -> layers.0.*
if original_name.startswith("midlayer."):
new_name = original_name.replace("midlayer.", "layers.0.")
processed_weights[new_name] = tensor
logger.debug(f"Remapped: {original_name} -> {new_name}")
# Keep layers.0.* as is (already correct)
elif original_name.startswith("layers.0."):
processed_weights[original_name] = tensor
else:
processed_weights[original_name] = tensor

# Only add verifier embeddings if not present in eagle model
if "embed_tokens.weight" not in processed_weights:
logger.info("Eagle model missing embeddings - adding verifier embeddings")
return self._add_verifier_embeddings(processed_weights, base_model)
else:
logger.info("Eagle model already has embeddings - keeping originals")
return processed_weights

def _add_verifier_embeddings(
self, weights: dict[str, torch.Tensor], base_model: str
) -> dict[str, torch.Tensor]:
"""
Add embeddings from the verifier model to the checkpoint.

:param weights: Current checkpoint weights
:param base_model: Base model to load embeddings from
:return: Updated weights with verifier embeddings
"""
logger.info(f"Loading embeddings from verifier model: {base_model}")

try:
# Load verifier model to get embeddings
verifier = AutoModelForCausalLM.from_pretrained(
base_model, torch_dtype=torch.float32
)

# Extract embeddings from verifier
if hasattr(verifier, "model") and hasattr(verifier.model, "embed_tokens"):
embed_tokens = verifier.model.embed_tokens.weight.data.clone() # type: ignore[assignment,union-attr,operator,attr-defined]
elif hasattr(verifier, "embed_tokens"):
embed_tokens = verifier.embed_tokens.weight.data.clone() # type: ignore[assignment,union-attr,operator,attr-defined]
else:
raise RuntimeError(
f"Could not find embed_tokens in verifier model {base_model}"
)

logger.info(f"Loaded embeddings with shape: {embed_tokens.shape}")
weights["embed_tokens.weight"] = embed_tokens

# Clean up verifier model to save memory
del verifier
torch.cuda.empty_cache() if torch.cuda.is_available() else None

except (OSError, ValueError, RuntimeError) as e:
logger.error(f"Failed to load embeddings from verifier: {e}")
raise RuntimeError(
f"Could not load embeddings from verifier model {base_model}. "
"This is required for Eagle3 models without trained embeddings."
) from e

return weights

def _create_verifier_config(self, base_model: str) -> VerifierConfig:
config_dict, _ = PretrainedConfig.get_config_dict(base_model)
return VerifierConfig(
Expand All @@ -157,6 +102,7 @@ def _build_eagle3_speculator_config(
eagle_config: dict,
base_model: str,
norm_before_residual: bool = False,
eagle_aux_hidden_state_layer_ids: Optional[list[int]] = None,
) -> Eagle3SpeculatorConfig:
transformer_config = self._create_transformer_config_from_eagle(
eagle_config, base_model
Expand All @@ -181,6 +127,7 @@ def _build_eagle3_speculator_config(
draft_vocab_size=eagle_config.get("draft_vocab_size", 32000),
norm_before_residual=norm_before_residual,
target_hidden_size=eagle_config.get("target_hidden_size"),
eagle_aux_hidden_state_layer_ids=eagle_aux_hidden_state_layer_ids,
)

def _create_transformer_config_from_eagle(
Expand Down Expand Up @@ -223,11 +170,15 @@ def _save_converted_checkpoint(
config: Eagle3SpeculatorConfig,
weights: dict[str, torch.Tensor],
output_dir: Union[str, Path],
reduce_vocab_size: bool,
has_drafter_embedding: bool,
) -> Path:
model = Eagle3Speculator(
config=config,
verifier=None,
verifier_attachment_mode="detached",
reduce_vocab_size=reduce_vocab_size,
has_drafter_embedding=has_drafter_embedding,
)
model.load_state_dict(weights, strict=False) # type: ignore[attr-defined]
weights_dtype = getattr(config.transformer_layer_config, "torch_dtype", None)
Expand Down
18 changes: 18 additions & 0 deletions src/speculators/convert/eagle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@
from safetensors import safe_open


def find_vocab_size(config_dict: dict) -> Optional[int]:
"""
Recursively search for vocab_size in nested config dictionary.

:param config_dict: Configuration dictionary to search
:return: vocab_size if found, None otherwise
"""
if isinstance(config_dict, dict):
if "vocab_size" in config_dict:
return config_dict["vocab_size"]
for value in config_dict.values():
if isinstance(value, dict):
result = find_vocab_size(value)
if result is not None:
return result
return None


def download_checkpoint_from_hub(
model_id: str, cache_dir: Optional[str] = None
) -> Path:
Expand Down
3 changes: 2 additions & 1 deletion src/speculators/convert/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def convert_model(
:param output_path: Directory path where the converted model will be saved.
:param kwargs: Additional keyword arguments for the conversion algorithm.
Options for Eagle: {"layernorms": true, "fusion_bias": true}.
Options for Eagle3: {"norm_before_residual": true}.
Options for Eagle3: {"norm_before_residual": true,
"eagle_aux_hidden_state_layer_ids": [1,23,44]}.
"""

if algorithm == "eagle":
Expand Down
Loading