Skip to content

Commit c04ba8e

Browse files
committed
Create a constant value for MIN_MASKED_ATTN_VALUE
Signed-off-by: Amit Raj <[email protected]>
1 parent 299ef79 commit c04ba8e

21 files changed

+86
-26
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
)
8989

9090
from QEfficient.customop import CustomRMSNormAIC
91+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
9192

9293
# Placeholder for all non-transformer models
9394
from .models.codegen.modeling_codegen import (
@@ -303,12 +304,12 @@ def _prepare_cross_attention_mask(
303304
# invert the mask
304305
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
305306
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
306-
inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32)
307+
inverted_cross_attn_mask.to(torch.bool), torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)
307308
)
308309

309310
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
310311
# last dimension contains negative infinity values, otherwise it's 1
311-
negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32)
312+
negative_inf_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)
312313
full_text_row_masked_out_mask = (
313314
(cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
314315
)
@@ -338,7 +339,11 @@ def _prepare_aspect_ratio_attention_mask(
338339
# Reshape to 2D and create 4D attention mask
339340
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
340341
attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1)
341-
attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32)
342+
attention_mask = (
343+
attention_mask
344+
@ attention_mask.transpose(-1, -2)
345+
* torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)
346+
)
342347
attention_mask = attention_mask.unsqueeze(1)
343348

344349
return attention_mask

QEfficient/transformers/models/codegen/modeling_codegen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from QEfficient.transformers.cache_utils import QEffDynamicCache
2525
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
26+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
2627

2728

2829
class QEffCodeGenAttention(CodeGenAttention):
@@ -47,11 +48,10 @@ def _attn(
4748
attn_weights = torch.matmul(query, key.transpose(-1, -2))
4849

4950
attn_weights = attn_weights / self.scale_attn
50-
# Minimum value for causal mask
51-
mask_value = -10000.0
51+
5252
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
5353
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
54-
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
54+
mask_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype).to(attn_weights.device)
5555

5656
if attention_mask is not None:
5757
# Apply the attention mask

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from QEfficient.transformers.cache_utils import QEffDynamicCache
3333
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
34+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
3435

3536

3637
class QEffFalconRotaryEmbedding(FalconRotaryEmbedding):
@@ -148,7 +149,9 @@ def forward(
148149

149150
attention_scores = query_layer @ key_layer.transpose(-1, -2)
150151
attention_scores /= math.sqrt(self.head_dim)
151-
attention_scores = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attention_scores)
152+
attention_scores = torch.where(
153+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attention_scores
154+
)
152155
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
153156
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
154157
attn_output = attention_scores @ value_layer

QEfficient/transformers/models/gemma/modeling_gemma.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from QEfficient.transformers.cache_utils import QEffDynamicCache
2929
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
30+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
3031

3132

3233
class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding):
@@ -110,7 +111,9 @@ def eager_attention_forward(
110111

111112
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
112113
if attention_mask is not None:
113-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
114+
attn_weights = torch.where(
115+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
116+
)
114117

115118
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
116119
attn_output = torch.matmul(attn_weights, value_states)

QEfficient/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
# from transformers.utils import is_torchdynamo_compiling
3232
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
33+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
3334

3435

3536
class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding):
@@ -116,7 +117,9 @@ def eager_attention_forward(
116117

117118
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
118119
if attention_mask is not None:
119-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
120+
attn_weights = torch.where(
121+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
122+
)
120123

121124
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
122125
attn_output = torch.matmul(attn_weights, value_states)

QEfficient/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from QEfficient.transformers.cache_utils import QEffDynamicCache
1919
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
20+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
2021

2122

2223
def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
@@ -30,15 +31,17 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
3031
# if only "normal" attention layer implements causal mask
3132
query_length, key_length = query.size(-2), key.size(-2)
3233
causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
33-
mask_value = -10000.0
34+
mask_value = MIN_MASKED_ATTENTION_VALUE
3435
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
3536
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
3637
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
3738
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
3839

3940
if attention_mask is not None:
4041
# Apply the attention mask
41-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
42+
attn_weights = torch.where(
43+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
44+
)
4245

4346
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
4447

QEfficient/transformers/models/gptj/modeling_gptj.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from QEfficient.transformers.cache_utils import QEffDynamicCache
3030
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
31+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
3132

3233

3334
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
@@ -62,7 +63,9 @@ def _attn(
6263

6364
if attention_mask is not None:
6465
# Apply the attention mask
65-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
66+
attn_weights = torch.where(
67+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
68+
)
6669

6770
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
6871
attn_weights = attn_weights.to(value.dtype)

QEfficient/transformers/models/granite/modeling_granite.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from QEfficient.transformers.cache_utils import QEffDynamicCache
2828
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
29+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
2930

3031

3132
class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding):
@@ -107,7 +108,9 @@ def eager_attention_forward(
107108

108109
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
109110
if attention_mask is not None:
110-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
111+
attn_weights = torch.where(
112+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
113+
)
111114

112115
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
113116
attn_output = torch.matmul(attn_weights, value_states)

QEfficient/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from QEfficient.transformers.cache_utils import QEffDynamicCache
3131
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
32+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
3233

3334

3435
class QEffGraniteMoeRotaryEmbedding(GraniteMoeRotaryEmbedding):
@@ -153,7 +154,9 @@ def forward(
153154

154155
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
155156
if attention_mask is not None:
156-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
157+
attn_weights = torch.where(
158+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
159+
)
157160

158161
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
159162
dropout = 0.0 if not self.training else self.attention_dropout

QEfficient/transformers/models/llama/modeling_llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from QEfficient.transformers.cache_utils import QEffDynamicCache
2929
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
30+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
3031

3132

3233
class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
@@ -109,7 +110,9 @@ def eager_attention_forward(
109110

110111
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
111112
if attention_mask is not None:
112-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
113+
attn_weights = torch.where(
114+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
115+
)
113116

114117
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
115118
attn_output = torch.matmul(attn_weights, value_states)

0 commit comments

Comments
 (0)