Skip to content

Conversation

albertvillanova
Copy link
Member

@albertvillanova albertvillanova commented Oct 21, 2025

Fix CUDA index out of bounds error that occurs during generation with static caches when using token type IDs for bidirectional image attention masking.

Background

After PR

changed cache initialization behavior in generate(), a latent bug in the VLM masking code was exposed. The error manifests as:

CUDA error: device-side assert triggered
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:113: operator(): block: [0,0,0], thread: [0,0,0]
Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.

Bug

In the token_type_ids_mask_function inner mask, the code correctly handles out-of-bounds kv_idx values but fails to handle out-of-bounds q_idx values.

The PR

originally fixed the bidirectional image masking by adding bounds checking for kv_idx, but overlooked that q_idx needed the same protection.

During generation with static caches:

  • Cache shapes can exceed the actual input sequence length (e.g., static cache of 2048 positions with 512 token input)
  • The masking function receives both q_idx and kv_idx that can exceed token_type_ids.shape[1]
  • Direct indexing like token_type_ids[batch_idx, q_idx] causes CUDA index out of bounds errors when q_idx >= token_type_ids.shape[1]

The code comment on line 740 already acknowledged this issue:

"NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length"

Bounds checking was implemented for kv_idx, but q_idx was overlooked.

Fix

This PR adds the same torch.where bounds-checking pattern for q_idx that already existed for kv_idx:

  1. Create safe_q_idx to clamp indices within valid range
  2. Use safe indices for tensor access
  3. Apply torch.where to mask out-of-bounds values with appropriate defaults (0 for token_type_ids, -1 for image_group_ids)

Affected Models

  • Gemma3ForConditionalGeneration
  • PaliGemmaForConditionalGeneration
  • Example modular transformer template (modeling_new_task_model.py)

Testing

This PR fixes the downstream failing test in TRL:

tests/test_grpo_trainer.py::TestGRPOTrainer::test_training_vlm_0_trl_internal_testing_tiny_Gemma3ForConditionalGeneration

See associated issue:

Related Issues

CC:

@mellowpraful
Copy link

mellowpraful commented Oct 21, 2025 via email

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma3, paligemma

@albertvillanova albertvillanova changed the title Fix CUDA index out of bounds for q_idx in Gemma3 token type masking Fix CUDA index out of bounds for q_idx in VLM token type masking for Gemma3, PaliGemma, and example modular Oct 21, 2025
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me! Can you also check pytest -k test_generate_with_static_cache tests/models/gemma3/test_modeling_gemma3.py, since it was supposed to fails for gemma3 in that case?

Prob the test doesn't pass token type ids, or is already failing on main and we didn't notice it?

@albertvillanova
Copy link
Member Author

Thanks for your review, @zucchini-nlp.

I have run the tests as requested and everything is OK:

pytest -k test_generate_with_static_cache tests/models/gemma3/test_modeling_gemma3.py

tests/models/gemma3/test_modeling_gemma3.py::Gemma3TextModelTest::test_generate_with_static_cache PASSED                                                                                            [ 50%]
tests/models/gemma3/test_modeling_gemma3.py::Gemma3Vision2TextModelTest::test_generate_with_static_cache PASSED                                                                                     [100%]

============================================================================== 2 passed, 373 deselected, 2 warnings in 9.62s ==============================================================================

@zucchini-nlp
Copy link
Member

Thanks

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, have to approve to merge

@zucchini-nlp zucchini-nlp merged commit 9a27302 into huggingface:main Oct 22, 2025
17 checks passed
@albertvillanova
Copy link
Member Author

Thank YOU, @zucchini-nlp for your fast and efficient review!

ngazagna-qc pushed a commit to ngazagna-qc/transformers that referenced this pull request Oct 23, 2025
…Gemma3, PaliGemma, and example modular (huggingface#41757)

* Fix CUDA index out of bounds for q_idx in Gemma3 token type masking

* Fix CUDA index out of bounds for q_idx in modular modeling_new_task_model

* Revert "Fix CUDA index out of bounds for q_idx in Gemma3 token type masking"

This reverts commit f8e5c2a.

* Fix CUDA index out of bounds for q_idx in PaliGemma token type masking

* Fix CUDA index out of bounds for q_idx in Gemma3 token type masking
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CI fails with dev dependencies: torch.AcceleratorError: CUDA error: device-side assert triggered

4 participants