Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 18, 2025

Extends IterDomainBuilder and IterDomain::cloneWithoutRfactor for RaggedIterDomain so that utils like ops::newOutputTV can correctly create RaggedIterDomain when an input ID is ragged.

This is mainly for allowing ops like set, add, etc to not generate invalid output tensors. We are not doing lowering, so this is just exercising Fusion IR constructions. Specifically, when an input tensor of a unary op has a RaggedIterDomain, its output should create a RaggedIterDomain at the corresponding position of its logical domain.

The ops in csrc/ops/arith.h, csrc/ops/alias.h, csrc/ops/indexing.h should either generate valid tensors or immediately fail. The ops in the other files are not yet considered.

@github-actions
Copy link

github-actions bot commented Dec 18, 2025

Review updated until commit bfc3da9

Description

  • Extends IterDomainBuilder to support RaggedIterDomain cloning by adding ragged_extents field and build logic

  • Implements cloneWithoutRFactor override in RaggedIterDomain to preserve ragged structure during domain cloning

  • Adds newOutputRaggedIterDomain utility function to create ragged output domains from ragged input domains

  • Adds validation checks across ops (alias, arith, indexing) to reject RaggedIterDomain in unsupported operations

  • Adds comprehensive test coverage for RaggedIterDomain propagation through various tensor operations

Changes walkthrough

Relevant files
Enhancement
internal_base_nodes.cpp
IterDomainBuilder and RaggedIterDomain cloning implementation

