From a326e351a1fb9c4ff8ee970a407c1f4f35f663af Mon Sep 17 00:00:00 2001 From: Marks101 <46690260+Marks101@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:40:35 +0200 Subject: [PATCH] [PyTorch] Fix issues with cross attention (#1069) Signed-off-by: Markus Schnoes Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 3 ++- transformer_engine/pytorch/transformer.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3fc805bdc6..b2fb22c8fc 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5778,7 +5778,8 @@ def forward( assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - if max_seqlen_q == max_seqlen_kv: + if self.attention_type == "self": + assert max_seqlen_q == max_seqlen_kv cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 130cf91f0e..e40653d998 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -652,7 +652,7 @@ def forward( hidden_states, attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, - window_size=enc_dec_window_size, + window_size=window_size, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -679,6 +679,8 @@ def forward( inter_attention_outputs = self.inter_attention( hidden_states, attention_mask=enc_dec_attn_mask, + attn_mask_type=enc_dec_attn_mask_type, + window_size=enc_dec_window_size, encoder_output=encoder_output, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention,