diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5947bfdf23..eb00089e80 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -528,7 +528,7 @@ def fp8_init(self, num_gemms: int = 1) -> None: self.fp8_calibration = is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - if self.fp8 or self.fp8_calibration or self.primary_weights_in_fp8: + if self.fp8 or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]: return diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0dd3325a67..84b60cbe18 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -552,8 +552,6 @@ def __init__( self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag - self.primary_weights_in_fp8 = primary_weights_in_fp8 - if ub_split_rs or ub_split_ag: assert ( tex.userbuf_comm_available()