diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b921888633..227c5c178d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -575,7 +575,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_dtype('cuda') + self.activation_dtype = torch.get_autocast_dtype("cuda") return # All checks after this have already been performed once, thus skip