Skip to content

Commit

Permalink
[PyTorch] Fix issues with cross attention (#1069)
Browse files Browse the repository at this point in the history
Signed-off-by: Markus Schnoes <[email protected]>
Co-authored-by: Charlene Yang <[email protected]>
  • Loading branch information
Marks101 and cyanguwa authored Aug 15, 2024
1 parent cc329b7 commit a326e35
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5778,7 +5778,8 @@ def forward(
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
if max_seqlen_q == max_seqlen_kv:
if self.attention_type == "self":
assert max_seqlen_q == max_seqlen_kv
cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def forward(
hidden_states,
attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
window_size=enc_dec_window_size,
window_size=window_size,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
Expand All @@ -679,6 +679,8 @@ def forward(
inter_attention_outputs = self.inter_attention(
hidden_states,
attention_mask=enc_dec_attn_mask,
attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size,
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
Expand Down

0 comments on commit a326e35

Please sign in to comment.