diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3fc805bdc6..b2fb22c8fc 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5778,7 +5778,8 @@ def forward( assert ( attention_mask is not None ), "Please provide attention_mask for padding!" - if max_seqlen_q == max_seqlen_kv: + if self.attention_type == "self": + assert max_seqlen_q == max_seqlen_kv cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q else: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 130cf91f0e..e40653d998 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -652,7 +652,7 @@ def forward( hidden_states, attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, - window_size=enc_dec_window_size, + window_size=window_size, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -679,6 +679,8 @@ def forward( inter_attention_outputs = self.inter_attention( hidden_states, attention_mask=enc_dec_attn_mask, + attn_mask_type=enc_dec_attn_mask_type, + window_size=enc_dec_window_size, encoder_output=encoder_output, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention,