Skip to content

Commit

Permalink
no need to use primary_weights_in_fp8 since already using `fp8_auto…
Browse files Browse the repository at this point in the history
…cast` around model init to initialize the fp8_meta tensors

Signed-off-by: Sudhakar Singh <[email protected]>
  • Loading branch information
sudhakarsingh27 committed Aug 14, 2023
1 parent 567f289 commit 04c11f7
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def fp8_init(self, num_gemms: int = 1) -> None:
self.fp8_calibration = is_fp8_calibration()
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration

if self.fp8 or self.fp8_calibration or self.primary_weights_in_fp8:
if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,6 @@ def __init__(
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag

self.primary_weights_in_fp8 = primary_weights_in_fp8

if ub_split_rs or ub_split_ag:
assert (
tex.userbuf_comm_available()
Expand Down

0 comments on commit 04c11f7

Please sign in to comment.