[Metal] Reject tensor-scale nvfp4 in qqmm#3551
Open
Brooooooklyn wants to merge 1 commit into
Open
Conversation
QQMatmul::eval_gpu on Metal silently dropped global_scale_x / global_scale_w in the gemv special case (pre-quantized w, M==1), producing numerically incorrect results when tensor-scale nvfp4 weights were in use. The general case already throws NYI, and quantize/dequantize already reject the same combination at the op level. Mirror those guards in qqmm() so the request is rejected at graph-construction time rather than silently mis-computed. Fixes ml-explore#3550.
20f4211 to
0964646
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
QQMatmul::eval_gpuon Metal silently droppedglobal_scale_x/global_scale_win its gemv special case (pre-quantizedw,x.shape(-2) == 1), producing numerically incorrect results when tensor-scale nvfp4 weights were in use. The general case already throws[QQMatmul] NYI for the general case, andquantize()/dequantize()already reject the same combination at the op level (mlx/ops.cpp:4940-4945,5205-5210).qqmm()was missed when tensor-scale nvfp4 landed in #3022.This change mirrors the existing guards:
qqmm()now throwsstd::invalid_argument(PythonValueError) at graph-construction time when the stream is a Metal GPU and eitherglobal_scale_xorglobal_scale_wis set, rather than letting the request reacheval_gpuwhere it is silently mis-computed. CUDA is unaffected.Per-group nvfp4 (no
global_scale) on Metal continues to work — that path is exercised by the existingtest_qqmvand is unchanged.Fixes #3550.
Test plan
test_qqmm_metal_global_scale_rejectedinpython/tests/test_quantized.py, assertsValueErrorwhenmx.qqmmis called on Metal with both global scales set. Verified the test fails onmain(silently runs to completion) and passes with this fix.python/tests/test_quantized.py(29 tests) still passes locally on Apple Silicon, includingtest_qqmvwhich exercises the gemv branch being guarded.