|
15 | 15 | import os |
16 | 16 | import re |
17 | 17 | import warnings |
18 | | -from typing import Any, ClassVar, Literal |
| 18 | +from typing import Any, ClassVar, Literal, cast |
19 | 19 |
|
20 | 20 | import torch |
21 | 21 | from pydantic import Field, field_serializer, field_validator, model_validator |
@@ -308,7 +308,7 @@ def attach_verifier( |
308 | 308 | self, |
309 | 309 | verifier: str | os.PathLike | PreTrainedModel, |
310 | 310 | mode: Literal["full", "train_only"] | None = None, |
311 | | - ) -> PreTrainedModel: |
| 311 | + ): |
312 | 312 | """ |
313 | 313 | Attach a verifier model to the EagleSpeculator for speculative decoding. |
314 | 314 | Utilizes the verifier's embed_tokens, rotary_emb, and lm_head layers |
@@ -349,25 +349,25 @@ def attach_verifier( |
349 | 349 | perform generation until a full verifier is attached. |
350 | 350 | :return: The PreTrainedModel instance for the verifier that was attached. |
351 | 351 | """ |
352 | | - verifier = super().attach_verifier( |
353 | | - verifier=verifier, |
354 | | - mode=mode, |
355 | | - ) |
| 352 | + super().attach_verifier(verifier=verifier, mode=mode) |
356 | 353 |
|
357 | | - # Extract layers from the verifier model |
| 354 | + if self.verifier_attachment_mode == "train_only": |
| 355 | + verifier_model = self.resolve_verifier(verifier) |
| 356 | + elif self.verifier_attachment_mode == "full": |
| 357 | + verifier_model = cast("PreTrainedModel", self.verifier) |
| 358 | + else: |
| 359 | + return |
358 | 360 |
|
359 | | - if hasattr(verifier, "model"): |
360 | | - self.embed_tokens = verifier.model.embed_tokens # type: ignore[assignment,union-attr] |
361 | | - self.rotary_emb = verifier.model.rotary_emb # type: ignore[assignment,union-attr] |
| 361 | + if hasattr(verifier_model, "model"): |
| 362 | + self.embed_tokens = verifier_model.model.embed_tokens # type: ignore[assignment,union-attr] |
| 363 | + self.rotary_emb = verifier_model.model.rotary_emb # type: ignore[assignment,union-attr] |
362 | 364 | else: |
363 | 365 | # Bare model structure |
364 | | - self.embed_tokens = verifier.embed_tokens # type: ignore[assignment,attr-defined] |
365 | | - self.rotary_emb = verifier.rotary_emb # type: ignore[assignment,attr-defined] |
| 366 | + self.embed_tokens = verifier_model.embed_tokens # type: ignore[assignment,attr-defined] |
| 367 | + self.rotary_emb = verifier_model.rotary_emb # type: ignore[assignment,attr-defined] |
366 | 368 |
|
367 | 369 | # lm_head is always at the top level of the verifier |
368 | | - self.lm_head = verifier.lm_head # type: ignore[assignment,attr-defined] |
369 | | - |
370 | | - return verifier |
| 370 | + self.lm_head = verifier_model.lm_head # type: ignore[assignment,attr-defined] |
371 | 371 |
|
372 | 372 | def detach_verifier(self): |
373 | 373 | """ |
|
0 commit comments