-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Fix Qwen3Next dtype API usage #41735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Qwen3Next dtype API usage #41735
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, let's revert the unrelated changes tho!
src/transformers/generation/utils.py
Outdated
| # Safety: if the model is sharded across multiple devices (hf_device_map/device_map) and we are | ||
| # doing sampling, enable `remove_invalid_values` by default to avoid NaN/Inf logits causing CUDA | ||
| # asserts during multinomial sampling. Users can still override this by passing the flag explicitly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated changes? Also in logits_process.py
|
Yeah, if you could revert the unrelated changes this PR is good! |
Replace torch.get_current_dtype() with torch.get_default_dtype() to fix FLA compatibility
15485b7 to
6717030
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen3_next |
|
Thanks for pointing it out, i have made this pr clean and added other changes to #41734. Sorry for the delay. Please check and provide me the feedback or any further improvements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the PR
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Replace torch.get_current_dtype() with torch.get_default_dtype() to fix FLA compatibility
This PR fixes an invalid PyTorch API usage in the Qwen3Next model.
Changes
torch.get_current_dtype()withtorch.get_default_dtype()in both modular and modeling filesTechnical Details
get_current_dtype()get_default_dtype()which returns the global default dtype settingTesting
The changes have been tested with
make fixupand pass the repository consistency checks.Fix #41732