Skip to content

Commit b997444

Browse files
authored
Move sliding window attn before FA block for Gemma (#2187)
1 parent 050d032 commit b997444

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

Diff for: keras_hub/src/models/gemma/gemma_attention.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def _compute_attention(
133133
query_normalization = 1 / np.sqrt(
134134
self.hidden_dim // self.num_query_heads
135135
)
136+
137+
if self.use_sliding_window_attention and attention_mask is not None:
138+
attention_mask = self._mask_sliding_window(
139+
attention_mask,
140+
cache_update_index=cache_update_index,
141+
)
142+
136143
if self._can_use_flash_attention():
137144
if attention_mask is not None:
138145
attention_mask = ops.expand_dims(attention_mask, axis=1)
@@ -172,13 +179,8 @@ def _compute_attention(
172179
ops.tanh(attention_logits), self.logit_soft_cap
173180
)
174181

175-
if self.use_sliding_window_attention:
176-
attention_mask = self._mask_sliding_window(
177-
attention_mask,
178-
cache_update_index=cache_update_index,
179-
)
180-
181-
attention_mask = attention_mask[:, None, None, :, :]
182+
if attention_mask is not None:
183+
attention_mask = attention_mask[:, None, None, :, :]
182184
orig_dtype = attention_logits.dtype
183185
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
184186
attention_softmax = ops.cast(attention_softmax, orig_dtype)

0 commit comments

Comments
 (0)