diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 947c642c2c..3a9e3f9585 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -48,7 +48,7 @@ def attention_mask_func( attention_scores: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """Get attention mask""" - attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores.masked_fill_(attention_mask, float("-inf")) return attention_scores