Skip to content

MXFP Inference Tracking Doc #2229

Open
0 of 1 issue completed
Open
0 of 1 issue completed
@drisspg

Description

@drisspg

MXFP Inference and Performance Tracking

Summary

This issue tracks performance and E2E integration of MXFP formats (MXFP8, MXFP4, NVFP4) on B200 and other devices.

Status Overview

Component Status Notes
Dense Model Support ✅ Done Dense models are working E2E
MOE Support 🟡 Not Started Need to add support to MXFP8 to scaled_grouped_gemm
VLLM Integration 🟡 In Progress Works, performance inconsistencies
VLLM MXFP8 Performance 🟡 Suboptimal Currently ~11% slower than BF16 baseline
VLLM MXFP4 Performance 🟡 SubOptimal Comparable to BF16 baseline

GEMM Kernels

Format Kernel Source Performance Status Notes
MXFP8 CuBlas / ScaledMM ✅ Optimal As Advertised
MXFP4 CUTLASS/AO 🟡 Suboptimal Slower than expected on compute bound shapes / needs tuning
NVFP4 Cublas/ScaledMM 🔴 Not Benchmarked Need to wire up in TorchAO

Casting Kernels

Format Kernel Source Performance Status Notes
MXFP AO/Triton 🟡 Pretty good Dim 0 Cast is optimal same as scale swizzle but can be fused
MXFP Inductor 🟡 SubOptimal Falling back to eager needs fixes

Tasks

1. Kernel Optimization

  • Implement custom swizzle kernel for feed forward networks
    See: Add a triton kernel for swizziling #2168, currently working around Inductor
  • Investigate why inductor is falling back for swizzle kernel
    Started but need to land a fix PR: Inductor Perf MX to_blocked pytorch#153194
    Cleanup parent fallback logic pytorch#154006
  • Optimize MXFP4 kernel in AO which isn't performing as expected
    Very vanilla cutlass template we need to identify the shapes we care about and likely instantiate a few more templates
  • Implement scale caching for static weight tensors (post tensor parallelism)
    Once we load the quantized model and shared the weights for TP, the mx scales can be pre-swizzled. We should create a mechanism for caching these, will show some speed up
  • Develop single kernel for MX cast + swizzled scale generation
    These exist Add MXFP casting kernels from triton Repro #2217 in triton, Ideal end state is to have inductor produce this for us.

2. VLLM Integration

  • Add Option for NVfp4 quant scheme
  • Debug inconsistent behavior in VLLM integration
  • Optimize TTFT (Time To First Token) for MXFP8 format
  • Ensure consistent throughput across different model sizes
  • Profile memory bandwidth utilization for different formats
  • Compare latency patterns across BF16, MXFP8, and MXFP4

Traces:

Decode 8b (batch 1 gemmv):
BF16: https://fburl.com/rj5cpto2
MXFP8: https://fburl.com/xo985lyg

In this case Inductor is producing the fully unrolled gemmv: https://www.internalfb.com/intern/paste/P1817637421/ w/ no tensor-cores, I wonder what we need to do to support this cc @eellison

Performance Data

Performance by Format (Qwen2-7B-Instruct)

Format Throughput (req/s) Total Token Throughput (tok/s) Output Token Throughput (tok/s)
BF16 56.68 24053.96 11590.20
MXFP8 50.52 21443.10 10332.18
MXFP4 56.64 24039.96 11583.46

Performance by Format (Qwen2.5-72B)

Format Throughput (req/s) Total Token Throughput (tok/s) Output Token Throughput (tok/s)
BF16 26.28 11154.41 5374.66
MXFP4 25.96 11018.18 5309.02

MXFP8 Serving Benchmark

============ Serving Benchmark Result ============
Successful requests:                     1024      
Benchmark duration (s):                  13.43     
Total input tokens:                      225502    
Total generated tokens:                  185297    
Request throughput (req/s):              76.26     
Output token throughput (tok/s):         13800.30  
Total Token throughput (tok/s):          30594.93  
---------------Time to First Token----------------
Mean TTFT (ms):                          1119.68   
Median TTFT (ms):                        1100.86   
P99 TTFT (ms):                           1721.80   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          29.11     
Median TPOT (ms):                        27.07     
P99 TPOT (ms):                           46.91     
---------------Inter-token Latency----------------
Mean ITL (ms):                           23.38     
Median ITL (ms):                         17.28     
P99 ITL (ms):                            49.45     
==================================================

Numerics

Some quick LM evals:

References

  • LLama 70B feed forward with custom kernel: fburl.com/125yv8hh
  • LLama 70B feed forward without custom kernel: fburl.com/a21gwjmc
  • BF16 reference: fburl.com/2lgn9xkx
  • Non-Quantized Trace: fburl.com/sput3bmn
  • Quantized Trace: fburl.com/0pgmyrge
  • BF16 70B MLP: fburl.com/aeqm5s4v
  • MXFP8 70B MLP: fburl.com/uxgoju4r
  • MXFP4 70B MLP: fburl.com/u95f6f39
  • Eager vs. Inductor Swizzle Profile: fburl.com/kqhm91ib

Sub-issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions