Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 3, 2026

Fixes #5308

---------------------------------------------------------------------------------------------- benchmark: 3 tests ----------------------------------------------------------------------------------------------
Name (time in ms)                                      Min               Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_row_parallel_linear_forward_benchmark[s=2]     3.5521 (1.0)      3.6830 (1.0)      3.5952 (1.0)      0.0510 (1.71)     3.5788 (1.0)      0.0460 (1.0)           1;1  278.1505 (1.0)           5           1
test_row_parallel_linear_forward_benchmark[s=4]     3.6751 (1.03)     3.7427 (1.02)     3.7021 (1.03)     0.0298 (1.0)      3.6876 (1.03)     0.0498 (1.08)          1;0  270.1204 (0.97)          5           1
test_row_parallel_linear_forward_benchmark[s=1]     3.6866 (1.04)     4.1345 (1.12)     3.8824 (1.08)     0.2257 (7.58)     3.7571 (1.05)     0.4190 (9.11)          2;0  257.5757 (0.93)          5           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Overlapping improves the wall time slightly.

Stream assignment and overlapping are verified by the following:

$ nsys profile --capture-range=cudaProfilerApi --capture-range-end=stop mpirun -np 2 pytest tests/python/multidevice/test_overlap.py::'test_row_parallel_linear_forward_benchmark[s=4]' --only-mpi -vs
$ nsys stats report3.nsys-rep --report cuda_gpu_trace | grep '(0)'
    7840730           1184     345                                                                                0.000              3.378  Device              NVIDIA H100 80GB HBM3 (0)    1              20  [CUDA memset]
    7858970         666943     346     2    66     1   384     1     1      168         0.000         0.213                                                     NVIDIA H100 80GB HBM3 (0)    1              20  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT
    8276377            960     421                                                                                0.000              4.167  Device              NVIDIA H100 80GB HBM3 (0)    1              28  [CUDA memset]
    8357049         846078     422     2    66     1   384     1     1      168         0.000         0.213                                                     NVIDIA H100 80GB HBM3 (0)    1              28  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT
    8629561            800     497                                                                                0.000              5.000  Device              NVIDIA H100 80GB HBM3 (0)    1              32  [CUDA memset]
    8958648           1504     573                                                                                0.000              2.660  Device              NVIDIA H100 80GB HBM3 (0)    1              36  [CUDA memset]
    9029464          47392     350                                                                               33.554         707998.515  Device    Device    NVIDIA H100 80GB HBM3 (0)    1              20  [CUDA memcpy Device-to-Device]
    9075640         832766     498     2    66     1   384     1     1      168         0.000         0.213                                                     NVIDIA H100 80GB HBM3 (0)    1              32  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT
    9729238         888798     574     2    66     1   384     1     1      168         0.000         0.213                                                     NVIDIA H100 80GB HBM3 (0)    1              36  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT
   10440469         265567     376    24     1     1   544     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1              24  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   10444405         114368     426                                                                               33.554         293366.399  Device    Device    NVIDIA H100 80GB HBM3 (0)    1              28  [CUDA memcpy Device-to-Device]
   10520725          49440     502                                                                               33.554         678671.942  Device    Device    NVIDIA H100 80GB HBM3 (0)    1              32  [CUDA memcpy Device-to-Device]
   10619732          29408     578                                                                               33.554        1140984.906  Device    Device    NVIDIA H100 80GB HBM3 (0)    1              36  [CUDA memcpy Device-to-Device]
   10708500         141408     452    24     1     1   544     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1              24  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   10852884         139456     528    24     1     1   544     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1              24  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   10994868         138783     604    24     1     1   544     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1              24  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
image

The performance is suboptimal for two reasons:

  1. Unnecessary local mempcy in postAllReduce #5567 leads to an unnecessary memcpy.
  2. ncclAllReduce and gemm compete for SMs. ncclAllReduce is often delayed by the gemm kernel. Therefore, the benchmark can't achieve perfect overlapping. This is a known limitation of NCCL and can be addressed by other SM-free communication backends.

@github-actions
Copy link

github-actions bot commented Jan 3, 2026

Review updated until commit 6ad8d0a

Description

  • Implement AssignStreams pass for stream-parallel loop execution

  • Add stream synchronization and worker stream management

  • Integrate new pass into host IR optimization pipeline

  • Refactor tests and add benchmark for stream parallel performance

Changes walkthrough

Relevant files
Enhancement
3 files
assign_streams.cpp
Implement AssignStreams pass for stream parallelization   
+64/-0   
passes.cpp
Integrate AssignStreams pass into optimization pipeline   
+2/-0     
assign_streams.h
Define AssignStreams optimization pass interface                 
+26/-0   
Tests
2 files
test_stream.py
Remove nvfuser_direct_test parameter from test functions 
+3/-3     
test_overlap.py
Refactor row parallel linear test and add benchmark           
+68/-24 
Documentation
2 files
benchmark_utils.py
Update profiling documentation and recommendations             
+13/-7   
internal_nodes.h
Add documentation comment to insert method                             
+1/-0     
Cleanup
2 files
allocate_and_deallocate.h
Remove unnecessary header include                                               
+0/-1     
ir.h
Remove unnecessary header include                                               
+0/-1     
Configuration changes
1 files
CMakeLists.txt
Reorder source files and add new AssignStreams files         
+2/-1     

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
🔒 No security concerns identified
⚡ Recommended focus areas for review
Stream synchronization correctness

The implementation creates a main stream and worker streams for stream-parallel loops. The synchronization pattern (SetCurrentStream + Synchronize main stream at loop start, then synchronize all worker streams in a joining loop) appears correct for ensuring proper ordering and avoiding race conditions. However, verify that this pattern handles all edge cases correctly, especially for nested loops or loops with complex dependencies.

    Stream* main_stream = IrBuilder::create<Stream>();
    hic->topLevel().insert(
        it, IrBuilder::create<GetCurrentStream>(main_stream));

    // At the beginning of each iteration: set stream and synchronize with main
    // stream
    auto* worker_stream = IrBuilder::create<Stream>(for_loop->index());
    auto* set_stream = IrBuilder::create<SetCurrentStream>(worker_stream);
    auto* sync_main = IrBuilder::create<Synchronize>(main_stream);
    auto old_begin = for_loop->body().exprs().begin();
    for_loop->body().insert(old_begin, set_stream);
    for_loop->body().insert(old_begin, sync_main);

    // After the loop: create a joining loop to synchronize all worker streams
    hic->topLevel().insert(
        next_it, IrBuilder::create<SetCurrentStream>(main_stream));
    auto* join_loop = IrBuilder::create<ForLoop>(
        for_loop->index(), for_loop->start(), for_loop->stop());
    hic->topLevel().insert(next_it, join_loop);

    // In the joining loop: synchronize each worker stream
    auto* join_worker_stream = IrBuilder::create<Stream>(join_loop->index());
    auto* sync_worker = IrBuilder::create<Synchronize>(join_worker_stream);
    join_loop->body().pushBack(sync_worker);

    it = next_it;
  }
}
Stream-parallel loop detection

The code currently assumes all loops are stream-parallel without explicit verification (lines 31-33). While this may be true for the current use case, consider adding explicit validation to ensure only appropriate loops are processed, which would make the code more robust for future changes.

// We should check that the loop is stream-parallel. This is not necessary
// at this moment because all loops are stream-parallel. This is also hard
// to do because hir::ForLoop doesn't point to the source IterDomain.

Test failures

  • (High, 95) CUDA driver/runtime mismatch on dlcluster_h100 affecting nvFuser matmul & top-k test suites

    Test Name H100 Source
    Ampere/MmaTest.SingleTile/Ampere_16_8_16__bfloat Link
    ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/2048_1_1_0 Link
    BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize32_ItemsPerThread4 Link
    ClusterReductionTest.SimpleFusionNotAllReduce/cluster_15_dtype_double Link
    ClusterReductionTest.SimpleFusionNotAllReduce/cluster_4_dtype_double Link
    CutlassExecutorTest.Nvfp4Matmul_BiasEpilogue Link
    General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KK_512_256_128_MmaMacro_m64_n128_k16_splitk_2 Link
    General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MK_512_256_128_MmaMacro_m128_n128_k16_tma_store Link
    General/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MN_512_256_128_MmaMacro_m64_n128_k16_tma_store_splitk_2 Link
    GreedySchedulerTest.ScanNonLocalOutput Link
    ... with 85 more test failures omitted. Check internal logs.
  • (High, 16) CUDA driver too old on dlcluster_h100 causes early failure in RNGTest.BroadcastingRNG

    Test Name H100 Source
    .thunder.tests.opinfos
    .thunder.tests.test_apex_cross_entropy_executor
    .thunder.tests.test_auto_register_torchops
    .thunder.tests.test_cudnn_executor
    .thunder.tests.test_einops
    .thunder.tests.test_grad
    .thunder.tests.test_nvfuser
    .thunder.tests.test_ops
    .thunder.tests.test_sdpaex_executor
    .thunder.tests.test_torch_compile_executor
    ... with 6 more test failures omitted. Check internal logs.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 3, 2026

Greptile Summary

This PR implements stream parallelization for loops in nvFuser's host IR to enable overlapping of computation (matmul) and communication (allreduce) operations. The implementation adds a new AssignStreams optimization pass that transforms stream-parallel loops by:

  • Capturing the main stream before the loop
  • Setting worker streams at the beginning of each iteration and synchronizing with the main stream
  • Creating a joining loop after the main loop to synchronize all worker streams back to the main stream

The changes include:

  • New csrc/host_ir/assign_streams.{cpp,h} implementing the stream assignment pass
  • Integration of the pass into the host IR pipeline
  • Comprehensive test coverage with benchmarks comparing nvFuser against a PyTorch reference implementation
  • Code cleanup removing unnecessary includes and forward declarations

Benchmark results show nvFuser is slightly faster than the reference implementation (3.8ms vs 4.6ms mean), addressing issue #5308. The implementation correctly handles stream ordering and synchronization.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The implementation is clean, well-structured, and thoroughly tested. The core stream assignment logic correctly orders operations (SetCurrentStream before Synchronize). The PR includes comprehensive test coverage with both correctness tests and benchmarks, and demonstrates performance improvements over the reference implementation. Code cleanup changes are safe and improve maintainability.
  • No files require special attention

Important Files Changed

Filename Overview
csrc/host_ir/assign_streams.cpp New file implementing stream assignment pass for stream-parallel loops, creates worker streams and synchronization logic
csrc/host_ir/assign_streams.h New header file declaring AssignStreams optimization pass
csrc/host_ir/passes.cpp Added AssignStreams pass to the host IR pass pipeline
tests/python/multidevice/test_overlap.py Added benchmark test with parameterized chunk sizes and reference implementation for stream-parallel linear forward pass

Sequence Diagram

sequenceDiagram
    participant Main as Main Stream
    participant W0 as Worker Stream 0
    participant W1 as Worker Stream 1
    participant W2 as Worker Stream 2
    
    Note over Main: GetCurrentStream
    Note over Main: Start ForLoop (i=0..2)
    
    Main->>W0: SetCurrentStream(0)
    W0->>Main: Synchronize Main Stream
    Note over W0: ShardByStream (chunk 0)
    Note over W0: Linear (matmul chunk 0)
    Note over W0: AllReduce (async)
    
    par Parallel Execution
        Main->>W1: SetCurrentStream(1)
        W1->>Main: Synchronize Main Stream
        Note over W1: ShardByStream (chunk 1)
        Note over W1: Linear (matmul chunk 1)
        Note over W1: AllReduce (async)
    and
        Main->>W2: SetCurrentStream(2)
        W2->>Main: Synchronize Main Stream
        Note over W2: ShardByStream (chunk 2)
        Note over W2: Linear (matmul chunk 2)
        Note over W2: AllReduce (async)
    end
    
    Note over Main: End ForLoop
    Note over Main: SetCurrentStream(Main)
    Note over Main: Start Join Loop (i=0..2)
    
    Main->>W0: Synchronize Worker Stream 0
    Main->>W1: Synchronize Worker Stream 1
    Main->>W2: Synchronize Worker Stream 2
    
    Note over Main: All chunks complete
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +44 to +46
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Insertion order creates reversed execution order. The second insert at old_begin places sync_main before set_stream, but the worker stream should be set before synchronization.

Suggested change
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
for_loop->body().insert(old_begin, sync_main);
for_loop->body().insert(old_begin, set_stream);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip this file. Subsumed by another PR.

namespace nvfuser {

class ViewTransform;
class Scope;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary -- class Scope is defined in the same file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip this file -- subsumed by another PR

from nvfuser_direct import FusionDefinition, ParallelType, DataType


def test_matmul(nvfuser_direct_test):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests didn't use the nvfuser_direct_test fixture.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +45 to +46
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Wrong insertion order - both inserts use old_begin, causing sync_main to be placed before set_stream. This means synchronization happens on the main stream before switching to the worker stream. The worker stream should be set first, then synchronize.

Current order: Sync(main) → SetStream(worker)
Required order: SetStream(worker) → Sync(main)

Fix by inserting sync_main at the updated iterator after inserting set_stream

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. Maybe it missed that Scope contains a linked list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! I completely missed that Scope uses a linked list structure. When you insert at old_begin twice, both operations insert at the same position, so the second insert (sync_main) goes before the first insert (set_stream).

Looking at the code, both insertions use the same iterator old_begin, which means:

  1. First insert: set_stream goes at old_begin
  2. Second insert: sync_main goes at the same old_begin position, pushing set_stream forward

This creates the wrong order: sync_mainset_stream

The fix is to update the iterator after the first insertion:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);

Should be:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);

Wait, that's still wrong. Let me fix this properly:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);

Actually, the correct fix is:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
auto new_pos = for_loop->body().exprs().begin();
std::advance(new_pos, 1);
for_loop->body().insert(new_pos, sync_main);

Or more simply, insert sync_main after inserting set_stream by using the updated iterator position.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

insert(old_begin, ...) inserts before old_begin, doesn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! insert(old_begin, ...) inserts before old_begin. So the current code:

for_loop->body().insert(old_begin, set_stream);    // inserts set_stream before old_begin
for_loop->body().insert(old_begin, sync_main);     // inserts sync_main before old_begin

Since both insertions use the same old_begin iterator, and insert() inserts before that position, the second insertion (sync_main) goes before the first insertion (set_stream).

This creates the order: sync_mainset_stream → (original body)

But we want: set_streamsync_main → (original body)

The fix is to update the iterator after the first insertion:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(for_loop->body().exprs().begin(), sync_main);

This way set_stream gets inserted first, then sync_main gets inserted at the new beginning (which is now after set_stream).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mpirun -np 2 -x NVFUSER_DUMP=host_ir pytest tests/python/multidevice/test_overlap.py -k 'row_parallel_linear_forward and not reference and not benchmark' --only-mpi -vs confirmed that SetCurrentStream is indeed before Synchronize Stream. The code as is is correct.

%HostIrContainer { (T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})) -> (T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1})) :
  T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), mem_type=global, size=( i0 * 2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0x3df7aa70
  FOR i535 from 0 to 3:
    SetCurrentStream to Stream i535
    Synchronize Stream 0x3df7aa70
    T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535)
    T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), mem_type=global, size=( ( ceilDiv(i0, 3) ) * 12 ), zero_init=false, resets_to_zero=false)
    T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1})
       = linear(T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}),
                T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})      )
    T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535)
    Communication 272 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}), backend=NCCL)
    Wait Communication 272
  SetCurrentStream to Stream 0x3df7aa70
  FOR i535 from 0 to 3:
    Synchronize Stream i535
} // %HostIrContainer


// We should check that the loop is stream-parallel. This is not necessary
// at this moment because all loops are stream-parallel. This is also hard
// to do becauase hir::ForLoop doesn't point to the source IterDomain.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check my understanding: the iterdomain in ForLoop is not Stream Parallel unlike the iterdomain of the output loop domain?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike kir::ForLoop, hir::ForLoop doesn't have the IterDomain as an input. Maybe it should.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it as simple as addInput(id) as is done for kir::ForLoop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's plausible. I'm debating with myself on how much host IR depend on nvFuser specifics like IterDomain. I could go with adding a ParallelType as an attribute. 🤷

# benchmark iterations using
# ```bash
# mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi : -np <processes - 1> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi
# nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> mpirun -np <processes> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this update needed? The former comment gave me more consistent timings for communication than the updated comment.

For transformer forward tensor parallel on 8 H100:

Updated comment:

Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     53.3         14580255         16  911265.9  693407.0    254975   2909549     795575.7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)             
     14.2          3870002         17  227647.2  107807.0    104287    370304     133392.2  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT                                                      
     11.0          3014100          8  376762.5  376510.0    370173    380863       3346.5  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_bias_TNT                                                 
      9.5          2591701          9  287966.8  287742.0    284573    291007       1958.7  nvjet_sm90_tst_192x192_64x4_2x1_v_bz_coopB_bias_TNN                                                 
      4.1          1114011          9  123779.0  123839.0    122495    125152        920.8  void pytorch_flash::flash_fwd_kernel<Flash_fwd_kernel_traits<(int)128, (int)128, (int)32, (int)4, (…
      4.0          1087676          8  135959.5  135807.5    134335    138272       1126.1  nvf::nvfuser_inner_persistent_f0_c1_r0_g11(nvf::Tensor<nvf::__bfloat, (int)1, (int)1>, nvf::Tensor<…
      2.3           618591          8   77323.9   77424.0     76576     77600        344.5  nvf::nvfuser_pointwise_f0_c1_r0_g10(nvf::Tensor<nvf::__bfloat, (int)1, (int)1>, nvf::Tensor<nvf::__…
      1.2           339453          9   37717.0   37536.0     37311     38368        362.7  nvf::nvfuser_inner_persistent_f0_c1_r0_g12(nvf::Tensor<nvf::__bfloat, (int)3, (int)3>, nvf::Tensor<…
      0.5           126303          8   15787.9   15824.0     15616     15936        108.2  nvf::nvfuser_pointwise_f0_c1_r0_g7(nvf::Tensor<nvf::__bfloat, (int)3, (int)4>, nvf::Tensor<nvf::__b…

Previous comment:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     31.1           688351          2  344175.5  344175.5    252928    435423     129043.5  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)             
     21.2           469823          2  234911.5  234911.5    105760    364063     182647.8  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT                                                      
     17.0           376127          1  376127.0  376127.0    376127    376127          0.0  nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_bias_TNT                                                 
     13.1           291392          1  291392.0  291392.0    291392    291392          0.0  nvjet_sm90_tst_192x192_64x4_2x1_v_bz_coopB_bias_TNN                                                 
      6.1           136064          1  136064.0  136064.0    136064    136064          0.0  nvf::nvfuser_inner_persistent_f0_c1_r0_g11(nvf::Tensor<nvf::__bfloat, (int)1, (int)1>, nvf::Tensor<…
      5.6           123456          1  123456.0  123456.0    123456    123456          0.0  void pytorch_flash::flash_fwd_kernel<Flash_fwd_kernel_traits<(int)128, (int)128, (int)32, (int)4, (…
      3.5            77695          1   77695.0   77695.0     77695     77695          0.0  nvf::nvfuser_pointwise_f0_c1_r0_g10(nvf::Tensor<nvf::__bfloat, (int)1, (int)1>, nvf::Tensor<nvf::__…
      1.7            37632          1   37632.0   37632.0     37632     37632          0.0  nvf::nvfuser_inner_persistent_f0_c1_r0_g12(nvf::Tensor<nvf::__bfloat, (int)3, (int)3>, nvf::Tensor<…
      0.7            15648          1   15648.0   15648.0     15648     15648          0.0  nvf::nvfuser_pointwise_f0_c1_r0_g7(nvf::Tensor<nvf::__bfloat, (int)3, (int)4>, nvf::Tensor<nvf::__b…

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this issue here: #4844 (comment) and the former comment gave me the nsys profile which matched my wall-clock measurements most closely.

Have you noticed this discrepancy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the former comment gave me the nsys profile which matched my wall-clock measurements most closely.

image seems to say the reverse -- putting nsys profile at the front gave you closer-to-wall time. Am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.nvidia.com/nsight-systems/UserGuide/index.html?utm_source=chatgpt.com#handling-application-launchers-mpirun-deepspeed-etc doesn't have a strong opinion between the two for single-node. For me, the updated command line is more convenient -- it's shorter and gives me the timing of all GPUs in one file so it's easier to process. But let me double check the reported timing is trustworthy.

Copy link
Collaborator

@Priya2698 Priya2698 Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, let me clarify.

I was comparing the two comments in the diff:

  • Updated command (same as 2 in the above image): # nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> mpirun -np <processes> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi. Collects data across all ranks.

  • Previous command: # mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi : -np <processes - 1> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi. Only one rank collects data.

The nsys profiles I pasted above are corresponding to these 2 commands.

While (2) in the above image was closer than (1) in the same image, the numbers for communication were still unstable as compared to the previous command where only 1 rank uses nsys profile.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.nvidia.com/nsight-systems/UserGuide/index.html?utm_source=chatgpt.com#handling-application-launchers-mpirun-deepspeed-etc doesn't have a strong opinion between the two for single-node. For me, the updated command line is more convenient -- it's shorter and gives me the timing of all GPUs in one file so it's easier to process. But let me double check the reported timing is trustworthy.

I only got very high numbers for communication (higher than wall-time), everything else looked okay.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried three configs:

  1. nsys profile ... mpirun -np 8 ...
  2. mpirun -np 8 nsys profile ...
  3. mpirun -np 1 nsys profile ... : -np 7 ...

I prefer config 1 overall because only config 1 tells me:

  1. The actual run time of either1 allreduce operation is around 260us. Recall that kernels for the same allreduce ends at roughly the same time so the kernel starts early looks slower. Therefore, the actual run time of an allreduce operation, by my definition, is the run time of the fastest kernel corresponding to that allreduce operation.
  2. Despite of the starting time, allreduce kernels across GPUs end about the same time. The first allreduce ends at around 20,310us and the second around 37,012us.

Config 2 leads to a large variance and there's no way to align different GPUs because each runs a separate nsys profile. However, like config 1, it does show the 260us correctly.

Config 3 has lowest variance likely because the process being profiled tends to run slowest and launches allreduce kernels the last. However, even with that, the first allreduce is shown to run for 441us, which is inaccurate.

Config 1

$ nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:1 mpirun -np 8 pytest tests/python/multidevice/test_transformer_engine.py -k 'test_transformer_layer[nonoverlap-tp-forward]' -vs --only-mpi

$ nsys stats report3.nsys-rep --report cuda_gpu_trace | grep nccl
   13077591        7232222    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (1)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   13364015        6946483    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (2)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   13957463        6353057    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (7)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   15182508        5127666    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (6)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   15584437        4725800    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   15749866        4560473    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (4)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   16052423        4258616    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (3)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   20044166         266689    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (5)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   21276737       15735928    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (6)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   21392976       15620182    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (1)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   21410685       15601858    2068     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (3)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   22583001       14430372    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (7)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   22691330       14321563    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   24994300       12019087    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (4)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   31899021        5114358    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (2)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   36751238         261377    2072     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (5)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)

Config 2

$ mpirun -np 8 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:1 pytest tests/python/multidevice/test_transformer_engine.py -k 'test_transformer_layer[nonoverlap-tp-forward]' -vs --only-mpi

$ for i in {4..11}; do nsys stats report$i.nsys-rep --report cuda_gpu_trace; done | grep nccl
   10424681       40915927    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (2)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   52313758        1030942    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (2)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   10484621         264926    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (1)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   12491523         261631    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (1)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   10588716       14568360    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (3)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   26119026        1042526    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (3)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   13385933      167946450    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (6)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
  182302274        1033604    2072     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (6)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   34146687      756926958    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (4)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
  792038764        1038974    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (4)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   21311074      121867114    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (7)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
  144132973        1048896    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (7)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   12634286      470272575    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (5)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
  483859345        1050852    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (5)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   10716413       52947804    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   64613978        1054210    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)

Config 3

$ mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:1 pytest tests/python/multidevice/test_transformer_engine.py -k 'test_transformer_layer[nonoverlap-tp-forward]' -vs --only-mpi : -np 7 pytest tests/python/multidevice/test_transformer_engine.py -k 'test_transformer_layer[nonoverlap-tp-forward]' -vs --only-mpi

$ nsys stats report12.nsys-rep --report cuda_gpu_trace | grep nccl
   10208430         441441    1847     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
   12226226         270945    2066     16     1     1   640     1     1       96         0.037         0.082                                                     NVIDIA H100 80GB HBM3 (0)    1               7  ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)

(Note: you'll need to change report indices according to your generated files.)

Footnotes

  1. a TP transformer layer forward runs two allreduces, one in MHA and the other in MLP. They transfer data of the same size, b * s * h.

@wujingyue wujingyue requested a review from Priya2698 January 6, 2026 04:50
Base automatically changed from wjy/ref to main January 6, 2026 07:10
@wujingyue
Copy link
Collaborator Author

!test

Comment on lines +31 to +33
// We should check that the loop is stream-parallel. This is not necessary
// at this moment because all loops are stream-parallel. This is also hard
// to do because hir::ForLoop doesn't point to the source IterDomain.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do ALL hir::ForLoops stream-parallelize? Is there no case where we want to sequentially loop in hir? or is this pass triggered by some other condition I'm not seeing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do ALL hir::ForLoops stream-parallelize?

Yes at this moment.

I'm considering separating ParallelType::Stream and ParallelType::HostSerial. The latter doesn't exist today. That's when we'll have to look at the parallel type of the loop index.

@pytest.mark.benchmark
def test_row_parallel_linear_forward_benchmark(multidevice_test, benchmark):
# This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
h, s, t = 8192, 2, 8192
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making s a parameter to benchmark for s=2/4. For reference implementation, s=4 had better overlap.

Copy link
Collaborator

@Priya2698 Priya2698 Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add nsys profiles for s=4 in the PR description?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s=4 is overall slower than s=2 according to my updated results, but it seems to show better overlap as you said. Recall that there's overhead in decomposing matmuls.

Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let me know if you plan to test merging the for-loops in this PR.

}

// We should check that the loop is stream-parallel. This is not necessary
// at this moment because all loops are stream-parallel. This is also hard
Copy link
Collaborator

@Priya2698 Priya2698 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly for this PR, but similar to kir::ForLoop, hir::ForLoop can hold the source iterdomain for this check

for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);

// After the loop: create a joining loop to synchronize all worker streams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan on merging this with the above for-loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't convinced myself that will work: http://nv/e-d

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Implements stream parallelization for loops in nvFuser's host IR to enable overlapping computation and communication operations. Adds new AssignStreams optimization pass that transforms stream-parallel loops by capturing the main stream, assigning worker streams at loop iteration start with synchronization, and creating a joining loop afterward to synchronize all worker streams back. Includes comprehensive test coverage with benchmarks showing nvFuser slightly outperforms the PyTorch reference implementation (3.6ms vs 4.6ms mean).

Confidence Score: 4/5

  • Safe to merge with minor improvement opportunities
  • The implementation is well-structured with proper synchronization patterns, comprehensive tests, and correct integration into the host IR pipeline. The main concern is the acknowledged missing validation for loop stream-parallelization (line 31-33 in assign_streams.cpp), which could theoretically transform non-stream-parallel loops incorrectly, though the comment indicates all current loops are stream-parallel. The code follows established patterns, includes thorough test coverage, and demonstrates correct behavior via benchmarks and profiling.
  • csrc/host_ir/assign_streams.cpp - consider adding validation or assertion for stream-parallel loop check

Important Files Changed

File Analysis

Filename Score Overview
csrc/host_ir/assign_streams.cpp 4/5 Implements AssignStreams pass to transform stream-parallel loops with proper synchronization; well-structured but lacks validation for loop stream-parallelization
csrc/host_ir/passes.cpp 5/5 Integrates AssignStreams pass into host IR pipeline after AllocateAndDeallocate; correct ordering
tests/python/direct/test_stream.py 5/5 Comprehensive tests for stream parallelization with matmul operations; validates correct kernel count and shapes

Sequence Diagram

sequenceDiagram
    participant Main as Main Stream
    participant W0 as Worker Stream 0
    participant W1 as Worker Stream 1
    participant W2 as Worker Stream 2
    
    Note over Main: GetCurrentStream (capture main)
    Note over Main: FOR i=0 to 3
    
    Main->>W0: SetCurrentStream(0)
    W0->>Main: Synchronize(main_stream)
    Note over W0: Execute loop body (matmul/allreduce)
    
    Main->>W1: SetCurrentStream(1)
    W1->>Main: Synchronize(main_stream)
    Note over W1: Execute loop body (matmul/allreduce)
    
    Main->>W2: SetCurrentStream(2)
    W2->>Main: Synchronize(main_stream)
    Note over W2: Execute loop body (matmul/allreduce)
    
    Note over Main: SetCurrentStream(main_stream)
    Note over Main: FOR i=0 to 3 (joining loop)
    
    Main->>W0: Synchronize(worker_stream_0)
    Main->>W1: Synchronize(worker_stream_1)
    Main->>W2: Synchronize(worker_stream_2)
    
    Note over Main: All workers synchronized back
Loading

Comment on lines +31 to +33
// We should check that the loop is stream-parallel. This is not necessary
// at this moment because all loops are stream-parallel. This is also hard
// to do because hir::ForLoop doesn't point to the source IterDomain.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment acknowledges this validation is skipped, but consider adding a TODO or assertion to track this technical debt. Without validation, non-stream-parallel loops could be incorrectly transformed, potentially leading to incorrect synchronization patterns. At minimum, add a NVF_CHECK that verifies the loop meets basic requirements (e.g., has a valid index, start, and stop).

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements stream parallelization for loops in nvFuser's host IR to enable overlapping of computation and communication operations. The core implementation adds a new AssignStreams optimization pass that transforms stream-parallel loops by assigning worker streams to each iteration and adding proper synchronization.

Key Changes:

  • New AssignStreams pass in csrc/host_ir/assign_streams.{cpp,h} that transforms loops to use worker streams
  • Integration into host IR pipeline via csrc/host_ir/passes.cpp
  • Comprehensive test coverage in tests/python/multidevice/test_overlap.py with benchmarks
  • Code cleanup: removed unnecessary includes in allocate_and_deallocate.h, ir.h, and internal_nodes.h

Transformation Pattern:
For each loop, the pass:

  1. Captures the main stream before the loop
  2. At the start of each iteration: switches to a worker stream and synchronizes with the main stream
  3. After the loop: creates a joining loop that synchronizes all worker streams back to main

Issues Found:

  • Copyright year is 2026 in both new files (should be 2025)
  • Missing validation that loops are actually stream-parallel (acknowledged in code comment but not implemented)

The implementation correctly follows the stream synchronization pattern demonstrated in the PyTorch reference implementation. Benchmark results show nvFuser achieves slight performance improvements over the baseline.

Confidence Score: 4/5

  • This PR is safe to merge with minor corrections needed for copyright years
  • The implementation is technically sound with correct synchronization logic matching the reference implementation. The only actual errors are copyright year mistakes (2026 instead of 2025). The missing stream-parallel validation is acknowledged in comments and appears to be a known limitation rather than an oversight. Comprehensive tests provide good coverage.
  • Pay attention to the copyright years in csrc/host_ir/assign_streams.cpp and csrc/host_ir/assign_streams.h which need correction from 2026 to 2025

Important Files Changed

File Analysis

Filename Score Overview
csrc/host_ir/assign_streams.cpp 4/5 New file implementing stream assignment pass for stream-parallel loops. Transforms loops to use worker streams with proper synchronization. Copyright year is incorrect (2026 instead of 2025). Logic is sound but lacks validation that loops are actually stream-parallel.
csrc/host_ir/assign_streams.h 4/5 New header file declaring AssignStreams optimization pass. Copyright year is incorrect (2026 instead of 2025). Clean interface design.
csrc/host_ir/passes.cpp 5/5 Added AssignStreams pass to the host IR pipeline after AllocateAndDeallocate. Simple integration, no issues found.
tests/python/multidevice/test_overlap.py 5/5 Added comprehensive tests for stream-parallelized row-parallel linear forward pass with benchmarks. Includes reference implementation matching the transformation pattern. Well-structured tests with proper validation.

Sequence Diagram

sequenceDiagram
    participant Main as Main Stream
    participant Worker as Worker Stream i
    participant Ops as Operations
    
    Note over Main: Save main stream
    Main->>Main: GetCurrentStream
    
    Note over Main,Worker: Loop iteration starts
    Main->>Worker: SetCurrentStream worker_stream i
    Worker->>Main: Synchronize main_stream
    Note over Worker: Wait for main work to complete
    Worker->>Ops: Execute matmul
    Worker->>Ops: Execute allreduce
    Note over Worker: Work continues in parallel
    
    Note over Main: After loop completes
    Worker->>Main: SetCurrentStream main_stream
    
    Note over Main: Join loop synchronizes workers
    Main->>Worker: Synchronize worker_stream 0
    Main->>Worker: Synchronize worker_stream 1
    Main->>Worker: Synchronize worker_stream N
    Note over Main: All workers synchronized
    
    Note over Main: Continue on main stream
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Overview

This PR implements stream parallelization for loops in nvFuser's host IR, enabling overlapping of computation (matmul) and communication (allreduce) operations to improve performance in distributed workloads.

Key Changes

New Stream Assignment Pass (csrc/host_ir/assign_streams.cpp):

  • Transforms stream-parallel loops to use multiple CUDA streams
  • Captures the main stream before the loop
  • Sets worker streams at the beginning of each iteration with synchronization
  • Creates a joining loop after the main loop to synchronize all worker streams back to the main stream

Integration:

  • Added AssignStreams pass to the host IR optimization pipeline (runs after AllocateAndDeallocate)
  • Updated CMakeLists.txt to include the new source file

Tests & Benchmarks:

  • Refactored test_row_parallel_linear_forward to extract reusable fusion definition
  • Added parametrized benchmark test_row_parallel_linear_forward_benchmark with varying chunk counts (s=1,2,4)
  • Includes reference PyTorch implementation for validation

Code Cleanup:

  • Removed unnecessary includes from header files
  • Improved documentation for Scope::insert method
  • Fixed test function signatures

Architecture

The stream assignment creates a pattern where:

  1. Main stream is captured before the loop
  2. Each loop iteration executes on a dedicated worker stream
  3. Worker streams synchronize with the main stream before starting work
  4. After the main loop, a joining loop makes the main stream wait for all workers

This enables concurrent execution of loop iterations on different streams, particularly beneficial for overlapping matmul operations with collective communication.

Performance Impact

Benchmark results show slight improvement with overlapping (3.6ms vs 3.7-3.9ms for s=1), with the implementation verified via nsys profiling to correctly assign different streams to different operations.

Confidence Score: 2/5

  • This PR contains a critical logic error that will cause incorrect stream synchronization behavior at runtime.
  • Score reflects a critical bug in csrc/host_ir/assign_streams.cpp (lines 49-53) where the insertion order of SetCurrentStream and join_loop is inverted, causing the join loop to execute before setting the current stream to main_stream. This means synchronizations will happen on the wrong stream (last worker stream instead of main stream), breaking the intended stream parallelization semantics. Additionally, the copyright year is incorrect (2026 instead of 2025).
  • csrc/host_ir/assign_streams.cpp requires immediate attention to fix the insertion order bug on lines 49-53. The expected Host IR comment in tests/python/multidevice/test_overlap.py (line 74-75) should also be updated to include the missing SetCurrentStream statement.

Important Files Changed

File Analysis

Filename Score Overview
csrc/host_ir/assign_streams.cpp 2/5 New file implementing stream assignment pass. Contains critical logic error in insertion order (lines 49-53) causing incorrect stream synchronization sequence. Also has incorrect copyright year (2026).
csrc/host_ir/assign_streams.h 5/5 New header file declaring AssignStreams optimization pass. Clean interface, no issues detected.
csrc/host_ir/passes.cpp 5/5 Integrates new AssignStreams pass into pipeline. Correct ordering after AllocateAndDeallocate pass.
tests/python/multidevice/test_overlap.py 4/5 Refactors test into reusable function, adds benchmark test, updates to use outer_split instead of split. Expected Host IR comment may be missing SetCurrentStream statement between loops.

Sequence Diagram

sequenceDiagram
    participant Main as Main Stream
    participant W0 as Worker Stream 0
    participant W1 as Worker Stream 1
    participant W2 as Worker Stream 2
    
    Note over Main: GetCurrentStream
    
    rect rgb(200, 220, 255)
        Note over Main,W2: Main Loop (stream-parallel iterations)
        Main->>W0: SetCurrentStream(worker_0)
        W0->>Main: Synchronize(main_stream)
        Note over W0: Compute chunk 0<br/>(matmul + allreduce)
        
        Main->>W1: SetCurrentStream(worker_1)
        W1->>Main: Synchronize(main_stream)
        Note over W1: Compute chunk 1<br/>(matmul + allreduce)
        
        Main->>W2: SetCurrentStream(worker_2)
        W2->>Main: Synchronize(main_stream)
        Note over W2: Compute chunk 2<br/>(matmul + allreduce)
    end
    
    Note over Main: SetCurrentStream(main_stream)
    
    rect rgb(255, 220, 200)
        Note over Main,W2: Join Loop (synchronize workers)
        W0->>Main: Synchronize(worker_0)
        W1->>Main: Synchronize(worker_1)
        W2->>Main: Synchronize(worker_2)
    end
    
    Note over Main: Continue on main stream
Loading

@@ -0,0 +1,64 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copyright year is set to 2026, which appears to be incorrect. This should likely be 2025 or earlier.

Suggested change
* SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.

Comment on lines +49 to +53
hic->topLevel().insert(
next_it, IrBuilder::create<SetCurrentStream>(main_stream));
auto* join_loop = IrBuilder::create<ForLoop>(
for_loop->index(), for_loop->start(), for_loop->stop());
hic->topLevel().insert(next_it, join_loop);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical logic error: The insertion order will result in incorrect code generation. Both statements use insert(next_it, ...) which inserts BEFORE the iterator position. Since both use the same next_it, the join_loop (inserted second) will be placed BEFORE the SetCurrentStream (inserted first), resulting in:

FOR loop (original)
join_loop               ← wrong position
SetCurrentStream        ← wrong position  

The correct order should be SetCurrentStream THEN join_loop, because:

  1. After the first FOR loop exits, the current stream is the last worker stream (not the main stream)
  2. The join_loop synchronizes worker streams on the current stream
  3. Therefore, we must set the current stream to main_stream BEFORE the join_loop

This matches the reference implementation row_parallel_linear_forward_reference in test_overlap.py (lines 182-194) where main_stream.wait_stream(worker_stream) happens on the main_stream.

The fix is to save the iterator returned from the first insert and use it for the second insert:

Suggested change
hic->topLevel().insert(
next_it, IrBuilder::create<SetCurrentStream>(main_stream));
auto* join_loop = IrBuilder::create<ForLoop>(
for_loop->index(), for_loop->start(), for_loop->stop());
hic->topLevel().insert(next_it, join_loop);
// After the loop: create a joining loop to synchronize all worker streams
auto set_main_it = hic->topLevel().insert(
next_it, IrBuilder::create<SetCurrentStream>(main_stream));
auto* join_loop = IrBuilder::create<ForLoop>(
for_loop->index(), for_loop->start(), for_loop->stop());
hic->topLevel().insert(std::next(set_main_it), join_loop);

Comment on lines +74 to +75
# FOR i535 from 0 to 3:
# Synchronize Stream i535
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expected Host IR output in this comment appears to be missing a SetCurrentStream statement between the two FOR loops. Based on the implementation in assign_streams.cpp (lines 49-50), there should be a SetCurrentStream to Stream 0x174e5c80 before the second FOR loop.

The correct expected output should be:

#   FOR i535 from 0 to 3:
#     SetCurrentStream to Stream i535
#     Synchronize Stream 0x174e5c80
#     ... (computations)
#   SetCurrentStream to Stream 0x174e5c80
#   FOR i535 from 0 to 3:
#     Synchronize Stream i535

This matches the reference implementation row_parallel_linear_forward_reference where the main stream waits for all worker streams after they complete.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements stream parallelization for loops in nvFuser's host IR to enable overlapping of computation (matmul) and communication (allreduce) operations. The core implementation adds an AssignStreams optimization pass that transforms stream-parallel loops by capturing the main stream, assigning worker streams to each iteration with proper synchronization, and creating a joining loop to synchronize all worker streams back to the main stream.

Key Changes:

  • New csrc/host_ir/assign_streams.{cpp,h} implementing the stream assignment pass
  • Integration of the pass into the host IR optimization pipeline
  • Refactored test structure with new benchmark tests
  • Code cleanup removing unnecessary includes

Issues Found:

  • The joining loop reuses the main loop's index variable, which could cause bugs if code after the loops depends on the index value
  • Copyright year typo (2026 instead of 2025) in the new header file

Confidence Score: 3/5

  • This PR has one moderate logic issue that should be addressed before merging
  • The implementation is generally sound with proper synchronization primitives and good test coverage. However, there's a potential bug where the joining loop reuses the main loop's index variable (line 51 in assign_streams.cpp), which could cause issues if any code depends on the loop index value after the loop completes. This needs verification or fixing. The copyright year typo is a minor style issue.
  • Pay close attention to csrc/host_ir/assign_streams.cpp - specifically the index variable reuse in the joining loop

Important Files Changed

File Analysis

Filename Score Overview
csrc/host_ir/assign_streams.cpp 3/5 New file implementing stream parallelization pass. Contains potential bug with loop index reuse in joining loop.
csrc/host_ir/assign_streams.h 4/5 Header file with copyright year typo (2026 instead of 2025).
tests/python/multidevice/test_overlap.py 4/5 Refactored test to extract fusion definition, added benchmark test. Changed from torch.randint to torch.testing.make_tensor and torch.randn for test data generation.

Sequence Diagram

sequenceDiagram
    participant Main as Main Stream
    participant Pass as AssignStreams Pass
    participant ForLoop as Stream Parallel ForLoop
    participant W0 as Worker Stream 0
    participant W1 as Worker Stream 1
    participant WN as Worker Stream N
    
    Pass->>Main: GetCurrentStream
    Note over Pass: Capture main stream before loop
    Pass->>ForLoop: Transform loop structure
    
    Note over ForLoop,WN: Main Loop Execution
    ForLoop->>W0: SetCurrentStream worker 0
    W0->>Main: Synchronize with main stream
    W0->>W0: Execute iteration 0 body
    
    ForLoop->>W1: SetCurrentStream worker 1
    W1->>Main: Synchronize with main stream
    W1->>W1: Execute iteration 1 body
    
    ForLoop->>WN: SetCurrentStream worker N
    WN->>Main: Synchronize with main stream
    WN->>WN: Execute iteration N body
    
    Pass->>Main: SetCurrentStream back to main
    
    Note over ForLoop,WN: Joining Loop Execution
    Main->>W0: Synchronize worker stream 0
    Main->>W1: Synchronize worker stream 1
    Main->>WN: Synchronize worker stream N
    
    Note over Main: All workers joined to main stream
Loading

@@ -0,0 +1,26 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyright year is 2026 (future year) - should be 2025

Comment on lines +51 to +52
auto* join_loop = IrBuilder::create<ForLoop>(
for_loop->index(), for_loop->start(), for_loop->stop());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reusing for_loop->index() for the joining loop could cause issues. The index variable from the main loop will be overwritten during the join loop execution. If any code after the loops depends on the loop index value, this could lead to bugs. Consider creating a new index variable for the joining loop.

@wujingyue
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Implements stream parallelization for loops in nvFuser's host IR to enable overlapping of computation and communication operations. Adds an AssignStreams optimization pass that transforms loops by capturing the main stream, assigning worker streams to each iteration with proper synchronization, and creating a joining loop to synchronize all worker streams back to the main stream. Includes comprehensive test coverage with benchmarks showing slight performance improvements from overlapping.

Confidence Score: 4/5

  • Safe to merge with minor style improvement suggested for iterator handling
  • The implementation correctly achieves stream parallelization with proper synchronization semantics as verified by tests. The core logic is sound: capturing the main stream, assigning worker streams per iteration, and synchronizing back via a joining loop. One style issue noted regarding iterator invalidation (lines 44-46 in assign_streams.cpp) that works in practice but could be clearer. No functional bugs or security concerns identified.
  • csrc/host_ir/assign_streams.cpp for iterator handling style improvement

Important Files Changed

File Analysis

Filename Score Overview
csrc/host_ir/assign_streams.cpp 4/5 Adds stream assignment pass to parallelize loops across streams for overlapping computation and communication
csrc/host_ir/passes.cpp 5/5 Integrates AssignStreams pass into the host IR optimization pipeline
tests/python/multidevice/test_overlap.py 4/5 Refactors row parallel linear forward test, adds benchmark tests and updates expected host IR comments

Sequence Diagram

sequenceDiagram
    participant Main as Main Stream
    participant W0 as Worker Stream 0
    participant W1 as Worker Stream 1
    participant W2 as Worker Stream 2
    
    Note over Main: GetCurrentStream
    
    rect rgb(200, 220, 255)
        Note over Main,W2: Main Loop (3 iterations)
        
        Main->>W0: SetCurrentStream(0)
        W0->>Main: Synchronize(main)
        Note over W0: Compute + AllReduce
        
        par Parallel Execution
            Main->>W1: SetCurrentStream(1)
            W1->>Main: Synchronize(main)
            Note over W1: Compute + AllReduce
        and
            Note over W0: Continue async work
        end
        
        par Parallel Execution
            Main->>W2: SetCurrentStream(2)
            W2->>Main: Synchronize(main)
            Note over W2: Compute + AllReduce
        and
            Note over W0,W1: Continue async work
        end
    end
    
    rect rgb(255, 220, 200)
        Note over Main,W2: Joining Loop (synchronize workers)
        Main->>W0: Synchronize(worker 0)
        Main->>W1: Synchronize(worker 1)
        Main->>W2: Synchronize(worker 2)
    end
    
    Main->>Main: SetCurrentStream(main)
Loading

Comment on lines +44 to +46
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The insertion order places sync_main before set_stream due to how insert works with old_begin. After the first insert at old_begin, set_stream is inserted. When the second insert happens at the same old_begin (which still points to the original beginning), sync_main gets inserted before set_stream.

This creates the sequence: sync_mainset_stream → rest of body. However, the correct order should be set_streamsync_main because you need to set the worker stream first before synchronizing with the main stream on that worker stream.

Suggested change
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto* worker_stream = IrBuilder::create<Stream>(for_loop->index());
auto* set_stream = IrBuilder::create<SetCurrentStream>(worker_stream);
auto* sync_main = IrBuilder::create<Synchronize>(main_stream);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
auto new_begin = for_loop->body().exprs().begin();
for_loop->body().insert(std::next(new_begin), sync_main);

Comment on lines +44 to +46
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using old_begin after the first insert relies on iterator invalidation rules that may lead to undefined behavior. While this works in practice (confirmed by tests), consider capturing the return value for clarity:

Suggested change
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
auto it = for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(std::next(it), sync_main);

This makes the intent clearer and avoids potential issues with iterator invalidation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support circular buffering in host IR lowering to overlap matmul and allreduce

4 participants