Skip to content

Conversation

@SrijanUpadhyay
Copy link
Contributor

@SrijanUpadhyay SrijanUpadhyay commented Oct 19, 2025

This PR fixes an invalid PyTorch API usage in the Qwen3Next model.

Changes

  • Replaced torch.get_current_dtype() with torch.get_default_dtype() in both modular and modeling files
  • This change fixes FLA (Flash Linear Attention) compatibility when using different dtypes like float32/float16

Technical Details

  • The original code was using a non-existent PyTorch API get_current_dtype()
  • The correct API is get_default_dtype() which returns the global default dtype setting
  • This change ensures proper dtype handling in the Qwen3Next GatedDeltaNet normalization layer

Testing

The changes have been tested with make fixup and pass the repository consistency checks.

Fix #41732

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, let's revert the unrelated changes tho!

Comment on lines 1848 to 1850
# Safety: if the model is sharded across multiple devices (hf_device_map/device_map) and we are
# doing sampling, enable `remove_invalid_values` by default to avoid NaN/Inf logits causing CUDA
# asserts during multinomial sampling. Users can still override this by passing the flag explicitly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated changes? Also in logits_process.py

@Rocketknight1
Copy link
Member

Yeah, if you could revert the unrelated changes this PR is good!

Replace torch.get_current_dtype() with torch.get_default_dtype() to fix FLA compatibility
@SrijanUpadhyay SrijanUpadhyay force-pushed the fix-qwen3next-dtype-api branch from 15485b7 to 6717030 Compare October 23, 2025 04:36
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3_next

@SrijanUpadhyay
Copy link
Contributor Author

Thanks for pointing it out, i have made this pr clean and added other changes to #41734. Sorry for the delay. Please check and provide me the feedback or any further improvements.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the PR

@vasqu vasqu enabled auto-merge (squash) October 23, 2025 11:54
@vasqu vasqu merged commit d4562bb into huggingface:main Oct 23, 2025
17 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

ngazagna-qc pushed a commit to ngazagna-qc/transformers that referenced this pull request Oct 24, 2025
Replace torch.get_current_dtype() with torch.get_default_dtype() to fix FLA compatibility
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen3Next compatibilty issue with flash-linear-attention

4 participants