Skip to content

Conversation

@protonu
Copy link
Collaborator

@protonu protonu commented Dec 23, 2025

This add quantized scaled MM ops to our Python benchmark.

This will create/quantize the module to:

        (feed_forward): Llama4MoE(
          (gate): NVFP4InferenceLinear()
          (shared_experts): NVFP4InferenceSwiGLU(
            (gate_proj): NVFP4InferenceLinear()
            (up_proj): NVFP4InferenceLinear()
            (down_proj): NVFP4InferenceLinear()
          )
          (routed_experts): NVFP4InferenceGroupedSwiGLU(
            (gate_proj): NVFP4InferenceGroupedLinear()
            (up_proj): NVFP4InferenceGroupedLinear()
            (down_proj): NVFP4InferenceGroupedLinear()
          )
        )

There was a small bug fixed.
When inferring the output allocation we don't call tensor_.view when one of the split was not a divisible split.
This problem shows up when we pad the inner dimension by 4, and the "padded" outer split dimension was one.

@protonu protonu changed the title Pbasu/nvfp4 linear bench [DO NOT REVIEW] benchmark for nvfp4 scaled mm Dec 23, 2025
@github-actions
Copy link

github-actions bot commented Jan 5, 2026

Review updated until commit 5a40972

Description

  • Add quantized scaled MM operations to Python benchmark for NVFP4 inference

  • Implement NVFP4InferenceLinear and NVFP4InferenceSwiGLU classes for efficient inference

  • Register nvfuser_f16a_nvfp4weight_scaled_mm custom operation with Thunder

  • Fix tensor view bug when split dimensions are not divisible

Changes walkthrough

Relevant files
Bug fix
allocations.cpp
Fix tensor view bug for non-divisible splits                         

