Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jan 8, 2026

Context

The series of PRs is trying to enable a single kernel for quantization and layout handling of block scaling factor on grouped tensors.

Existing solution for nvfp4 quantization of activation Tensor for grouped_mm relies on two operation:
i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
ii. block_scaling_factor needs to be processed by PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout required by grouped_mm kernels

The series of PRs tries to merge the two operation into a single one.

Stacked PRs

#5775 GroupedBlockQuantizationOp PR0: Adding runtime function
#5776 GroupedBlockQuantizationOp PR1: Adding codegen support
#5777 GroupedBlockQuantizationOp PR2: Adding python API and updating llama4 benchmark

What's in this PR

  1. Adding Fusion IR node GroupedBlockQuantizationOp. The operation is a combination of BlockQuantizationOp and PreprocessGroupedMatmulInputSf, where it inherits all the validation / checks from the two operations.
    The operation is similar to BlockQuantizationOp, with the exception that:
    i. The block scaling factor output doesn't have the swizzle logic represented as allocation domain transformations;
    ii. It takes an additional inputs (input_offsets and output_offsets) to facilitate group indexing, similar to PreprocessGroupedMatmulInputSf.

  2. Adding cpp test case for GroupedBlockQuantizationOp.

@jjsjann123 jjsjann123 changed the base branch from main to jj/grouped_block_quantize_op_0 January 8, 2026 00:36
@jjsjann123 jjsjann123 changed the title Jj/grouped block quantize op 1 PR1: adding codegen support for GroupedBlockQuantizationOp Jan 8, 2026
@github-actions
Copy link

github-actions bot commented Jan 8, 2026

Review updated until commit fc79a9c

Description

  • Added codegen support for GroupedBlockQuantizationOp in CudaKernelGenerator

  • Integrated GroupedBlockQuantizationOp across device lower analysis passes

  • Added GroupedBlockQuantizationOp IR node definition and infrastructure

  • Updated scheduler integration for grouped block quantization operations

  • Added comprehensive test coverage for grouped block quantization functionality

Changes walkthrough

Relevant files
Enhancement
25 files
codegen.cpp
Added GroupedBlockQuantizationOp codegen handler                 
+114/-0 
non_divisible_split.cpp
Extended non-divisible split analysis for grouped quantization
+6/-1     
sync_information.cpp
Updated sync map analysis for grouped block quantization 
+10/-5   
trivial_broadcast.cpp
Added GroupedBlockQuantizationOp broadcast domain handling
+11/-0   
index.cpp
Implemented index lowering for GroupedBlockQuantizationOp
+54/-0   
utils.cpp
Updated TV operation utilities for grouped quantization   
+1/-0     
validation.cpp
Added validation logic for GroupedBlockQuantizationOp       
+196/-1 
fusion_segmenter.cpp
Updated fusion segmentation for grouped block quantization
+6/-1     
composite_nodes.cpp
Implemented GroupedBlockQuantizationOp IR node                     
+58/-0   
utils.cpp
Updated IR utilities for grouped quantization operations 
+5/-1     
kernel.cpp
Extended kernel IR scanner for GroupedBlockQuantizationOp
+4/-0     
logical_domain_map.cpp
Updated logical domain mapping for grouped quantization   
+29/-9   
arith.cpp
Added groupedBlockQuantize arithmetic operation                   
+141/-0 
pointwise.cpp
Updated pointwise scheduler for grouped block quantization
+23/-1   
pointwise_non_tma.cpp
Extended non-TMA scheduler for grouped quantization           
+8/-1     
registry_utils.cpp
Updated scheduler topology checks for grouped operations 
+20/-0   
domain_map.cpp
Extended domain mapping tools for grouped quantization     
+13/-0   
utils.cpp
Updated scheduler utilities for grouped block quantization
+12/-7   
tensor_metadata.cpp
Extended tensor metadata handling for grouped quantization
+6/-0     
trivial_broadcast.h
Updated broadcast analysis header for grouped quantization
+2/-0     
index.h
Updated index lowering header for GroupedBlockQuantizationOp
+1/-0     
dispatch.h
Added GroupedBlockQuantizationOp to dispatch system           
+1/-0     
composite_nodes.h
Defined GroupedBlockQuantizationOp IR node interface         
+92/-0   
logical_domain_map.h
Updated logical domain map header for grouped quantization
+4/-0     
arith.h
Added groupedBlockQuantize function declaration                   
+9/-0     
Tests
1 files
test_layout_op.cpp
Added test case for GroupedBlockQuantizationOp                     
+69/-1   

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Runtime Function Interface

The codegen calls 'bq::grouped_block_quantize_to_nvfp4' runtime function but I don't see the actual runtime function implementation in this PR. Need to verify that the runtime function interface matches the function call arguments being generated, particularly the template arguments and function parameters.

indent() << genCall(
                "bq::grouped_block_quantize_to_nvfp4",
                template_args,
                func_args)
         << ";\n";
Validation Logic Completeness

The validation logic for GroupedBlockQuantizationOp is extensive but I notice it only validates output_dtype != Float8_e4m3fn on line 893-894, yet the operation creation in arith.cpp supports both Float4_e2m1fn and Float8_e4m3fn. Need to verify if Float8_e4m3fn support is complete or if there are missing validations.

void handle(GroupedBlockQuantizationOp* bqop) final {
  auto inp_tv = bqop->input(0)->as<TensorView>();
  auto quantized_output = bqop->quantizedOutput()->as<TensorView>();
  auto block_scaling_factor = bqop->blockScales()->as<TensorView>();
  auto output_dtype = quantized_output->dtype();

  NVF_ERROR_EQ(
      inp_tv->getMemoryType(),
      MemoryType::Local,
      "Input must be a local memory tensor. Found: ",
      inp_tv->getMemoryType());

  NVF_ERROR_EQ(
      quantized_output->getMemoryType(),
      MemoryType::Local,
      "Quantized output must be a local memory tensor. Found: ",
      quantized_output->getMemoryType());

  NVF_ERROR_EQ(
      block_scaling_factor->getMemoryType(),
      MemoryType::Global,
      "Block scaling factor must be a global memory tensor. Found: ",
      block_scaling_factor->getMemoryType());

  NVF_ERROR(
      output_dtype != DataType::Float8_e4m3fn,
      "output of Float8_e4m3fn is not yet implemented");
Data Type Support Consistency

The groupedBlockQuantize function creates operations with out_dtype validation but the actual GroupedBlockQuantizationOp creation doesn't explicitly pass the out_dtype parameter. Need to verify that the output data type is properly handled throughout the pipeline and matches the intended quantization format.

IrBuilder::create<GroupedBlockQuantizationOp>(
    block_scales,
    quantized_tensor,
    input,
    input_offsets,
    output_offsets,
    layout,
    inp_domain[1]->getMaybeExpandedExtent(),
    num_groups,
    global_scaling_factor,
    block_size);

@jjsjann123 jjsjann123 changed the title PR1: adding codegen support for GroupedBlockQuantizationOp GroupedBlockQuantizeOp PR1: Adding codegen support Jan 8, 2026
@jjsjann123 jjsjann123 marked this pull request as ready for review January 8, 2026 02:17
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 8, 2026

Greptile Summary

  • Adds GroupedBlockQuantizationOp IR node that merges BlockQuantizationOp and PreprocessGroupedMatmulInputSf into a single operation for performance optimization in grouped matrix multiplication scenarios
  • Implements comprehensive codegen support including dispatch registration, kernel handling, scheduling integration, and validation across all compiler passes
  • Includes test case validation ensuring the new grouped operation maintains correctness while enabling single-kernel quantization and layout handling

Important Files Changed

Filename Overview
csrc/ir/composite_nodes.h and csrc/ir/composite_nodes.cpp New GroupedBlockQuantizationOp class implementation with constructor, accessors, and evaluation methods combining quantization and layout handling functionality
csrc/codegen.cpp Added code generation handler for GroupedBlockQuantizationOp that generates runtime function calls with template arguments for block scaling layouts and group sizes
csrc/device_lower/pass/index.cpp Implemented index lowering handler with validation for 2D matrices and runtime divisibility checks for block size compatibility
csrc/device_lower/validation.cpp Added extensive validation logic duplicating BlockQuantizationOp constraints while supporting grouped indexing with ParallelType::Group
csrc/ops/arith.cpp and csrc/ops/arith.h New groupedBlockQuantize function implementation and declaration that creates IR nodes with proper domain setup and allocation handling

Confidence score: 4/5

  • This PR is generally safe to merge but requires careful review due to the complexity of the new composite operation
  • Score reflects the extensive changes across critical code generation and validation systems, though the implementation follows established patterns consistently
  • Pay close attention to the codegen handler in csrc/codegen.cpp and validation logic in csrc/device_lower/validation.cpp for correctness of the complex template argument and runtime function call generation

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

26 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1893 to +1907
GroupedBlockQuantizationOp::GroupedBlockQuantizationOp(
IrBuilderPasskey passkey,
Val* output_scales,
Val* output,
Val* input,
Val* input_offsets,
Val* output_offsets,
BlockScalingFactorLayout layout,
Val* k,
Val* g,
Val* global_scale,
int64_t block_size,
Val* row_idx,
Val* col_idx)
: Expr(passkey) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Consider adding parameter validation similar to the parent operations (BlockQuantizationOp and PreprocessGroupedMatmulInputSf) to ensure inputs meet expected constraints

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +892 to +894
NVF_ERROR(
output_dtype != DataType::Float8_e4m3fn,
"output of Float8_e4m3fn is not yet implemented");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Different error handling compared to BlockQuantizationOp - this throws an error for Float8_e4m3fn output while BlockQuantizationOp only restricts global scale usage for this type. Is this difference in Float8_e4m3fn handling intentional, or should GroupedBlockQuantizationOp support this data type with the same restrictions as BlockQuantizationOp?

Comment on lines +359 to 365
} else if (
tv->definition() && tv->definition()->isA<GroupedBlockQuantizationOp>()) {
auto bqop = tv->definition()->as<GroupedBlockQuantizationOp>();
if (tv == bqop->blockScales()) {
skip_validation = true;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing swizzle condition check for consistency with BlockQuantizationOp. The existing logic checks bqop->isSwizzledScales() but this version only checks if the tensor is blockScales. Should this also check for a swizzled condition like the BlockQuantizationOp case, or are all GroupedBlockQuantizationOp blockScales always meant to skip validation?

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +425 to +434
auto has_block_quantization_ops =
HeuristicDataCacheEntry<HeuristicCompileTime::HasBlockQuantizationOps>(
data_cache,
[fusion]() {
return std::make_unique<bool>(
!ir_utils::getOpsOfType<BlockQuantizationOp>(fusion).empty() ||
!ir_utils::getOpsOfType<GroupedBlockQuantizationOp>(fusion)
.empty());
})
.get();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Duplicate logic - the same check for block quantization ops is performed twice in this function. Consider extracting this into a helper function or reusing the cache entry from canScheduleRunTime.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +1154 to +1159
bool hasGlobalScale() const {
if (inputs().size() > 5) {
return true;
}
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Inconsistent global scale detection logic compared to BlockQuantizationOp which checks inputs().size() > 1. Is the different threshold (5 vs 1) intentional based on the different input structure?

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

auto block_scales_dtype = (out_dtype == DataType::Float4_e2m1fn)
? DataType::Float8_e4m3fn
: DataType::Float8_e8m0fnu;
NVF_ERROR_EQ(inp_domain.size(), 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This constraint requires exactly 2D input tensors - should validate this assumption earlier in the function. Should this validation be moved earlier in the function to fail fast?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants