diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f54cda6429..bb799ef8e5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -485,9 +485,9 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager def fp8_model_init( - enabled: bool = True, - preserve_high_precision_init_val: bool = False, - ) -> None: + enabled: bool = True, + preserve_high_precision_init_val: bool = False, +) -> None: """ Context manager for FP8 initialization of parameters. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ceb4a286e4..3aebc1729b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -883,6 +883,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # a parameter so we always re-apply it just for extra safety. param = torch.nn.Parameter(param) if high_precision_init_val is not None: + def get(self): if hasattr(self, "_high_precision_init_val"): return self._high_precision_init_val