Skip to content

Commit 5a04da4

Browse files
committed
Prevent unnecessary Verifier Model load
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 077020d commit 5a04da4

File tree

3 files changed

+25
-22
lines changed

3 files changed

+25
-22
lines changed

src/speculators/model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def attach_verifier(
375375
self,
376376
verifier: str | os.PathLike | PreTrainedModel,
377377
mode: Literal["full", "train_only"] | None = None,
378-
) -> PreTrainedModel:
378+
):
379379
"""
380380
Attach a verifier model for the speculator that is used to attach to
381381
for running inference/training with the speculator and validates the
@@ -417,14 +417,13 @@ def attach_verifier(
417417
"Must be one of 'full', 'train_only', or None."
418418
)
419419

420-
verifier = self.resolve_verifier(verifier)
421420
self.verifier_attachment_mode = mode or "full"
422421
self.verifier = (
423-
verifier if self.verifier_attachment_mode == "full" else None
422+
self.resolve_verifier(verifier)
423+
if self.verifier_attachment_mode == "full"
424+
else None
424425
) # Expect subclasses to handle references if train_only
425426

426-
return verifier
427-
428427
def detach_verifier(self):
429428
"""
430429
Removes the reference to the attached verifier model and frees up the

src/speculators/models/eagle.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import re
1717
import warnings
18-
from typing import Any, ClassVar, Literal
18+
from typing import Any, ClassVar, Literal, cast
1919

2020
import torch
2121
from pydantic import Field, field_serializer, field_validator, model_validator
@@ -308,7 +308,7 @@ def attach_verifier(
308308
self,
309309
verifier: str | os.PathLike | PreTrainedModel,
310310
mode: Literal["full", "train_only"] | None = None,
311-
) -> PreTrainedModel:
311+
):
312312
"""
313313
Attach a verifier model to the EagleSpeculator for speculative decoding.
314314
Utilizes the verifier's embed_tokens, rotary_emb, and lm_head layers
@@ -349,25 +349,25 @@ def attach_verifier(
349349
perform generation until a full verifier is attached.
350350
:return: The PreTrainedModel instance for the verifier that was attached.
351351
"""
352-
verifier = super().attach_verifier(
353-
verifier=verifier,
354-
mode=mode,
355-
)
352+
super().attach_verifier(verifier=verifier, mode=mode)
356353

357-
# Extract layers from the verifier model
354+
if self.verifier_attachment_mode == "train_only":
355+
verifier_model = self.resolve_verifier(verifier)
356+
elif self.verifier_attachment_mode == "full":
357+
verifier_model = cast("PreTrainedModel", self.verifier)
358+
else:
359+
return
358360

359-
if hasattr(verifier, "model"):
360-
self.embed_tokens = verifier.model.embed_tokens # type: ignore[assignment,union-attr]
361-
self.rotary_emb = verifier.model.rotary_emb # type: ignore[assignment,union-attr]
361+
if hasattr(verifier_model, "model"):
362+
self.embed_tokens = verifier_model.model.embed_tokens # type: ignore[assignment,union-attr]
363+
self.rotary_emb = verifier_model.model.rotary_emb # type: ignore[assignment,union-attr]
362364
else:
363365
# Bare model structure
364-
self.embed_tokens = verifier.embed_tokens # type: ignore[assignment,attr-defined]
365-
self.rotary_emb = verifier.rotary_emb # type: ignore[assignment,attr-defined]
366+
self.embed_tokens = verifier_model.embed_tokens # type: ignore[assignment,attr-defined]
367+
self.rotary_emb = verifier_model.rotary_emb # type: ignore[assignment,attr-defined]
366368

367369
# lm_head is always at the top level of the verifier
368-
self.lm_head = verifier.lm_head # type: ignore[assignment,attr-defined]
369-
370-
return verifier
370+
self.lm_head = verifier_model.lm_head # type: ignore[assignment,attr-defined]
371371

372372
def detach_verifier(self):
373373
"""

src/speculators/train/checkpointer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def load_model_state_dict(self, model: PreTrainedModel):
8484
full_state_dict = load_safetensors_state_dict(
8585
self.model_path(self.previous_epoch), "cuda:0"
8686
)
87-
model.load_state_dict(full_state_dict)
87+
# Note: `strict=False` because we don't load the verifier weights
88+
model.load_state_dict(full_state_dict, strict=False)
8889

8990
def load_optimizer_state_dict(
9091
self,
@@ -110,10 +111,13 @@ def load_model_state_dict(self, model: PreTrainedModel):
110111
full_state_dict = load_safetensors_state_dict(
111112
self.model_path(self.previous_epoch), "cpu"
112113
)
114+
# Note: `strict=False` because we don't load the verifier weights
113115
set_model_state_dict(
114116
model,
115117
full_state_dict, # type: ignore[arg-type]
116-
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
118+
options=StateDictOptions(
119+
full_state_dict=True, broadcast_from_rank0=True, strict=False
120+
),
117121
)
118122
dist.barrier()
119123

0 commit comments

Comments
 (0)