Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@github-actions
Copy link

github-actions bot commented Jan 6, 2026

Review updated until commit 7397013

Description

  • Add LaunchParams support to GpuLower for dynamic shape handling

  • Update warp specialization logic to work with dynamic dimensions

  • Modify reduction scheduling to remove static bdimx requirements

  • Pass launch constraints through compilation pipeline

  • Update tests to use symbolic tensors for dynamic shape validation

Changes walkthrough

Relevant files
Enhancement
8 files
fused_reduction.cpp
Add dynamic shape support for warp reduction checks           
+18/-5   
lower2device.cpp
Add LaunchParams to GpuLower constructor                                 
+6/-2     
parallel_dimension_map.cpp
Handle dynamic dimensions with launch parameters                 
+9/-4     
compiled_kernel.cpp
Pass LaunchParams through compilation pipeline                     
+5/-1     
executor.cpp
Include launch constraints in kernel compilation                 
+2/-0     
reduction_utils.cpp
Remove static bdimx requirement for warp specialization   
+9/-3     
lower2device.h
Update GpuLower interface for LaunchParams                             
+7/-1     
compiled_kernel.h
Update CompiledKernel interface for LaunchParams                 
+2/-0     
Tests
2 files
test_combined_inner_outer_reduction.cpp
Use symbolic tensors for dynamic shape testing                     
+5/-2     
test_persistent_buffer.cpp
Update tests for dynamic shapes and add kernel reuse test
+86/-5   

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Warp reduction logic complexity

The logic for determining static warp reductions has become more complex with the addition of dynamic shape support. The nested conditions checking has_static_bdimx, has_warp_specialization, and extent validation should be carefully reviewed to ensure correctness across all edge cases.

auto is_static_warp_reduction = [](TensorView* out,
                                   bool has_warp_specialization) {
  // Check if bdimx is statically known in launch params
  bool has_static_bdimx = GpuLower::hasCurrent() &&
      GpuLower::current()->launchParams().hasDim(ParallelType::TIDx);

  if (!has_warp_specialization && !has_static_bdimx) {
    return false;
  }

  constexpr int64_t kThreadsPerWarp = 32L;
  int reduction_count = 0;
  bool has_valid_tidx_reduction = false;
  for (auto ld : out->getLoopDomain()) {
    if (ld->isReduction()) {
      reduction_count++;
      if (ld->getParallelType() == ParallelType::TIDx) {
        // Get extent either from launch params or from the const extent
        std::optional<int64_t> extent;
        if (has_static_bdimx) {
          extent = GpuLower::current()->launchParams().getDim(
              ParallelType::TIDx);
        } else if (ld->extent()->isConst()) {
          extent = ld->extent()->value().as<int64_t>();
        }

        if (extent.has_value() && extent.value() % kThreadsPerWarp == 0) {
          has_valid_tidx_reduction = true;
        }
      }
    }
  }
Scheduler configuration changes

The warp specialized scheduler now uses dynamic split/parallelize instead of static inner_parallel_static. This changes the fundamental scheduling approach and should be validated to ensure it produces equivalent or better performance while maintaining correctness.

reduction_tv->split(
    inner_reduce_axis, rparams->batches_per_block_inner_reduction, false);
reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx);
reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp();

// // static bdimx is required for TMA warp specialization
// int64_t compute_bdimx = getComputeBdimx(option,
// rparams->lparams.bdimx()); inner_parallel_static(inner_reduce_axis,
// ParallelType::TIDx, compute_bdimx);
Performance regression risk

The comment mentions 59.7% SOL with concrete inputs vs 59.1% with symbolic inputs. This suggests a potential performance regression when using dynamic shapes. The performance impact should be quantified and justified.

// For case contig_1_dtype_float_batch_2048_hidden_8192
// the performance is 59.7% SOL uisng concrete inputs
// for symbolic inputs, the performance is 59.1% SOL

