-
Notifications
You must be signed in to change notification settings - Fork 288
NVfp4 #2408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
NVfp4 #2408
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2408
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d5bded3 with merge base 4e25496 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c58c5b0
to
3948f5d
Compare
3948f5d
to
1025236
Compare
1025236
to
1c007a4
Compare
1c007a4
to
a3d2874
Compare
a3d2874
to
034f892
Compare
034f892
to
92e0622
Compare
92e0622
to
b2c45a1
Compare
b2c45a1
to
7448f45
Compare
7448f45
to
fad58b5
Compare
fad58b5
to
b5a593d
Compare
b5a593d
to
2b4ba64
Compare
2b4ba64
to
5d50579
Compare
5d50579
to
29fa9ef
Compare
d85d39a
to
7f3dc05
Compare
7f3dc05
to
4dbd14e
Compare
4dbd14e
to
788e593
Compare
788e593
to
e35338c
Compare
e35338c
to
47bcbb8
Compare
Weight only fails w/ compile and bisected to: Likely from the work around to get triton to not error on e2m1 Disabling lowerings fixed the issue.
Starting bisect by getting upper bound.
Upper bound of 38 found for inductor.
Bisecting inductor - lowerings (Range: [0, 38], Midpoint: 19)
Bisecting inductor - lowerings (Range: [20, 38], Midpoint: 29)
Bisecting inductor - lowerings (Range: [30, 38], Midpoint: 34)
Bisecting inductor - lowerings (Range: [35, 38], Midpoint: 36)
Bisecting inductor - lowerings (Range: [35, 36], Midpoint: 35)
Binary search completed for inductor - lowerings. The bisect number is 36. Debug info: convert_element_type_5
Bisection status deleted.
Bisection result: BisectionResult(backend='inductor', subsystem='lowerings', bisect_number=36, debug_info='convert_element_type_5')
6. Testing inductor config workarounds for WEIGHT_ONLY:
{'inductor.coordinate_descent_tuning': False} ERROR
{'inductor.force_fuse_int_mm_with_mul': False} ERROR
{'inductor.post_grad_passes': False} ERROR
{'inductor.pattern_matcher': False} ERROR
{'inductor.epilogue_fusion': False} ERROR
{'inductor.max_autotune': False} ERROR
{'triton.autotune_pointwise': False} ✗ 3.1dB
{'inductor.benchmark_kernel': False} ERROR
{'inductor.aggressive_fusion': False} ERROR
7. Testing other compile backends:
Backend 'eager': SQNR = 20.00 dB ✓
Backend 'aot_eager': SQNR = 20.00 dB ✓
skipping cudagraphs due to skipping cudagraphs due to cpu device (_tensor_constant0). Found from :
File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 70, in inner
return fn(*args, **kwargs)
Backend 'cudagraphs': SQNR = 20.00 dB ✓ |
47bcbb8
to
c50c936
Compare
c50c936
to
22ac909
Compare
22ac909
to
e91b055
Compare
e91b055
to
efdd0b1
Compare
efdd0b1
to
b4f3d1d
Compare
@vkuzo updated to use the mm_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some comments
assert self.activation_dtype == torch.float4_e2m1fn_x2, ( | ||
f"NVFP4 requires activation_dtype=float4_e2m1fn_x2, got {self.activation_dtype}" | ||
) | ||
assert self.weight_dtype == torch.float4_e2m1fn_x2, ( | ||
f"NVFP4 requires weight_dtype=float4_e2m1fn_x2, got {self.weight_dtype}" | ||
) | ||
assert self.scale_dtype == torch.float8_e4m3fn, ( | ||
f"NVFP4 requires scale_dtype=float8_e4m3fn, got {self.scale_dtype}" | ||
) | ||
assert self.block_size == 16, ( | ||
f"NVFP4 requires block_size=16, got {self.block_size}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious. what's the point of exposing all of these when only a specific value is accepted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point, my original intent was to not make another subclass, and just merge in w/ mxfp, cc @vkuzo I imagine that we wan't this separated? I started to work on a observer for this since without it this just a worse mxfp4.
a_scale_blocked = to_blocked(a_scale) | ||
b_scale_blocked = to_blocked(b_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have wondered about this for MX-dtype as well. It makes sense for MX-dtype to have scale swizzling here since we may want the layout to be vendor-neutral. But NVFP4 is specific to NVIDIA, so why not put this under to_nvfp4()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good question, for mx I recently confirmed that AMD does not require swizziling, this is a good point about NVFP4, I am actually going update this code path to cache
the swizzled layout for the weight
b4f3d1d
to
d5bded3
Compare
Stacked PRs:
Add NVFP4 Inference flow
Details:
I kept this separate for MX but realistically we should probably merge the two. Basic support for blocksize 16 + e4m3 scales.
Double Quant Update
Ignore previous comments, the double quant is actually really similar to NF4 where you just scale the fp32 scales prior to casting to e4m3 to try and reduce
scale quant error
.I have that implemented now in the Nvfp4 code if a tesor_scale is given, just need to figure out how to thread to cublas param
scale_in_d
or how we want to expose this. We currently don't expose the C matrix to the Python API so we could use alpha as @gau-nernst pointed out to me, however we dont expose alpha either 🙃. However if we wanted to use alpha we would need the value on the host, the sync would likely rule out this option. I might keep this double quant on hold until we have the public api, since I am thinking about addingscale
overloads to addmm. However I read the cublas docs many times and it feels as though passing to scale result should work since we don't set the d_mode and its default value should work.Early Perf
No double quant here
which is even worse than mxfp4..., will profile later
Micro Bench
LLama 70B mlp no TP:
Diffusers
Errors
Annoyingly we are getting an error due to the
view as fp4x2
+ packing https://fburl.com/cd92w431 because this is trying to be bitcast iside inside triton kernel which is very annoying. Not sure how this didn't show up until vllm / w/ mxfp4^ similar to this: triton-lang/triton#6054 but make the same changes in _inductor/utils.py as we did for float8em0
Numerics
Script: https://gist.github.com/drisspg/4024ed055a6db911495102614c674c4c -> still emulating till we fix this bug in cublaslt bindings

Double quant really helps w/ tensor that have very small amax values, likely by reducing the amount of underflows will verify: