Open
0 of 1 issue completedDescription
MXFP Inference and Performance Tracking
Summary
This issue tracks performance and E2E integration of MXFP formats (MXFP8, MXFP4, NVFP4) on B200 and other devices.
Status Overview
Component | Status | Notes |
---|---|---|
Dense Model Support | ✅ Done | Dense models are working E2E |
MOE Support | 🟡 Not Started | Need to add support to MXFP8 to scaled_grouped_gemm |
VLLM Integration | 🟡 In Progress | Works, performance inconsistencies |
VLLM MXFP8 Performance | 🟡 Suboptimal | Currently ~11% slower than BF16 baseline |
VLLM MXFP4 Performance | 🟡 SubOptimal | Comparable to BF16 baseline |
GEMM Kernels
Format | Kernel Source | Performance Status | Notes |
---|---|---|---|
MXFP8 | CuBlas / ScaledMM | ✅ Optimal | As Advertised |
MXFP4 | CUTLASS/AO | 🟡 Suboptimal | Slower than expected on compute bound shapes / needs tuning |
NVFP4 | Cublas/ScaledMM | 🔴 Not Benchmarked | Need to wire up in TorchAO |
Casting Kernels
Format | Kernel Source | Performance Status | Notes |
---|---|---|---|
MXFP | AO/Triton | 🟡 Pretty good | Dim 0 Cast is optimal same as scale swizzle but can be fused |
MXFP | Inductor | 🟡 SubOptimal | Falling back to eager needs fixes |
Tasks
1. Kernel Optimization
- Implement custom swizzle kernel for feed forward networks
See: Add a triton kernel for swizziling #2168, currently working around Inductor - Investigate why inductor is falling back for swizzle kernel
Started but need to land a fix PR: Inductor Perf MX to_blocked pytorch#153194
Cleanup parent fallback logic pytorch#154006 - Optimize MXFP4 kernel in AO which isn't performing as expected
Very vanilla cutlass template we need to identify the shapes we care about and likely instantiate a few more templates - Implement scale caching for static weight tensors (post tensor parallelism)
Once we load the quantized model and shared the weights for TP, the mx scales can be pre-swizzled. We should create a mechanism for caching these, will show some speed up - Develop single kernel for MX cast + swizzled scale generation
These exist Add MXFP casting kernels from triton Repro #2217 in triton, Ideal end state is to have inductor produce this for us.
2. VLLM Integration
- Add Option for NVfp4 quant scheme
- Debug inconsistent behavior in VLLM integration
- Optimize TTFT (Time To First Token) for MXFP8 format
- Ensure consistent throughput across different model sizes
- Profile memory bandwidth utilization for different formats
- Compare latency patterns across BF16, MXFP8, and MXFP4
Traces:
Decode 8b (batch 1 gemmv):
BF16: https://fburl.com/rj5cpto2
MXFP8: https://fburl.com/xo985lyg
In this case Inductor is producing the fully unrolled gemmv: https://www.internalfb.com/intern/paste/P1817637421/ w/ no tensor-cores, I wonder what we need to do to support this cc @eellison
Performance Data
Performance by Format (Qwen2-7B-Instruct)
Format | Throughput (req/s) | Total Token Throughput (tok/s) | Output Token Throughput (tok/s) |
---|---|---|---|
BF16 | 56.68 | 24053.96 | 11590.20 |
MXFP8 | 50.52 | 21443.10 | 10332.18 |
MXFP4 | 56.64 | 24039.96 | 11583.46 |
Performance by Format (Qwen2.5-72B)
Format | Throughput (req/s) | Total Token Throughput (tok/s) | Output Token Throughput (tok/s) |
---|---|---|---|
BF16 | 26.28 | 11154.41 | 5374.66 |
MXFP4 | 25.96 | 11018.18 | 5309.02 |
MXFP8 Serving Benchmark
============ Serving Benchmark Result ============
Successful requests: 1024
Benchmark duration (s): 13.43
Total input tokens: 225502
Total generated tokens: 185297
Request throughput (req/s): 76.26
Output token throughput (tok/s): 13800.30
Total Token throughput (tok/s): 30594.93
---------------Time to First Token----------------
Mean TTFT (ms): 1119.68
Median TTFT (ms): 1100.86
P99 TTFT (ms): 1721.80
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 29.11
Median TPOT (ms): 27.07
P99 TPOT (ms): 46.91
---------------Inter-token Latency----------------
Mean ITL (ms): 23.38
Median ITL (ms): 17.28
P99 ITL (ms): 49.45
==================================================
Numerics
Some quick LM evals:
References
- LLama 70B feed forward with custom kernel: fburl.com/125yv8hh
- LLama 70B feed forward without custom kernel: fburl.com/a21gwjmc
- BF16 reference: fburl.com/2lgn9xkx
- Non-Quantized Trace: fburl.com/sput3bmn
- Quantized Trace: fburl.com/0pgmyrge
- BF16 70B MLP: fburl.com/aeqm5s4v
- MXFP8 70B MLP: fburl.com/uxgoju4r
- MXFP4 70B MLP: fburl.com/u95f6f39
- Eager vs. Inductor Swizzle Profile: fburl.com/kqhm91ib