Skip to content

Commit b50eef2

Browse files
committed
Value update for mask
Signed-off-by: Amit Raj <[email protected]>
1 parent bc01da8 commit b50eef2

File tree

3 files changed

+2
-6
lines changed

3 files changed

+2
-6
lines changed

QEfficient/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,9 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
3131
# if only "normal" attention layer implements causal mask
3232
query_length, key_length = query.size(-2), key.size(-2)
3333
causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
34-
mask_value = MIN_MASKED_ATTENTION_VALUE
3534
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
3635
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
37-
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
36+
mask_value = torch.full([], MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype, device=attn_weights.device)
3837
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
3938

4039
if attention_mask is not None:

QEfficient/transformers/models/mllama/modeling_mllama.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ def forward(
179179
if attention_mask is not None: # no matter the length, we just slice it
180180
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
181181
attn_weights = attn_weights + causal_mask
182-
# attn_weights = torch.where(
183-
# attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
184-
# )
185182

186183
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
187184
attn_output = torch.matmul(attn_weights, value_states)

QEfficient/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
DEFAULT_AIC_NUM_CORES = 16
3030
DEFAULT_AIC_MXPF6_MATMUL = False
3131
# Minimum value for causal mask
32-
MIN_MASKED_ATTENTION_VALUE = -1e4
32+
MIN_MASKED_ATTENTION_VALUE = float("-inf")
3333

3434

3535
# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable.

0 commit comments

Comments
 (0)