Skip to content

Commit 7117049

Browse files
albertvillanovangazagna-qc
authored andcommitted
Fix CUDA index out of bounds for q_idx in VLM token type masking for Gemma3, PaliGemma, and example modular (huggingface#41757)
* Fix CUDA index out of bounds for q_idx in Gemma3 token type masking * Fix CUDA index out of bounds for q_idx in modular modeling_new_task_model * Revert "Fix CUDA index out of bounds for q_idx in Gemma3 token type masking" This reverts commit f8e5c2a. * Fix CUDA index out of bounds for q_idx in PaliGemma token type masking * Fix CUDA index out of bounds for q_idx in Gemma3 token type masking
1 parent 2725f83 commit 7117049

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

examples/modular-transformers/modeling_new_task_model.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,23 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
125125
# If it's 1 for both query and key/value, we are in an image block
126126
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
127127
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
128-
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
129-
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
128+
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
129+
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
130+
131+
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
132+
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
133+
134+
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
130135
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
131136

132-
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
137+
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
138+
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
139+
140+
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
133141
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
134142

135-
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
136-
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
143+
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
144+
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
137145

138146
# This is bidirectional attention whenever we are dealing with image tokens
139147
return is_image_block & same_image_block

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -768,15 +768,23 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
768768
# If it's 1 for both query and key/value, we are in an image block
769769
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
770770
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
771-
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
772-
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
771+
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
772+
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
773+
774+
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
775+
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
776+
777+
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
773778
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
774779

775-
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
780+
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
781+
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
782+
783+
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
776784
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
777785

778-
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
779-
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
786+
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
787+
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
780788

781789
# This is bidirectional attention whenever we are dealing with image tokens
782790
return is_image_block & same_image_block

src/transformers/models/paligemma/modeling_paligemma.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,23 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
116116
# If it's 1 for both query and key/value, we are in an image block
117117
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
118118
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
119-
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
120-
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
119+
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
120+
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
121+
122+
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
123+
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
124+
125+
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
121126
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
122127

123-
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
128+
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
129+
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
130+
131+
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
124132
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
125133

126-
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
127-
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
134+
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
135+
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
128136

129137
# This is bidirectional attention whenever we are dealing with image tokens
130138
return is_image_block & same_image_block

0 commit comments

Comments
 (0)