Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ ignore = [
"COM812",
"ISC001",
"TC002",
"TC003", # allow imports outside of type checking blocks
"S311", # allow random number generators
"PLW1514", # allow Path.open without encoding
"RET505", # allow `else` blocks
Expand Down
18 changes: 8 additions & 10 deletions src/speculators/convert/eagle/eagle3_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig

from speculators.config import SpeculatorsConfig, VerifierConfig
from speculators.convert.eagle.utils import (
ensure_checkpoint_is_local,
load_checkpoint_config,
load_checkpoint_weights,
)
from speculators.models.eagle3 import Eagle3Speculator, Eagle3SpeculatorConfig
from speculators.proposals.greedy import GreedyTokenProposalConfig
from speculators.utils import (
load_model_checkpoint_config_dict,
load_model_checkpoint_state_dict,
)

__all__ = ["Eagle3Converter"]

Expand All @@ -39,11 +38,10 @@ def convert(
cache_dir: Optional[Union[str, Path]] = None,
) -> None:
logger.info(f"Converting Eagle-3 checkpoint: {input_path}")

local_checkpoint_path = ensure_checkpoint_is_local(input_path, cache_dir)

eagle_config = load_checkpoint_config(local_checkpoint_path)
weights = load_checkpoint_weights(local_checkpoint_path)
eagle_config = load_model_checkpoint_config_dict(
input_path, cache_dir=cache_dir
)
weights = load_model_checkpoint_state_dict(input_path, cache_dir=cache_dir)
logger.info(f"Loaded {len(weights)} weights")

# Patch: ensure target_vocab_size matches t2d tensor shape
Expand Down
51 changes: 41 additions & 10 deletions src/speculators/convert/eagle/eagle_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,49 @@
from transformers import LlamaConfig, PretrainedConfig

from speculators.config import SpeculatorsConfig, VerifierConfig
from speculators.convert.eagle.utils import (
detect_fusion_bias_and_layernorms,
ensure_checkpoint_is_local,
load_checkpoint_config,
load_checkpoint_weights,
)
from speculators.models.eagle import EagleSpeculator, EagleSpeculatorConfig
from speculators.proposals.greedy import GreedyTokenProposalConfig
from speculators.utils import (
load_model_checkpoint_config_dict,
load_model_checkpoint_state_dict,
)

__all__ = ["EagleConverter"]


def detect_fusion_bias_and_layernorms(
weights: dict[str, torch.Tensor],
) -> tuple[bool, bool]:
"""
Auto-detect fusion bias and extra layernorms presence based on weight names.

:param weights: Dictionary of weight tensors
:return: Tuple of (has_fusion_bias, has_layernorms)

:Example:

>>> weights = {
... "fc.bias": torch.randn(4096),
... "embed_layernorm.weight": torch.randn(4096)
... }
>>> has_bias, has_ln = detect_fusion_bias_and_layernorms(weights)
>>> print(f"Fusion bias: {has_bias}, Layernorms: {has_ln}")
Fusion bias: True, Layernorms: True
"""
has_fusion_bias = "fc.bias" in weights
has_layernorms = any(
name in weights
for name in ["embed_layernorm.weight", "post_embedding_layernorm.weight"]
)

if has_fusion_bias:
logger.info("Detected fusion bias in checkpoint")
if has_layernorms:
logger.info("Detected extra layernorms in checkpoint")

return has_fusion_bias, has_layernorms


class EagleConverter:
"""
Converter for Eagle/HASS checkpoints to speculators format.
Expand Down Expand Up @@ -98,10 +129,10 @@ def convert(
"""
logger.info(f"Converting Eagle checkpoint: {input_path}")

local_checkpoint_path = ensure_checkpoint_is_local(input_path, cache_dir)

eagle_config = load_checkpoint_config(local_checkpoint_path)
weights = load_checkpoint_weights(local_checkpoint_path)
eagle_config = load_model_checkpoint_config_dict(
input_path, cache_dir=cache_dir
)
weights = load_model_checkpoint_state_dict(input_path, cache_dir=cache_dir)
logger.info(f"Loaded {len(weights)} weights")

detected_fusion_bias, detected_layernorms = detect_fusion_bias_and_layernorms(
Expand Down
180 changes: 0 additions & 180 deletions src/speculators/convert/eagle/utils.py

This file was deleted.

18 changes: 18 additions & 0 deletions src/speculators/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
from .auto_importer import AutoImporterMixin
from .pydantic_utils import PydanticClassRegistryMixin, ReloadableBaseModel
from .registry import RegistryMixin
from .transformers_utils import (
check_download_model_checkpoint,
check_download_model_config,
download_model_checkpoint_from_hub,
load_model_checkpoint_config_dict,
load_model_checkpoint_index_weight_files,
load_model_checkpoint_state_dict,
load_model_checkpoint_weight_files,
load_model_config,
)

__all__ = [
"AutoImporterMixin",
"PydanticClassRegistryMixin",
"RegistryMixin",
"ReloadableBaseModel",
"check_download_model_checkpoint",
"check_download_model_config",
"download_model_checkpoint_from_hub",
"load_model_checkpoint_config_dict",
"load_model_checkpoint_index_weight_files",
"load_model_checkpoint_state_dict",
"load_model_checkpoint_weight_files",
"load_model_config",
]
Loading