Test failures

  • (Medium, 26) nvFuser validation mismatches in CombinedSchedulerTest and Gpu1Test suites

    Test Name A100 GB200 H100 Source
    CombinedSchedulerTest.IllegalSizeToUseTMA Link
    CombinedSchedulerTest.LayerNormBackward/dtype___half_batch_216_hidden_1024 Link
    CombinedSchedulerTest.LayerNormBackward/dtype___half_batch_216_hidden_768 Link
    CombinedSchedulerTest.LayerNormBackward/dtype_double_batch_216_hidden_1024 Link
    CombinedSchedulerTest.LayerNormBackward/dtype_double_batch_216_hidden_576 Link
    CombinedSchedulerTest.LayerNormBackward/dtype_double_batch_216_hidden_768 Link
    CombinedSchedulerTest.LayerNormBackward/dtype_float_batch_216_hidden_1024 Link
    CombinedSchedulerTest.LayerNormBackward/dtype_float_batch_216_hidden_768 Link
    Gpu1Test.FusionMagicSchedulerRMSNormBackward_CUDA Link
  • (Medium, 10) DistributedTransformer backward fp16/bf16 numerical mismatches across runners

    Test Name A100 A100 (dist.) GB200 GB200 (dist.) H100 Source
    DistributedTransformerTest.Backward/__bfloat Link
    DistributedTransformerTest.Backward/__half Link
  • (Medium, 3) Gradient mismatch in nanoGPT CUDAGraphs nvFuser test (thunder/tests/test_networks)

    Test Name A100 GB200 H100 Source
    thunder.tests.test_networks.test_nanogpt_complete_cudagraphs_autograd_nvfuser_cuda_thunder.dtypes.float32
  • (Medium, 3) Thunder NVFuser gradient mismatches in test_grad::test_populate_grads_block

    Test Name A100 GB200 H100 Source
    thunder.tests.test_grad.test_populate_grads_block_nvfuser_cuda_thunder.dtypes.float32

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 6, 2026

Greptile Summary

This PR enables dynamic shape support in the warp specialized inner-outer persistent scheduler by threading LaunchParams through the lowering pipeline. The key changes are:

  • Threading LaunchParams: LaunchParams is now passed to GpuLower constructor and stored as a member, allowing lowering passes to access runtime launch configuration
  • Dynamic splits: Replaces static bdimx splits with dynamic splits using batches_per_block_inner_reduction and padToMultipleOfWarp() in scheduler code
  • Optimization enablement: parallel_dimension_map.cpp and fused_reduction.cpp now query launch params to enable register sharing and warp reduction optimizations even with dynamic dimensions
  • Test coverage: Converts tests to symbolic inputs and adds kernel reuse verification

The approach maintains correctness by computing launch params early during compilation (before lowering) and making them available throughout the lowering pipeline.

Confidence Score: 4/5

  • This PR is generally safe to merge with one minor concern about the deserialization path
  • The implementation is well-structured and the changes are localized. The deserialization path passes empty LaunchParams() which may disable optimizations for deserialized kernels, but won't cause correctness issues
  • Pay attention to csrc/runtime/executor.cpp - the deserialization path uses empty LaunchParams()

Important Files Changed

Filename Overview
csrc/device_lower/lower2device.h Adds LaunchParams parameter to GpuLower constructor and stores it as member variable for dynamic shape support
csrc/device_lower/lower2device.cpp Updates GpuLower constructor implementation to accept and initialize lparams_ member
csrc/runtime/compiled_kernel.h Adds LaunchParams parameter to both CompiledKernel constructor overloads
csrc/runtime/compiled_kernel.cpp Passes launch_params through to GpuLower in both constructor implementations
csrc/runtime/executor.cpp Passes launch_constraints to CompiledKernel in compile path, but uses empty LaunchParams() in deserialization path
csrc/device_lower/analysis/fused_reduction.cpp Enables warp reduction optimization when launch params provide static bdimx, checking divisibility by warp size (32)
csrc/parallel_dimension_map.cpp Uses launch params for dynamic dimensions to enable register sharing when thread count is known at lowering time
csrc/scheduler/normalization_inner_outer_tma_ws.cpp Replaces static bdimx split with dynamic split using batches_per_block_inner_reduction for dynamic shape support
csrc/scheduler/reduction_utils.cpp Uses dynamic split with batches_per_block_inner_reduction and padToMultipleOfWarp() instead of static inner_parallel_static()
tests/cpp/test_persistent_buffer.cpp Changes RMS norm test to symbolic inputs and adds new test verifying kernel reuse with different dynamic shapes
tests/cpp/test_combined_inner_outer_reduction.cpp Converts test from concrete to symbolic tensor inputs to exercise dynamic shape code paths

Sequence Diagram

sequenceDiagram
    participant User
    participant KernelExecutor
    participant CompiledKernel
    participant GpuLower
    participant Scheduler
    participant Lowering

    User->>KernelExecutor: compile(launch_constraints)
    KernelExecutor->>CompiledKernel: new CompiledKernel(launch_params)
    CompiledKernel->>GpuLower: new GpuLower(cparams, lparams)
    Note over GpuLower: Stores lparams_ as member
    
    GpuLower->>Scheduler: Schedule tensors
    Note over Scheduler: Uses batches_per_block_inner_reduction<br/>for dynamic split instead of static bdimx
    
    GpuLower->>Lowering: Lower to device code
    Lowering->>GpuLower: launchParams().getDim(TIDx)
    Note over Lowering: parallel_dimension_map and<br/>fused_reduction check GpuLower::current()<br/>->launchParams() for dynamic dimensions
    
    Lowering-->>GpuLower: Use launch param value
    GpuLower-->>CompiledKernel: Lowered kernel
    CompiledKernel-->>KernelExecutor: Compiled kernel
    
    Note over KernelExecutor: Runtime execution uses<br/>actual launch params for<br/>warp reduction optimization
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +2370 to +2376
EXPECT_EQ(numRuntimes(), 1)
<< "Same dimensions should reuse the existing kernel";

FusionKernelRuntime* second_runtime =
executor_cache.getMostRecentKernelRuntime();
EXPECT_EQ(first_runtime, second_runtime)
<< "Should reuse the same runtime for identical shapes";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: the comments are misleading here. Line 2370 says "Same dimensions should reuse the existing kernel" but the shapes are different ({2048, 4096} vs {2048 + 8, 4096}). Line 2376 says "Should reuse the same runtime for identical shapes" but the shapes are not identical. Consider clarifying that the outer dimension changes but should still reuse the kernel due to dynamic shape support.

Suggested change
EXPECT_EQ(numRuntimes(), 1)
<< "Same dimensions should reuse the existing kernel";
FusionKernelRuntime* second_runtime =
executor_cache.getMostRecentKernelRuntime();
EXPECT_EQ(first_runtime, second_runtime)
<< "Should reuse the same runtime for identical shapes";
EXPECT_EQ(numRuntimes(), 1)
<< "Different outer dimension should reuse the existing kernel due to dynamic shape support";
FusionKernelRuntime* second_runtime =
executor_cache.getMostRecentKernelRuntime();
EXPECT_EQ(first_runtime, second_runtime)
<< "Should reuse the same runtime despite different outer dimension";

@liqiangxl liqiangxl force-pushed the llu/ws_dynamic_shape branch from 3993211 to 3b68cf3 Compare January 6, 2026 19:30
@liqiangxl
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables dynamic shape support in the warp specialized inner-outer persistent scheduler by threading LaunchParams through the compilation pipeline to GpuLower.

  • Core mechanism: The PR passes launch parameters from KernelExecutor::compile() through CompiledKernel to GpuLower, making them accessible during lowering analysis
  • Dynamic shape detection: FusedReduction analysis now checks GpuLower::launchParams() to determine if bdimx is statically known, enabling warp reduction optimizations even when tensor dimensions are symbolic
  • Register sharing: ParallelDimensionMap::getThreadCountInDim() now consults launch parameters for dynamic dimensions, allowing register sharing when thread counts are known at compile time
  • Scheduler change: Replaced static inner_parallel_static() split with dynamic split + padToMultipleOfWarp() for TIDx in TMA warp specialized path
  • Test updates: Converted tests from concrete to symbolic tensors to validate dynamic shape support, added kernel reuse test

Issues found:

  • Comment formatting broken in csrc/scheduler/reduction_utils.cpp lines 137-142
  • Typo "uisng" → "using" in test comment
  • Misleading test comment claiming "same dimensions" when dimensions actually differ

Confidence Score: 4/5

  • This PR is safe to merge with minor formatting fixes needed
  • The implementation follows a clean architectural pattern by threading launch params through the compilation stack. The logic changes are sound and enable an important optimization. However, formatting issues in comments (syntax errors) require fixing before merge, preventing a score of 5.
  • Fix syntax errors in csrc/scheduler/reduction_utils.cpp lines 137-142 before merging

Important Files Changed

File Analysis

Filename Score Overview
csrc/device_lower/analysis/fused_reduction.cpp 4/5 Enhanced warp reduction detection to support dynamic shapes by checking launch params when bdimx is not statically known
csrc/parallel_dimension_map.cpp 4/5 Updated getThreadCountInDim to use launch params for dynamic dimensions when available, enabling register sharing
csrc/runtime/compiled_kernel.cpp 5/5 Added LaunchParams parameter to constructors and passed it to GpuLower for dynamic shape support
csrc/scheduler/reduction_utils.cpp 3/5 Replaced static split with dynamic split and padToMultipleOfWarp() for TIDx parallelization, has formatting issues in comments
tests/cpp/test_persistent_buffer.cpp 3/5 Mixed changes: switched tensors between concrete/symbolic, added kernel reuse test with misleading comment on line 2371

Sequence Diagram

sequenceDiagram
    participant Executor as KernelExecutor
    participant CK as CompiledKernel
    participant GL as GpuLower
    participant FR as FusedReduction
    participant PDM as ParallelDimensionMap
    participant Sched as ScheduleReductionTV

    Note over Executor: compile() called with launch_constraints
    Executor->>CK: new CompiledKernel(fusion, compile_params, launch_constraints)
    CK->>GL: new GpuLower(fusion, compile_params, launch_params)
    Note over GL: Store launch_params in lparams_ member
    
    GL->>FR: Analyze fused reductions
    FR->>GL: Check launchParams().hasDim(TIDx)
    Note over FR: If bdimx available in launch params,<br/>use it to validate warp reduction
    FR->>FR: Determine if static warp reduction is valid
    
    GL->>Sched: scheduleReductionTV() for TMA warp specialized
    Note over Sched: Use dynamic split + padToMultipleOfWarp()<br/>instead of static split
    
    GL->>PDM: getThreadCountInDim(ParallelType)
    PDM->>GL: Check launchParams().hasDim(pt)
    Note over PDM: Return actual launch param value<br/>for dynamic dimensions when available
    PDM-->>GL: Return thread count or -1
    
    Note over GL: Lowering complete with dynamic shape support
Loading

Comment on lines +137 to +140
// // static bdimx is required for TMA warp specialization
// int64_t compute_bdimx = getComputeBdimx(option,
// rparams->lparams.bdimx()); inner_parallel_static(inner_reduce_axis,
// ParallelType::TIDx, compute_bdimx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment formatting is broken - line breaks should be after the comment prefix.

Suggested change
// // static bdimx is required for TMA warp specialization
// int64_t compute_bdimx = getComputeBdimx(option,
// rparams->lparams.bdimx()); inner_parallel_static(inner_reduce_axis,
// ParallelType::TIDx, compute_bdimx);
// static bdimx is required for TMA warp specialization
// int64_t compute_bdimx = getComputeBdimx(option,
// rparams->lparams.bdimx());
// inner_parallel_static(inner_reduce_axis, ParallelType::TIDx, compute_bdimx);

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

// rparams->lparams.bdimx()); inner_parallel_static(inner_reduce_axis,
// ParallelType::TIDx, compute_bdimx);

// Iteration: [I/Unroll/BIDy, BIDy, Unroll]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong comment prefix - should be // not /.

Suggested change
// Iteration: [I/Unroll/BIDy, BIDy, Unroll]
// Iteration: [I/Unroll/BIDy, BIDy, Unroll]

auto tv0 = makeContigConcreteTensor({dim0, dim1}, dtype);
auto tv1 = makeContigConcreteTensor({dim0, dim1}, dtype);
// For case contig_1_dtype_float_batch_2048_hidden_8192
// the performance is 59.7% SOL uisng concrete inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: 'uisng' should be 'using'.

Suggested change
// the performance is 59.7% SOL uisng concrete inputs
// the performance is 59.7% SOL using concrete inputs

Comment on lines +2370 to +2371
EXPECT_EQ(numRuntimes(), 1)
<< "Same dimensions should reuse the existing kernel";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment says "Same dimensions should reuse" but the test uses different outer dimension (2048 + 8 vs 2048) - comment is misleading.

@liqiangxl
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants