Skip to content

Commit e681ef9

Browse files
authored
Spd fix for release 1.20 (#483)
Added the fix for spdtransform in release branch 1.20 Signed-off-by: Dipankar Sarkar <[email protected]>
1 parent 1a52748 commit e681ef9

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1413,7 +1413,7 @@ def __init__(
14131413
self.num_layers = model.config.num_hidden_layers
14141414
self.continuous_batching = continuous_batching
14151415
self.model.qaic_config = qaic_config
1416-
1416+
self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)
14171417
self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs)
14181418
self.is_tlm = transformed
14191419
self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None)

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ class SpDTransform:
503503
@classmethod
504504
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
505505
transformed = False
506+
pretrained_model_name_or_path_temp = kwargs.pop("pretrained_model_name_or_path", None)
506507
if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None:
507508
return model, transformed
508509
elif speculative_model_type not in (
@@ -524,6 +525,7 @@ def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -
524525
raise NotImplementedError(
525526
f"model class {model_class} does not yet support returning multiple logits to keep."
526527
)
528+
kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path_temp
527529
return model, transformed
528530

529531

0 commit comments

Comments
 (0)