csrc/ir/internal_base_nodes.cpp

  • Modified IterDomainBuilder constructor to copy ragged extents from
    RaggedIterDomain inputs
  • Added ragged_extents() setter method to IterDomainBuilder
  • Updated build() method to create RaggedIterDomain when ragged_extents
    are provided
  • Implemented RaggedIterDomain constructor accepting IterDomainBuilder
  • Added cloneWithoutRFactor() override for RaggedIterDomain to preserve
    ragged structure
  • Added hasRaggedIterDomain() method to TensorDomain for validation
  • +128/-22
    utils.cpp
    RaggedIterDomain output domain creation utilities               

    csrc/ops/utils.cpp

  • Added newOutputRaggedIterDomain() function to create ragged output
    domains from ragged inputs
  • Modified newOutputIterDomain() to detect and handle RaggedIterDomain
    inputs
  • Added validation to ensure all or none inputs are ragged when ragged
    domains detected
  • +41/-0   
    internal_base_nodes.h
    Header declarations for RaggedIterDomain cloning                 

    csrc/ir/internal_base_nodes.h

  • Updated IterDomainBuilder method signatures to use consistent
    parameter naming
  • Added ragged_extents() method declaration to IterDomainBuilder
  • Added ragged_extents_ member field to IterDomainBuilder
  • Made cloneWithoutRFactor() virtual in IterDomain base class
  • Added RaggedIterDomain constructor accepting IterDomainBuilder
  • Added cloneWithoutRFactor() override declaration for RaggedIterDomain
  • Added hasRaggedIterDomain() method declaration to TensorDomain
  • +22/-11 
    utils.h
    Utility function declaration for ragged output domains     

    csrc/ops/utils.h

  • Added newOutputRaggedIterDomain() function declaration for creating
    ragged output domains
  • +6/-0     
    Error handling
    alias.cpp
    RaggedIterDomain validation in alias operations                   

    csrc/ops/alias.cpp

  • Added hasRaggedIterDomain() checks in reshape(), flatten(), pad(),
    cat(), slice(), broadcast(), expand(), repeat(), asNested()
  • Operations now reject tensors with RaggedIterDomain to prevent invalid
    transformations
  • +48/-1   
    arith.cpp
    RaggedIterDomain validation in reduction operations           

    csrc/ops/arith.cpp

  • Added RaggedIterDomain validation check in newForReduction() to
    prevent reduction of ragged dimensions
  • Ensures reduction operations maintain tensor validity
  • +8/-0     
    indexing.cpp
    RaggedIterDomain validation in indexing operations             

    csrc/ops/indexing.cpp

  • Added hasRaggedIterDomain() checks in select(), indexSelect(),
    indexPutAccumulate(), gather(), scatter()
  • Indexing operations now reject tensors with RaggedIterDomain for
    safety
  • +52/-0   
    Tests
    test_ragged_iter_domain.cpp
    Comprehensive test coverage for RaggedIterDomain operations

    tests/cpp/test_ragged_iter_domain.cpp

  • Added LoadStoreWithNestedTensor test to verify RaggedIterDomain
    propagation through set operation
  • Added BinaryOpWithNestedTensors test to verify ragged structure
    preservation in binary ops
  • Added BinaryOpMixedInputsError test to verify error handling for mixed
    ragged/non-ragged inputs
  • Added BinaryOpDifferentRaggedStructures test for compatibility
    validation
  • Added UnaryOpWithNestedTensors test for unary operation ragged
    propagation
  • Added BroadcastWithNestedTensors test for broadcast operation ragged
    handling
  • Added SqueezeNonRaggedDim, UnsqueezeWithNestedTensors,
    PermuteWithNestedTensors tests
  • Added ReductionOnNonRaggedDim test and ReductionOnRaggedDimError test
    for reduction validation
  • Added error tests for unsupported operations:
    ReshapeWithNestedTensorsError, FlattenWithNestedTensorsError
  • Added error tests for SliceRaggedDimensionError,
    CatRaggedDimensionError, PadRaggedDimensionError
  • +483/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    RaggedIterDomain cloneWithoutRFactor implementation

    The cloneWithoutRFactor() method implementation at line 985-995 creates a clone using IterDomainBuilder(this).resetRfactor().build() but then throws "Not implemented" for the mapping functionality. This suggests the mapping feature is incomplete, though the basic cloning works. Verify this is acceptable for the current scope.

    IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) {
      auto cloned = IterDomainBuilder(this).resetRfactor().build();
    
      // Optionally map the clone with the original in the Exact graph
      if (map_with_original) {
        // TODO: Implement mapping if needed
        NVF_THROW("Not implemented");
      }
    
      return cloned;
    }
    newOutputRaggedIterDomain validation

    The newOutputRaggedIterDomain() function at line 318-335 assumes all input IDs are equivalent and uses the first one as reference. This may not be safe if input ragged domains have different extents tensors or properties. Consider if additional validation is needed.

    RaggedIterDomain* newOutputRaggedIterDomain(
        const std::vector<IterDomain*>& input_ids) {
      NVF_ERROR(
          std::ranges::all_of(
              input_ids,
              [](IterDomain* input_id) {
                return input_id->isA<RaggedIterDomain>();
              }),
          "All input iter domains must be RaggedIterDomain");
    
      NVF_ERROR(!input_ids.empty());
    
      // Just using the first ragged ID as all input IDs are assumed to be
      // equivalent
      RaggedIterDomain* ref_input_id = input_ids.front()->as<RaggedIterDomain>();
    
      return IterDomainBuilder(ref_input_id).build()->as<RaggedIterDomain>();
    }
    RaggedIterDomain operation restrictions

    Multiple operations (reshape, flatten, pad, cat, slice, expand, repeat, asNested) explicitly reject tensors with RaggedIterDomain. While this is conservative and safe, verify this aligns with the intended functionality and that these restrictions are well-documented for users.

    NVF_CHECK(
        !inp_tv->domain()->hasRaggedIterDomain(),
        "Reshape operation is not supported for tensors with RaggedIterDomain. "
        "Input tensor: ",
        inp_tv->toString());

    Test failures

    • (High, 96) CUDA driver too old for runtime – nvFuser matmul/tutorial test suite failures on dlcluster_h100

      Test Name H100 Source
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/1024_3_1_1 Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_2_1_0 Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_3_0_1 Link
      BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize64_ItemsPerThread1 Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_16_dtype_float Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_5_dtype_float Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_6_dtype_float Link
      FusionProfilerTest.Profile3Segments Link
      General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KK_512_256_128_MmaMacro_m128_n128_k16_tma_store Link
      General/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MK_512_256_128_MmaMacro_m128_n128_k16_tma_store Link
      ... with 86 more test failures omitted. Check internal logs.

    @naoyam naoyam marked this pull request as ready for review December 18, 2025 06:47
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 18, 2025

    Greptile Summary

    Extends IterDomainBuilder and RaggedIterDomain::cloneWithoutRFactor to support ragged domain cloning, enabling ops like set, add, etc. to correctly propagate RaggedIterDomain to output tensors.

    • Adds ragged_extents parameter to IterDomainBuilder for constructing RaggedIterDomain from existing ragged domains
    • Implements RaggedIterDomain::cloneWithoutRFactor override with TODO for mapping support
    • Creates newOutputRaggedIterDomain helper to generate ragged output domains from ragged inputs
    • Modifies newOutputIterDomain to detect ragged inputs and delegate to ragged-specific logic
    • Adds validation checks across alias, indexing, and arithmetic ops to reject unsupported ragged operations
    • Comprehensive test coverage validates ragged domain propagation through unary/binary ops and detects mixed ragged/non-ragged input errors

    Note: Previous thread identified a critical logic mismatch in csrc/ops/utils.cpp:352 where std::any_of is used but newOutputRaggedIterDomain requires std::all_of, which could cause runtime failures when mixing ragged and non-ragged domains.

    Confidence Score: 3/5

    • PR has one critical logic bug that will cause runtime failures when mixing ragged/non-ragged tensors, but implementation is otherwise sound
    • Score reflects a critical logic mismatch between any_of and all_of checks (previously identified) that breaks the mixed input validation test. The incomplete mapping implementation is less critical as it explicitly throws. Other changes are well-structured with comprehensive validation.
    • Pay close attention to csrc/ops/utils.cpp - the logic mismatch at line 352 must be fixed

    Important Files Changed

    Filename Overview
    csrc/ir/internal_base_nodes.cpp Implements IterDomainBuilder and RaggedIterDomain::cloneWithoutRFactor for ragged domain cloning, with comprehensive validation
    csrc/ops/utils.cpp Implements ragged domain creation for ops, but has logic mismatch between any_of check and all_of requirement
    csrc/ops/alias.cpp Adds validation checks to reject ragged tensors in unsupported operations (reshape, flatten, pad, cat, slice, expand, repeat)

    Sequence Diagram

    sequenceDiagram
        participant Op as Binary/Unary Op
        participant Utils as ops::newOutputTV
        participant NewID as newOutputIterDomain
        participant NewRID as newOutputRaggedIterDomain
        participant Builder as IterDomainBuilder
        participant RID as RaggedIterDomain
    
        Op->>Utils: Create output tensor
        Utils->>NewID: newOutputIterDomain(input_ids)
        NewID->>NewID: Check if any_of(isRaggedIterDomain)
        
        alt Has ragged input
            NewID->>NewRID: newOutputRaggedIterDomain(input_ids)
            NewRID->>NewRID: Verify all_of(isRaggedIterDomain)
            NewRID->>RID: IrBuilder::create<RaggedIterDomain>
            RID-->>NewRID: New ragged domain
            NewRID-->>NewID: Return ragged domain
        else No ragged input
            NewID->>Builder: Create standard IterDomain
            Builder-->>NewID: Return standard domain
        end
        
        NewID-->>Utils: Return new IterDomain
        Utils-->>Op: Output tensor with cloned domains
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain);
    IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension);
    IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size);
    IterDomainBuilder& ragged_extents(TensorView* _ragged_extents);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: missing implementation for ragged_extents method - declared in header but not implemented in csrc/ir/internal_base_nodes.cpp

    Suggested change
    IterDomainBuilder& ragged_extents(TensorView* _ragged_extents);
    // IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); // TODO: implement or use direct member access

    Comment on lines +330 to +332
    // Just using the first ragged ID as all input IDs are assumed to be
    // equivalent
    RaggedIterDomain* ref_input_id = input_ids.front()->as<RaggedIterDomain>();
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: assumes all input RaggedIterDomains are equivalent without validation - consider verifying extents tensors match

    Comment on lines +985 to +988
    if (map_with_original) {
    // TODO: Implement mapping if needed
    NVF_THROW("Not implemented");
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: mapping implementation missing - should call fusion()->registerExactMapping(this, cloned) like base IterDomain::cloneWithoutRFactor does (line 334)

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 18, 2025

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 18, 2025

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines 352 to 361
    bool has_ragged =
    std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) {
    return id->isA<RaggedIterDomain>();
    });

    if (has_ragged) {
    NVF_ERROR(
    !force_iter_type.has_value(),
    "force_iter_type not supported for RaggedIterDomain");
    return newOutputRaggedIterDomain(input_ids);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: logic mismatch: checks if ANY input is ragged (std::any_of) but newOutputRaggedIterDomain requires ALL inputs to be ragged (std::ranges::all_of at line 321). This will fail when mixing ragged and non-ragged domains at the same position.

    Should either:

    1. Check std::ranges::all_of here instead of std::any_of
    2. Or add validation in newOutputRaggedIterDomain to filter/handle mixed cases

    @naoyam naoyam changed the title [WIP] RaggedIterDomain cloning RaggedIterDomain cloning Dec 19, 2025
    Base automatically changed from raggediterdomain-asnested to main January 7, 2026 20:12
    Comment on lines +70 to +109
    IterDomainBuilder& IterDomainBuilder::start(Val* start) {
    start_ = start;
    return *this;
    }

    IterDomainBuilder& IterDomainBuilder::extent(Val* _extent) {
    extent_ = _extent;
    IterDomainBuilder& IterDomainBuilder::extent(Val* extent) {
    extent_ = extent;
    return *this;
    }

    IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* _expanded_extent) {
    expanded_extent_ = _expanded_extent;
    IterDomainBuilder& IterDomainBuilder::expanded_extent(Val* expanded_extent) {
    expanded_extent_ = expanded_extent;
    return *this;
    }

    IterDomainBuilder& IterDomainBuilder::stop_offset(Val* _stop_offset) {
    stop_offset_ = _stop_offset;
    IterDomainBuilder& IterDomainBuilder::stop_offset(Val* stop_offset) {
    stop_offset_ = stop_offset;
    return *this;
    }

    IterDomainBuilder& IterDomainBuilder::parallel_type(
    ParallelType _parallel_type) {
    parallel_type_ = _parallel_type;
    ParallelType parallel_type) {
    parallel_type_ = parallel_type;
    return *this;
    }

    IterDomainBuilder& IterDomainBuilder::iter_type(IterType _iter_type) {
    iter_type_ = _iter_type;
    IterDomainBuilder& IterDomainBuilder::iter_type(IterType iter_type) {
    iter_type_ = iter_type;
    return *this;
    }

    IterDomainBuilder& IterDomainBuilder::is_rfactor_domain(
    bool _is_rfactor_domain) {
    is_rfactor_domain_ = _is_rfactor_domain;
    bool is_rfactor_domain) {
    is_rfactor_domain_ = is_rfactor_domain;
    return *this;
    }

    IterDomainBuilder& IterDomainBuilder::is_padded_dimension(
    bool _is_padded_dimension) {
    is_padded_dimension_ = _is_padded_dimension;
    bool is_padded_dimension) {
    is_padded_dimension_ = is_padded_dimension;
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Just renaming

    Comment on lines +33 to +34
    IterDomainBuilder::IterDomainBuilder(Val* start, Val* extent)
    : start_(start), extent_(extent) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Just renaming

    Comment on lines +55 to +63
    IterDomainBuilder& start(Val* start);
    IterDomainBuilder& extent(Val* extent);
    IterDomainBuilder& expanded_extent(Val* expanded_extent);
    IterDomainBuilder& stop_offset(Val* stop_offset);
    IterDomainBuilder& parallel_type(ParallelType parallel_type);
    IterDomainBuilder& iter_type(IterType iter_type);
    IterDomainBuilder& is_rfactor_domain(bool is_rfactor_domain);
    IterDomainBuilder& is_padded_dimension(bool is_padded_dimension);
    IterDomainBuilder& padded_to_size(std::optional<int64_t> padded_to_size);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Just renaming

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 9, 2026

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This PR extends IterDomainBuilder and IterDomain::cloneWithoutRFactor to support RaggedIterDomain, enabling operations like set, add, etc. to correctly create ragged output tensors when inputs contain ragged dimensions.

    Key Changes:

    • IterDomainBuilder now captures ragged_extents_ from RaggedIterDomain inputs and uses it to construct ragged outputs
    • Added RaggedIterDomain::cloneWithoutRFactor override to preserve ragged type during cloning
    • Created ops::newOutputRaggedIterDomain to handle output domain creation for ragged operations
    • Added validation checks in alias.cpp, indexing.cpp, and arith.cpp to reject unsupported operations on ragged tensors

    Critical Issues Found:

    • RaggedIterDomain::cloneWithoutRFactor throws an error when map_with_original=true instead of implementing the mapping. This will cause failures if the exact mapping feature is used.
    • newOutputIterDomain has inconsistent validation logic: it checks if any input is ragged (line 348-351) but then calls newOutputRaggedIterDomain which requires all inputs to be ragged (line 320-325). Mixed inputs will pass the first check but fail the second.

    Confidence Score: 2/5

    • This PR has critical logic bugs that will cause runtime failures
    • Two critical issues prevent merging: (1) cloneWithoutRFactor throws when map_with_original=true, breaking exact mapping functionality, and (2) newOutputIterDomain has inconsistent validation that accepts mixed ragged/non-ragged inputs but then fails. These are blocking bugs, not edge cases.
    • Pay special attention to csrc/ops/utils.cpp (logic bug in newOutputIterDomain) and csrc/ir/internal_base_nodes.cpp (incomplete cloneWithoutRFactor implementation)

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/ir/internal_base_nodes.cpp 3/5 Extended IterDomainBuilder and RaggedIterDomain::cloneWithoutRFactor to support ragged domains. Critical issue: cloneWithoutRFactor throws when map_with_original=true instead of implementing the mapping.
    csrc/ops/utils.cpp 2/5 Added newOutputRaggedIterDomain and updated newOutputIterDomain to handle ragged domains. Major logic issue: inconsistent validation allows any_of ragged but requires all_of ragged, causing runtime failures.
    csrc/ops/alias.cpp 4/5 Added validation checks to reject ragged tensors for unsupported operations (reshape, flatten, pad, cat, slice, expand, repeat). Checks are consistent and well-placed.
    csrc/ops/indexing.cpp 4/5 Added validation to reject ragged tensors for indexing operations (select, indexSelect, gather, scatter). All validation checks are thorough and handle all parameters.

    Sequence Diagram

    sequenceDiagram
        participant User as User Code
        participant Builder as IterDomainBuilder
        participant RID as RaggedIterDomain
        participant Utils as ops::newOutputIterDomain
        participant Ops as Binary/Unary Ops
        
        Note over User,Ops: Creating RaggedIterDomain
        User->>Builder: IterDomainBuilder(id)
        activate Builder
        Builder->>Builder: Check if id is RaggedIterDomain
        Builder->>Builder: Set ragged_extents_ if ragged
        Builder->>Builder: build()
        alt ragged_extents_ != nullptr
            Builder->>RID: Create RaggedIterDomain
            RID-->>Builder: ragged_id
        else
            Builder->>Builder: Create IterDomain
            Builder-->>Builder: iter_id
        end
        Builder-->>User: IterDomain*
        deactivate Builder
        
        Note over User,Ops: Cloning RaggedIterDomain
        User->>RID: cloneWithoutRFactor(map_with_original)
        activate RID
        RID->>RID: Create new RaggedIterDomain
        alt map_with_original == true
            RID->>RID: NVF_THROW("Not implemented")
            RID-->>User: ❌ Error
        else
            RID-->>User: ✓ cloned RaggedIterDomain
        end
        deactivate RID
        
        Note over User,Ops: Creating Output for Operations
        User->>Ops: add(ragged_tv1, ragged_tv2)
        activate Ops
        Ops->>Utils: newOutputDomain({ragged_tv1, ragged_tv2})
        activate Utils
        Utils->>Utils: newOutputIterDomain(input_ids)
        Utils->>Utils: Check any_of(isA<RaggedIterDomain>)
        alt has_ragged == true
            Utils->>Utils: newOutputRaggedIterDomain(input_ids)
            Utils->>Utils: Check all_of(isA<RaggedIterDomain>)
            alt all inputs ragged
                Utils-->>Ops: ✓ RaggedIterDomain
            else mixed inputs
                Utils-->>Ops: ❌ Error
            end
        else
            Utils-->>Ops: ✓ Regular IterDomain
        end
        deactivate Utils
        Ops-->>User: output TensorView
        deactivate Ops
    
    Loading

    Comment on lines +318 to +328
    RaggedIterDomain* newOutputRaggedIterDomain(
    const std::vector<IterDomain*>& input_ids) {
    NVF_ERROR(
    std::ranges::all_of(
    input_ids,
    [](IterDomain* input_id) {
    return input_id->isA<RaggedIterDomain>();
    }),
    "All input iter domains must be RaggedIterDomain");

    NVF_ERROR(!input_ids.empty());
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Missing validation: mixed ragged/non-ragged inputs not checked

    If input_ids contains both RaggedIterDomain and regular IterDomain, this silently uses only the first ragged ID. Consider:

    • Adding validation that either all or none are ragged
    • Or documenting why mixed inputs are allowed

    Comment on lines 349 to 359
    bool has_ragged =
    std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) {
    return id->isA<RaggedIterDomain>();
    });

    if (has_ragged) {
    NVF_ERROR(
    !force_iter_type.has_value(),
    "force_iter_type not supported for RaggedIterDomain");
    return newOutputRaggedIterDomain(input_ids);
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Inconsistent logic: any ragged forces all ragged output

    When has_ragged=true, calls newOutputRaggedIterDomain which requires all inputs to be ragged (line 320-325). This creates a logical inconsistency:

    • Line 348-351: Accepts if any input is ragged
    • Line 320-325: Requires all inputs to be ragged

    The std::any_of check will pass with mixed inputs, then newOutputRaggedIterDomain will fail with an error.

    Comment on lines 48 to +55
    is_rfactor_domain_(id->isRFactorProduct()),
    is_padded_dimension_(id->hasPaddingToMultipleOfWarp()),
    is_clustered_dimension_(id->isClusteredBlockDim()),
    padded_to_size_(id->getMaybeSizeAfterPadding()) {}
    padded_to_size_(id->getMaybeSizeAfterPadding()) {
    if (id->isA<RaggedIterDomain>()) {
    ragged_extents_ = id->as<RaggedIterDomain>()->extents();
    }
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Consider using IterDomainBuilder consistently

    The new RaggedIterDomain::cloneWithoutRFactor (line 985-997) doesn't use IterDomainBuilder but directly calls IrBuilder::create<RaggedIterDomain>, bypassing the builder pattern established here. For consistency, both paths should use the same construction method.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    Extends IterDomainBuilder and RaggedIterDomain::cloneWithoutRFactor to support cloning ragged domains, enabling ops like set, add, and other unary/binary operations to correctly propagate RaggedIterDomain to output tensors. Adds comprehensive validation across arithmetic, alias, and indexing ops to reject unsupported operations on ragged tensors.

    Confidence Score: 3/5

    • Safe to merge with minor typo fix recommended
    • The implementation correctly extends cloning support for RaggedIterDomain through IterDomainBuilder and adds appropriate validation checks. One syntax error exists (typo in error message at csrc/ops/utils.cpp:361). The cloneWithoutRFactor mapping feature is intentionally deferred with a TODO and explicit NVF_THROW, which is acceptable for this phase. Test coverage is comprehensive with 485 new lines validating ragged domain propagation and error conditions.
    • csrc/ops/utils.cpp requires typo fix in error message

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/ops/utils.cpp 2/5 Adds ragged domain handling with typo in error message "All of none" should be "All or none"
    csrc/ir/internal_base_nodes.cpp 3/5 Implements RaggedIterDomain cloning via IterDomainBuilder constructor and cloneWithoutRFactor; mapping support deferred with TODO
    csrc/ir/internal_base_nodes.h 4/5 Adds ragged_extents builder parameter, virtual cloneWithoutRFactor, and hasRaggedIterDomain helper

    Sequence Diagram

    sequenceDiagram
        participant Op as Unary/Binary Op
        participant Utils as ops::newOutputIterDomain
        participant Builder as IterDomainBuilder
        participant RID as RaggedIterDomain
    
        Op->>Utils: newOutputIterDomain(input_ids)
        Utils->>Utils: Check if any input is ragged
        alt Has ragged inputs
            Utils->>Utils: Validate ALL inputs are ragged
            Utils->>Utils: newOutputRaggedIterDomain(input_ids)
            Utils->>Builder: IterDomainBuilder(first_ragged_id)
            Note over Builder: Copies all properties including<br/>ragged_extents from input
            Builder->>RID: build() creates RaggedIterDomain
            RID-->>Op: Return cloned RaggedIterDomain
        else No ragged inputs
            Utils->>Builder: Create regular IterDomain
            Builder-->>Op: Return regular IterDomain
        end
    
    Loading

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 9, 2026

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This PR extends IterDomainBuilder and RaggedIterDomain::cloneWithoutRFactor to enable proper ragged domain propagation through unary and binary operations.

    Key Changes:

    • IterDomainBuilder enhancement: Adds ragged_extents_ field and corresponding setter method to preserve ragged extent information during cloning
    • RaggedIterDomain cloning: Implements cloneWithoutRFactor override that uses the builder pattern to create proper ragged clones (with TODO for mapping support)
    • Output domain generation: Adds newOutputRaggedIterDomain helper and extends newOutputIterDomain with ragged detection logic that validates all-or-none ragged inputs
    • Operation validation: Adds comprehensive checks across alias, arithmetic, and indexing operations to explicitly reject unsupported ragged operations with clear error messages

    Implementation Pattern:
    When operations like add, set, or neg process ragged tensors, the logic flows through newOutputIterDomain → detects ragged inputs → validates consistency → delegates to newOutputRaggedIterDomain → uses IterDomainBuilder(ragged_id) → creates output RaggedIterDomain with preserved extents.

    Scope:
    This PR focuses on correct Fusion IR construction only—no lowering or code generation. Operations in arith.h, alias.h, and indexing.h now either generate valid ragged tensors or fail immediately with clear errors. Operations in other files are not yet addressed.

    Confidence Score: 5/5

    • Safe to merge - clean implementation of ragged domain cloning with comprehensive validation
    • The implementation follows established patterns (IterDomainBuilder, virtual override), includes extensive validation to prevent misuse, and has comprehensive test coverage. The logic correctly preserves ragged properties through the builder pattern, and all validation checks properly reject unsupported operations with informative errors. The TODO in cloneWithoutRFactor is acceptable since mapping is only needed when map_with_original=true, which appears to be an advanced use case.
    • No files require special attention

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/ir/internal_base_nodes.h 5/5 Adds ragged_extents_ field to IterDomainBuilder and makes cloneWithoutRFactor virtual. Changes are clean and follow existing patterns.
    csrc/ir/internal_base_nodes.cpp 5/5 Implements RaggedIterDomain constructor from IterDomainBuilder and cloneWithoutRFactor override. Logic correctly preserves ragged properties through the builder pattern.
    csrc/ops/utils.h 5/5 Adds newOutputRaggedIterDomain helper function declaration with clear documentation.
    csrc/ops/utils.cpp 5/5 Implements ragged domain detection and output generation. Logic correctly validates all-or-none ragged inputs and delegates to specialized helper.
    csrc/ops/alias.cpp 5/5 Adds validation checks to reject unsupported operations on ragged tensors (reshape, flatten, expand, repeat, etc.). Properly prevents multiple nesting levels.
    csrc/ops/arith.cpp 5/5 Adds validation to prevent reduction of RaggedIterDomain dimensions with clear error message.
    csrc/ops/indexing.cpp 5/5 Adds validation checks across select, indexSelect, gather, and scatter operations to reject ragged tensors with informative error messages.
    tests/cpp/test_ragged_iter_domain.cpp 5/5 Comprehensive test coverage validates ragged domain propagation through unary/binary ops, mixed input errors, and unsupported operations.

    Sequence Diagram

    sequenceDiagram
        participant Op as Binary/Unary Op
        participant Utils as ops::newOutputTV
        participant NewOut as newOutputIterDomain
        participant Builder as IterDomainBuilder
        participant RaggedOut as newOutputRaggedIterDomain
        participant RaggedID as RaggedIterDomain
    
        Op->>Utils: create output tensor
        Utils->>NewOut: for each logical ID pair
        
        alt Input has RaggedIterDomain
            NewOut->>NewOut: std::any_of checks ragged
            NewOut->>NewOut: std::all_of validates all ragged
            NewOut->>RaggedOut: create ragged output
            RaggedOut->>Builder: IterDomainBuilder(ref_ragged_id)
            Note over Builder: Copies ragged_extents_ field
            Builder->>Builder: build()
            Builder->>RaggedID: create RaggedIterDomain(args)
            RaggedID-->>RaggedOut: new ragged ID
            RaggedOut-->>NewOut: return ragged output
        else Normal IterDomain
            NewOut->>NewOut: compute extent/offsets
            NewOut->>Builder: create normal ID
            Builder-->>NewOut: normal ID
        end
        
        NewOut-->>Utils: output ID
        Utils-->>Op: output TensorView
    
    Loading

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants