@@ -133,6 +133,13 @@ def _compute_attention(
133
133
query_normalization = 1 / np .sqrt (
134
134
self .hidden_dim // self .num_query_heads
135
135
)
136
+
137
+ if self .use_sliding_window_attention and attention_mask is not None :
138
+ attention_mask = self ._mask_sliding_window (
139
+ attention_mask ,
140
+ cache_update_index = cache_update_index ,
141
+ )
142
+
136
143
if self ._can_use_flash_attention ():
137
144
if attention_mask is not None :
138
145
attention_mask = ops .expand_dims (attention_mask , axis = 1 )
@@ -172,13 +179,8 @@ def _compute_attention(
172
179
ops .tanh (attention_logits ), self .logit_soft_cap
173
180
)
174
181
175
- if self .use_sliding_window_attention :
176
- attention_mask = self ._mask_sliding_window (
177
- attention_mask ,
178
- cache_update_index = cache_update_index ,
179
- )
180
-
181
- attention_mask = attention_mask [:, None , None , :, :]
182
+ if attention_mask is not None :
183
+ attention_mask = attention_mask [:, None , None , :, :]
182
184
orig_dtype = attention_logits .dtype
183
185
attention_softmax = self .softmax (attention_logits , mask = attention_mask )
184
186
attention_softmax = ops .cast (attention_softmax , orig_dtype )
0 commit comments