csrc/runtime/allocations.cpp

  • Add divisibility check before calling tensor_.view() for dimension
    merging
  • Evaluate split factor and input extent to determine if split is
    divisible
  • Only merge contiguous dimensions when split is actually divisible
  • +5/-1     
    Enhancement
    benchmark_inference.py
    Add NVFP4 scaled MM benchmark integration                               

    benchmarks/python/benchmark_inference.py

  • Import new NVFP4 layer classes and scaled MM function
  • Register nvfuser_f16a_nvfp4weight_scaled_mm custom operation
  • Add nvfp4_scaled_mm_translator for Thunder integration
  • Quantize SwiGLU and gate projection layers in Llama4MoE model
  • +57/-0   
    layers_for_inference_benchmark.py
    Implement NVFP4 inference layer classes                                   

    benchmarks/python/layers_for_inference_benchmark.py

  • Implement nvfuser_f16a_nvfp4weight_scaled_mm custom operation
  • Add NVFP4InferenceLinear class for quantized linear inference
  • Add NVFP4InferenceSwiGLU class for quantized SwiGLU inference
  • Include fake implementations for custom operations
  • +154/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review
    Missing Tests

    This PR adds significant new functionality with NVFP4InferenceLinear, NVFP4InferenceSwiGLU classes, and custom ops, but no tests are included. Given that this is a performance-critical benchmark with quantization logic, tests should verify correctness of the new layers, custom ops, and quantization conversions.

    _replace_with_custom_fn_if_matches_filter_with_name(
        model,
        NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu,
        lambda model, cur_fqn: isinstance(model, GroupedSwiGLU),
    )
    
    _replace_with_custom_fn_if_matches_filter_with_name(
        model,
        NVFP4InferenceSwiGLU.from_swiglu,
        lambda model, cur_fqn: isinstance(model, SwiGLU),
    )
    
    # Find and return all submodules of the model that are instances of Llama4MoE
    def _find_llama4moe_recursive(module):
        found = []
        for child in module.children():
            if isinstance(child, Llama4MoE):
                found.append(child)
            found.extend(_find_llama4moe_recursive(child))
        return found
    
    llama4moe_module = _find_llama4moe_recursive(model)
    assert len(llama4moe_module) == 1, f"Expected exactly one Llama4MoE module, found {len(llama4moe_module)}"
    
    # Quantize the gate projection layer
    llama4moe_module[0].gate = NVFP4InferenceLinear.from_linear(llama4moe_module[0].gate)
    Potential Performance Issue

    The NVFP4InferenceLinear.forward() method reshapes input to 2D (line 613) but doesn't restore the original 3D shape. This could cause issues if the input has more than 2 dimensions, potentially breaking the expected tensor shape for subsequent operations.

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Forward pass using nvfp4_scaled_mm.
    
        Args:
            hidden_states: Input tensor of shape [batch, seq_len, in_features]
        """
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
    
        # Use nvfp4_scaled_mm which handles the full computation
        output = torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm(
            hidden_states,
            self.fp4_weight,
            self.weight_scaling_factor,
            self.weight_global_scale,
        )
    
        return output
    Incomplete Error Handling

    The fake registration for nvfuser_f16a_nvfp4weight_scaled_mm validates device consistency and dimensionality but doesn't validate critical tensor shapes and dtypes that are essential for the quantized matmul operation to work correctly.

    @torch.library.register_fake("nvf_cutlass::f16a_nvfp4weight_scaled_mm")
    def _(
        activation: torch.Tensor,
        fp4_weight: torch.Tensor,
        weight_scaling_factor: torch.Tensor,
        weight_global_scale: torch.Tensor,
    ) -> torch.Tensor:
        # fp4_weight shape: (in_features // 2, out_features)
        # Validate that activation has at least 1 dimension
        if activation.ndim == 0:
            raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}")
    
    
        if (
            len(
                {
                    t.device
                    for t in [
                        activation,
                        fp4_weight,
                        weight_scaling_factor,
                        weight_global_scale,
                    ]
                }
            )
            != 1
        ):
            raise ValueError("Expected all inputs to be on the same device.")
    
    
        a = torch.empty((activation.shape[0], fp4_weight.t().shape[0]), device=activation.device, dtype=torch.bfloat16)
        return a

    @protonu protonu changed the title [DO NOT REVIEW] benchmark for nvfp4 scaled mm Benchmark for nvfp4 scaled mm Jan 5, 2026
    @protonu protonu marked this pull request as ready for review January 5, 2026 22:21
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 5, 2026

    Greptile Summary

    adds NVFP4 quantized scaled matmul operations for non-grouped linear layers in the Python benchmark

    Changes

    • benchmarks/python/layers_for_inference_benchmark.py: added NVFP4InferenceLinear and NVFP4InferenceSwiGLU classes for quantized inference with custom op nvfuser_f16a_nvfp4weight_scaled_mm
    • benchmarks/python/benchmark_inference.py: registered the new custom op with Thunder/nvFuser translator, added quantization logic to convert SwiGLU modules and the gate projection in Llama4MoE to NVFP4 versions
    • csrc/runtime/allocations.cpp: fixed bug where tensor_.view() was incorrectly called on non-divisible splits when padding creates non-divisible dimensions (e.g., padded inner dimension by 4 with outer split dimension of 1)

    Issues Found

    • shape inconsistency in NVFP4InferenceLinear.forward where input is flattened but output shape isn't restored to match the documented [batch, seq_len, in_features] input format

    Confidence Score: 3/5

    • generally safe but contains shape handling bug that could cause runtime issues
    • C++ bug fix is correct and necessary. Python additions follow existing patterns but have a shape inconsistency issue where NVFP4InferenceLinear.forward flattens input without restoring original shape, inconsistent with docstring and grouped version implementation. This could cause downstream shape mismatches.
    • pay close attention to benchmarks/python/layers_for_inference_benchmark.py - the NVFP4InferenceLinear.forward method needs shape restoration logic

    Important Files Changed

    Filename Overview
    csrc/runtime/allocations.cpp fixes bug where .view() was called on non-divisible splits during allocation inference by adding divisibility check
    benchmarks/python/layers_for_inference_benchmark.py adds NVFP4InferenceLinear and NVFP4InferenceSwiGLU for non-grouped quantized matmul operations; fake implementation has shape mismatch with forward method's flattening behavior
    benchmarks/python/benchmark_inference.py registers nvfp4 scaled_mm custom op, adds translator for Thunder, and quantizes gate projection and SwiGLU modules in Llama4MoE

    Sequence Diagram

    sequenceDiagram
        participant Benchmark as benchmark_inference.py
        participant Quantize as _quantize_llama4()
        participant Linear as NVFP4InferenceLinear
        participant Op as nvfuser_f16a_nvfp4weight_scaled_mm
        participant Thunder as Thunder/nvFuser
        participant Alloc as allocations.cpp
        
        Note over Benchmark: Register custom ops
        Benchmark->>Thunder: _register_nvfp4_ops()
        Thunder->>Thunder: register nvfp4_scaled_mm_symbol
        Thunder->>Thunder: register nvfp4_scaled_mm_translator
        
        Note over Benchmark: Model preparation
        Benchmark->>Quantize: _quantize_llama4(model)
        Quantize->>Quantize: replace GroupedSwiGLU with NVFP4InferenceGroupedSwiGLU
        Quantize->>Quantize: replace SwiGLU with NVFP4InferenceSwiGLU
        Quantize->>Quantize: find Llama4MoE modules
        Quantize->>Linear: gate = NVFP4InferenceLinear.from_linear()
        Linear->>Linear: quantize_linear_weight_to_nvfp4()
        Linear-->>Quantize: NVFP4InferenceLinear instance
        
        Note over Benchmark: Forward pass
        Benchmark->>Linear: forward(hidden_states)
        Linear->>Linear: flatten: view(-1, in_features)
        Linear->>Op: f16a_nvfp4weight_scaled_mm(activation, fp4_weight, ...)
        Op->>Op: dequantize_to_dtype()
        Op->>Op: torch.nn.functional.linear()
        Op-->>Linear: output (2D, bfloat16)
        Linear-->>Benchmark: output (shape not restored)
        
        Note over Thunder: nvFuser translation
        Thunder->>Thunder: nvfp4_scaled_mm_translator()
        Thunder->>Thunder: nv_block_quantize(activation)
        Thunder->>Thunder: scaled_mm(quantized_activation, fp4_weight, ...)
        
        Note over Alloc: Bug fix for padded dimensions
        Alloc->>Alloc: handle(Split)
        Alloc->>Alloc: check: is_divisible = (in_extent % factor == 0)
        alt is_divisible && contiguous
            Alloc->>Alloc: tensor_.view(new_shape)
        else non-divisible or non-contiguous
            Alloc->>Alloc: tensor_.as_strided(tensor_new_shape, strides)
        end
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @protonu
    Copy link
    Collaborator Author

    protonu commented Jan 5, 2026

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @protonu
    Copy link
    Collaborator Author

    protonu commented Jan 5, 2026

    !test

    @protonu protonu requested review from jjsjann123 and tbqh January 5, 2026 22:51
    @protonu
    Copy link
    Collaborator Author

    protonu commented Jan 6, 2026

    !test

    @protonu
    Copy link
    Collaborator Author

    protonu commented Jan 6, 2026

    !test

    @protonu
    Copy link
    Collaborator Author

    protonu commented Jan 7, 2026

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    raise ValueError("Expected all inputs to be on the same device.")


    a = torch.empty((activation.shape[0], fp4_weight.t().shape[0]), device=activation.device, dtype=torch.bfloat16)
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: output shape assumes 2D activation but doesn't preserve batch structure. If activation was originally 3D (batch, seq, hidden) and flattened to 2D (batch*seq, hidden) before this call, the output remains 2D (batch*seq, out_features) instead of restoring to 3D (batch, seq, out_features). This is inconsistent with the docstring in the forward method at line 611 which documents input as [batch, seq_len, in_features].

    Suggested change
    a = torch.empty((activation.shape[0], fp4_weight.t().shape[0]), device=activation.device, dtype=torch.bfloat16)
    output_shape = activation.shape[:-1] + (fp4_weight.size(1),)
    return torch.empty(output_shape, device=activation.device, dtype=torch.bfloat16)

    Comment on lines +607 to +623
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    """Forward pass using nvfp4_scaled_mm.
    Args:
    hidden_states: Input tensor of shape [batch, seq_len, in_features]
    """
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

    # Use nvfp4_scaled_mm which handles the full computation
    output = torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm(
    hidden_states,
    self.fp4_weight,
    self.weight_scaling_factor,
    self.weight_global_scale,
    )

    return output
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: flattens input to 2D but doesn't restore original shape before returning. Docstring says input is [batch, seq_len, in_features] but output remains flattened as (batch*seq, out_features). Compare with grouped version's fake (lines 332-336) which preserves input shape.

    need to store original shape and restore:

    original_shape = hidden_states.shape
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
    output = torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm(...)
    return output.view(*original_shape[:-1], -1)
    

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants