Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,10 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp
${NVFUSER_SRCS_DIR}/host_ir/allocate_and_deallocate.cpp
${NVFUSER_SRCS_DIR}/host_ir/assign_streams.cpp
${NVFUSER_SRCS_DIR}/host_ir/pass/convert_op_to_communication.cpp
${NVFUSER_SRCS_DIR}/host_ir/pass/stream_parallel_type.cpp
${NVFUSER_SRCS_DIR}/host_ir/allocate_and_deallocate.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/translate_scatter_accumulate.cpp
Expand Down
1 change: 0 additions & 1 deletion csrc/host_ir/allocate_and_deallocate.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
// clang-format on
#pragma once

#include "host_ir/container.h"
#include "optimization_pass.h"

namespace nvfuser::hir {
Expand Down
64 changes: 64 additions & 0 deletions csrc/host_ir/assign_streams.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include "host_ir/assign_streams.h"

#include "host_ir/container.h"
#include "ir/builder.h"

namespace nvfuser::hir {

void AssignStreams::runPass(Fusion* fusion) {
auto* hic = dynamic_cast<HostIrContainer*>(fusion);
NVF_CHECK(hic != nullptr);
FusionGuard fg(hic);

for (auto it = hic->topLevel().exprs().begin();
it != hic->topLevel().exprs().end();) {
auto next_it = std::next(it);

auto* for_loop = dynamic_cast<ForLoop*>(*it);
if (for_loop == nullptr) {
it = next_it;
continue;
}

// We should check that the loop is stream-parallel. This is not necessary
// at this moment because all loops are stream-parallel. This is also hard
Copy link
Collaborator

@Priya2698 Priya2698 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly for this PR, but similar to kir::ForLoop, hir::ForLoop can hold the source iterdomain for this check

// to do because hir::ForLoop doesn't point to the source IterDomain.
Comment on lines +31 to +33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do ALL hir::ForLoops stream-parallelize? Is there no case where we want to sequentially loop in hir? or is this pass triggered by some other condition I'm not seeing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do ALL hir::ForLoops stream-parallelize?

Yes at this moment.

I'm considering separating ParallelType::Stream and ParallelType::HostSerial. The latter doesn't exist today. That's when we'll have to look at the parallel type of the loop index.

Comment on lines +31 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment acknowledges this validation is skipped, but consider adding a TODO or assertion to track this technical debt. Without validation, non-stream-parallel loops could be incorrectly transformed, potentially leading to incorrect synchronization patterns. At minimum, add a NVF_CHECK that verifies the loop meets basic requirements (e.g., has a valid index, start, and stop).

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


Stream* main_stream = IrBuilder::create<Stream>();
hic->topLevel().insert(
it, IrBuilder::create<GetCurrentStream>(main_stream));

// At the beginning of each iteration: set stream and synchronize with main
// stream
auto* worker_stream = IrBuilder::create<Stream>(for_loop->index());
auto* set_stream = IrBuilder::create<SetCurrentStream>(worker_stream);
auto* sync_main = IrBuilder::create<Synchronize>(main_stream);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
Comment on lines +44 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Insertion order creates reversed execution order. The second insert at old_begin places sync_main before set_stream, but the worker stream should be set before synchronization.

Suggested change
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
for_loop->body().insert(old_begin, sync_main);
for_loop->body().insert(old_begin, set_stream);

Comment on lines +45 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Wrong insertion order - both inserts use old_begin, causing sync_main to be placed before set_stream. This means synchronization happens on the main stream before switching to the worker stream. The worker stream should be set first, then synchronize.

Current order: Sync(main) → SetStream(worker)
Required order: SetStream(worker) → Sync(main)

Fix by inserting sync_main at the updated iterator after inserting set_stream

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. Maybe it missed that Scope contains a linked list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! I completely missed that Scope uses a linked list structure. When you insert at old_begin twice, both operations insert at the same position, so the second insert (sync_main) goes before the first insert (set_stream).

Looking at the code, both insertions use the same iterator old_begin, which means:

  1. First insert: set_stream goes at old_begin
  2. Second insert: sync_main goes at the same old_begin position, pushing set_stream forward

This creates the wrong order: sync_mainset_stream

The fix is to update the iterator after the first insertion:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);

Should be:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);

Wait, that's still wrong. Let me fix this properly:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);

Actually, the correct fix is:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
auto new_pos = for_loop->body().exprs().begin();
std::advance(new_pos, 1);
for_loop->body().insert(new_pos, sync_main);

Or more simply, insert sync_main after inserting set_stream by using the updated iterator position.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

insert(old_begin, ...) inserts before old_begin, doesn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! insert(old_begin, ...) inserts before old_begin. So the current code:

for_loop->body().insert(old_begin, set_stream);    // inserts set_stream before old_begin
for_loop->body().insert(old_begin, sync_main);     // inserts sync_main before old_begin

Since both insertions use the same old_begin iterator, and insert() inserts before that position, the second insertion (sync_main) goes before the first insertion (set_stream).

This creates the order: sync_mainset_stream → (original body)

But we want: set_streamsync_main → (original body)

The fix is to update the iterator after the first insertion:

Suggested change
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(for_loop->body().exprs().begin(), sync_main);

This way set_stream gets inserted first, then sync_main gets inserted at the new beginning (which is now after set_stream).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mpirun -np 2 -x NVFUSER_DUMP=host_ir pytest tests/python/multidevice/test_overlap.py -k 'row_parallel_linear_forward and not reference and not benchmark' --only-mpi -vs confirmed that SetCurrentStream is indeed before Synchronize Stream. The code as is is correct.

%HostIrContainer { (T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})) -> (T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1})) :
  T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), mem_type=global, size=( i0 * 2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0x3df7aa70
  FOR i535 from 0 to 3:
    SetCurrentStream to Stream i535
    Synchronize Stream 0x3df7aa70
    T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535)
    T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), mem_type=global, size=( ( ceilDiv(i0, 3) ) * 12 ), zero_init=false, resets_to_zero=false)
    T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1})
       = linear(T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}),
                T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})      )
    T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535)
    Communication 272 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}), backend=NCCL)
    Wait Communication 272
  SetCurrentStream to Stream 0x3df7aa70
  FOR i535 from 0 to 3:
    Synchronize Stream i535
} // %HostIrContainer

Comment on lines +44 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The insertion order places sync_main before set_stream due to how insert works with old_begin. After the first insert at old_begin, set_stream is inserted. When the second insert happens at the same old_begin (which still points to the original beginning), sync_main gets inserted before set_stream.

This creates the sequence: sync_mainset_stream → rest of body. However, the correct order should be set_streamsync_main because you need to set the worker stream first before synchronizing with the main stream on that worker stream.

Suggested change
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto* worker_stream = IrBuilder::create<Stream>(for_loop->index());
auto* set_stream = IrBuilder::create<SetCurrentStream>(worker_stream);
auto* sync_main = IrBuilder::create<Synchronize>(main_stream);
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
auto new_begin = for_loop->body().exprs().begin();
for_loop->body().insert(std::next(new_begin), sync_main);

Comment on lines +44 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using old_begin after the first insert relies on iterator invalidation rules that may lead to undefined behavior. While this works in practice (confirmed by tests), consider capturing the return value for clarity:

Suggested change
auto old_begin = for_loop->body().exprs().begin();
for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(old_begin, sync_main);
auto old_begin = for_loop->body().exprs().begin();
auto it = for_loop->body().insert(old_begin, set_stream);
for_loop->body().insert(std::next(it), sync_main);

This makes the intent clearer and avoids potential issues with iterator invalidation.


// After the loop: create a joining loop to synchronize all worker streams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan on merging this with the above for-loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't convinced myself that will work: http://nv/e-d

hic->topLevel().insert(
next_it, IrBuilder::create<SetCurrentStream>(main_stream));
auto* join_loop = IrBuilder::create<ForLoop>(
for_loop->index(), for_loop->start(), for_loop->stop());
Comment on lines +51 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reusing for_loop->index() for the joining loop could cause issues. The index variable from the main loop will be overwritten during the join loop execution. If any code after the loops depends on the loop index value, this could lead to bugs. Consider creating a new index variable for the joining loop.

hic->topLevel().insert(next_it, join_loop);
Comment on lines +49 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical logic error: The insertion order will result in incorrect code generation. Both statements use insert(next_it, ...) which inserts BEFORE the iterator position. Since both use the same next_it, the join_loop (inserted second) will be placed BEFORE the SetCurrentStream (inserted first), resulting in:

FOR loop (original)
join_loop               ← wrong position
SetCurrentStream        ← wrong position  

The correct order should be SetCurrentStream THEN join_loop, because:

  1. After the first FOR loop exits, the current stream is the last worker stream (not the main stream)
  2. The join_loop synchronizes worker streams on the current stream
  3. Therefore, we must set the current stream to main_stream BEFORE the join_loop

This matches the reference implementation row_parallel_linear_forward_reference in test_overlap.py (lines 182-194) where main_stream.wait_stream(worker_stream) happens on the main_stream.

The fix is to save the iterator returned from the first insert and use it for the second insert:

Suggested change
hic->topLevel().insert(
next_it, IrBuilder::create<SetCurrentStream>(main_stream));
auto* join_loop = IrBuilder::create<ForLoop>(
for_loop->index(), for_loop->start(), for_loop->stop());
hic->topLevel().insert(next_it, join_loop);
// After the loop: create a joining loop to synchronize all worker streams
auto set_main_it = hic->topLevel().insert(
next_it, IrBuilder::create<SetCurrentStream>(main_stream));
auto* join_loop = IrBuilder::create<ForLoop>(
for_loop->index(), for_loop->start(), for_loop->stop());
hic->topLevel().insert(std::next(set_main_it), join_loop);


// In the joining loop: synchronize each worker stream
auto* join_worker_stream = IrBuilder::create<Stream>(join_loop->index());
auto* sync_worker = IrBuilder::create<Synchronize>(join_worker_stream);
join_loop->body().pushBack(sync_worker);

it = next_it;
}
}

} // namespace nvfuser::hir
26 changes: 26 additions & 0 deletions csrc/host_ir/assign_streams.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyright year is 2026 (future year) - should be 2025

* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include "optimization_pass.h"

namespace nvfuser::hir {

// A host IR pass that assigns streams to stream-parallel loops.
class AssignStreams : public OptimizationPass<AssignStreams> {
friend class OptimizationPass<AssignStreams>;

protected:
static void runPass(Fusion* fusion);

static constexpr std::string_view name() {
return "AssignStreams";
}
};

} // namespace nvfuser::hir
1 change: 0 additions & 1 deletion csrc/host_ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "ir/base_nodes.h"
#include "ir/builder.h"
#include "multidevice/communication.h"
#include "scheduler/heuristic.h"

namespace nvfuser {
// This works around a circular dependency: compiled_kernel.h ==>
Expand Down
2 changes: 2 additions & 0 deletions csrc/host_ir/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
#include "host_ir/passes.h"

#include "host_ir/allocate_and_deallocate.h"
#include "host_ir/assign_streams.h"

namespace nvfuser::hir {

void runPasses(HostIrContainer& hic) {
OptimizationPass<hir::AllocateAndDeallocate>::runPass(&hic);
OptimizationPass<hir::AssignStreams>::runPass(&hic);
}

} // namespace nvfuser::hir
1 change: 1 addition & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class Scope {
return std::ssize(exprs_);
}

// Returns an iterator pointing to the inserted expression.
Iterator insert(Iterator pos, Expr* expr);

Iterator pushBack(Expr* e) {
Expand Down
6 changes: 3 additions & 3 deletions tests/python/direct/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from nvfuser_direct import FusionDefinition, ParallelType, DataType


def test_matmul(nvfuser_direct_test):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests didn't use the nvfuser_direct_test fixture.

def test_matmul():
c = 3

with FusionDefinition() as fd:
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_matmul(nvfuser_direct_test):
assert event.input_shapes == [[5, 7], [7, 2], [5, 2]]


def test_two_matmuls_inlinable(nvfuser_direct_test):
def test_two_matmuls_inlinable():
c = 3

with FusionDefinition() as fd:
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_two_matmuls_inlinable(nvfuser_direct_test):
assert event.input_shapes[0][0] == 2


def test_two_matmuls_not_inlinable(nvfuser_direct_test):
def test_two_matmuls_not_inlinable():
c = 3

with FusionDefinition() as fd:
Expand Down
20 changes: 13 additions & 7 deletions tests/python/multidevice/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@ def wrapper(*args, **kwargs):

# Returns two functors, the first with profiler off and the second with profiler
# on. The first functor is usually used for warmup and the second for actual
# benchmarking. This way, one
# can collect stats of the first few non-warmup benchmark iterations using
# ```bash
# mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi : -np <processes - 1> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi
# ```
# and then display the stats using e.g. `nsys stats --report=cuda_gpu_kern_sum
# report1.nsys-rep`.
# benchmarking. This way, one can collect stats of non-warmup
# benchmark iterations using `nsys profile --capture-range=cudaProfilerApi`.
#
# https://docs.nvidia.com/nsight-systems/UserGuide/index.html#handling-application-launchers-mpirun-deepspeed-etc
# has described several ways to profile multi-process applications.
#
# For single-node profiling, I recommend putting `nsys profile` before
# `mpirun`, e.g., `nsys profile ... mpirun -np 8 ...` instead of `mpirun -np 8
# nsys profile ...` or `mpirun -np 1 nsys profile ... : -np 7 ...`. This config
# tries to collect and align traces on different GPUs so it gives the most
# complete picture. See
# https://github.com/NVIDIA/Fuser/pull/5751/files#r2663586669 for my
# experiment.
def get_benchmark_fns(func):
return get_benchmark_fn(func, profile=False), get_benchmark_fn(func, profile=True)
93 changes: 69 additions & 24 deletions tests/python/multidevice/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,7 @@
from benchmark_utils import get_benchmark_fns


@pytest.mark.mpi
def test_row_parallel_linear_forward(multidevice_test):
# This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
h, s, t = 2, 3, 6
d = multidevice_test.size
if (h * 4) % d != 0:
pytest.skip(
f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
)
assert t % s == 0

mesh = nvfuser.multidevice.DeviceMesh(range(d))

def row_parallel_linear_forward(h, mesh, num_chunks):
with FusionDefinition() as fd:
inp = fd.define_tensor(
shape=[-1, h * 4], contiguity=True, dtype=DataType.BFloat16
Expand All @@ -40,11 +28,11 @@ def test_row_parallel_linear_forward(multidevice_test):
for tv in (inp, weight):
tv.set_device_mesh(mesh)

inp.split(0, s, inner_split=False)
inp.outer_split(0, num_chunks)
inp.axis(0).parallelize(nvfuser.ParallelType.stream)
inp.split(2, d, inner_split=False)
inp.outer_split(2, mesh.size)
inp.axis(2).parallelize(nvfuser.ParallelType.mesh_x)
weight.split(1, d, inner_split=False)
weight.outer_split(1, mesh.size)
weight.axis(1).parallelize(nvfuser.ParallelType.mesh_x)

# Expected pre-segmentation IR:
Expand All @@ -67,22 +55,50 @@ def test_row_parallel_linear_forward(multidevice_test):
# /\.
# s*

# Expected host IR:
# The host IR dumped with NVFUSER_DUMP=host_ir is similar to `row_parallel_linear_forward_reference`:
#
# %HostIrContainer { (T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})) -> (T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1})) :
# T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), mem_type=global, size=( i0 * 2 ), zero_init=false, resets_to_zero=false)
# Stream 0x174e5c80 = GetCurrentStream()
# FOR i535 from 0 to 3:
# T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535)
# SetCurrentStream(Stream i535)
# Synchronize(Stream 0x174e5c80)
# T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535)
# T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), mem_type=global, size=( ( ceilDiv(i0, 3) ) * 12 ), zero_init=false, resets_to_zero=false)
# T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1})
# = linear(T4_l___bfloat[istreamIdx31{3}, ideviceIdx.x33{2}, iS32{( ceilDiv(i0, 3) )}, iS34{4}] (DeviceMesh{0 1}),
# = linear(T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}),
# T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1}) )
# T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535)
# Communication 250 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, iS36{2}] (DeviceMesh{0 1}), backend=NCCL)
# Wait Communication 250
# T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535)
# Communication 272 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}), backend=NCCL)
# Wait(Communication 272)
# SetCurrentStream(Stream 0x174e5c80)
# FOR i535 from 0 to 3:
# Synchronize(Stream i535)
# } // %HostIrContainer

inp_ref = torch.randint(-2, 3, (t, h * 4), dtype=torch.int32).to(torch.bfloat16)
weight_ref = torch.randint(-2, 3, (h, h * 4), dtype=torch.int32).to(torch.bfloat16)
return fd


@pytest.mark.mpi
def test_row_parallel_linear_forward(multidevice_test):
# This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
h, s, t = 2, 3, 6
d = multidevice_test.size
if (h * 4) % d != 0:
pytest.skip(
f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
)
assert t % s == 0

mesh = nvfuser.multidevice.DeviceMesh(range(d))
fd = row_parallel_linear_forward(h, mesh, s)

inp_ref = torch.testing.make_tensor(t, h * 4, dtype=torch.int32, device="cpu").to(
torch.bfloat16
)
weight_ref = torch.testing.make_tensor(
h, h * 4, dtype=torch.int32, device="cpu"
).to(torch.bfloat16)
out_ref = torch.nn.functional.linear(inp_ref, weight_ref)

inp = multidevice_test.shard_tensor(inp_ref, -1, mesh)
Expand All @@ -105,6 +121,35 @@ def test_row_parallel_linear_forward(multidevice_test):
assert event.input_shapes == [[m, k], [k, n], [m, n]]


@pytest.mark.mpi
@pytest.mark.benchmark
@pytest.mark.parametrize("s", [1, 2, 4])
def test_row_parallel_linear_forward_benchmark(multidevice_test, benchmark, s):
# This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
h, t = 8192, 8192
d = multidevice_test.size
if (h * 4) % d != 0:
pytest.skip(
f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
)
assert t % s == 0

mesh = nvfuser.multidevice.DeviceMesh(range(d))
fd = row_parallel_linear_forward(h, mesh, s)

inp_ref = torch.randn(t, h * 4, dtype=torch.bfloat16, device="cpu")
weight_ref = torch.randn(h, h * 4, dtype=torch.bfloat16, device="cpu")

inp = multidevice_test.shard_tensor(inp_ref, -1, mesh)
weight = multidevice_test.shard_tensor(weight_ref, -1, mesh)

warmup_fn, benchmark_fn = get_benchmark_fns(
lambda: fd.execute([inp, weight], _enable_options=["host_ir_lowering"])
)
warmup_fn()
benchmark.pedantic(benchmark_fn, rounds=5)


# The caching allocator in PyTorch can't cache buffers across streams, so we
# have to reuse streams to avoid repeated cudaMalloc. torch.cuda.Stream() is
# backed by a stream pool as well but I failed to find a way to set its size.
Expand Down