From 5c649e9350acc97fd2936aad1256d0e42cc30948 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 10:50:30 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/fp8.py | 6 +++--- transformer_engine/pytorch/module/base.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f54cda6429..bb799ef8e5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -485,9 +485,9 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager def fp8_model_init( - enabled: bool = True, - preserve_high_precision_init_val: bool = False, - ) -> None: + enabled: bool = True, + preserve_high_precision_init_val: bool = False, +) -> None: """ Context manager for FP8 initialization of parameters. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 27e7469434..f4aa9de2d5 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -882,6 +882,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # a parameter so we always re-apply it just for extra safety. param = torch.nn.Parameter(param) if self.primary_weights_in_fp8 and self.preserve_high_precision_init_val: + def get(self): if hasattr(self, "_high_precision_init_val"): return self._high_precision_init_val