-
Notifications
You must be signed in to change notification settings - Fork 74
Stream-parallelize loops #5751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stream-parallelize loops #5751
Changes from all commits
776a102
2a86e82
2ddc323
f595406
1095598
2588deb
b09ab67
7ae4e9e
b4adbf6
bb87f6e
4868c59
9fee753
72b0c67
0ddbea1
28824b2
c431253
951537d
0c3014f
30a835c
fe26a6d
6ad8d0a
215a166
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,64 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // clang-format off | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * SPDX-License-Identifier: BSD-3-Clause | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // clang-format on | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "host_ir/assign_streams.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "host_ir/container.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "ir/builder.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace nvfuser::hir { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void AssignStreams::runPass(Fusion* fusion) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* hic = dynamic_cast<HostIrContainer*>(fusion); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NVF_CHECK(hic != nullptr); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FusionGuard fg(hic); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto it = hic->topLevel().exprs().begin(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| it != hic->topLevel().exprs().end();) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto next_it = std::next(it); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* for_loop = dynamic_cast<ForLoop*>(*it); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (for_loop == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| it = next_it; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // We should check that the loop is stream-parallel. This is not necessary | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // at this moment because all loops are stream-parallel. This is also hard | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not strictly for this PR, but similar to kir::ForLoop, hir::ForLoop can hold the source iterdomain for this check |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // to do because hir::ForLoop doesn't point to the source IterDomain. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+31
to
+33
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do ALL hir::ForLoops stream-parallelize? Is there no case where we want to sequentially loop in hir? or is this pass triggered by some other condition I'm not seeing?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes at this moment. I'm considering separating ParallelType::Stream and ParallelType::HostSerial. The latter doesn't exist today. That's when we'll have to look at the parallel type of the loop index.
Comment on lines
+31
to
+33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment acknowledges this validation is skipped, but consider adding a TODO or assertion to track this technical debt. Without validation, non-stream-parallel loops could be incorrectly transformed, potentially leading to incorrect synchronization patterns. At minimum, add a Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Stream* main_stream = IrBuilder::create<Stream>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hic->topLevel().insert( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| it, IrBuilder::create<GetCurrentStream>(main_stream)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // At the beginning of each iteration: set stream and synchronize with main | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // stream | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* worker_stream = IrBuilder::create<Stream>(for_loop->index()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* set_stream = IrBuilder::create<SetCurrentStream>(worker_stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* sync_main = IrBuilder::create<Synchronize>(main_stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto old_begin = for_loop->body().exprs().begin(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Priya2698 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for_loop->body().insert(old_begin, set_stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for_loop->body().insert(old_begin, sync_main); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+44
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Insertion order creates reversed execution order. The second
Suggested change
Comment on lines
+45
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Wrong insertion order - both inserts use Current order: Sync(main) → SetStream(worker) Fix by inserting
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't make sense. Maybe it missed that Scope contains a linked list?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're absolutely right! I completely missed that Looking at the code, both insertions use the same iterator
This creates the wrong order: The fix is to update the iterator after the first insertion:
Suggested change
Should be:
Suggested change
Wait, that's still wrong. Let me fix this properly:
Suggested change
Actually, the correct fix is:
Suggested change
Or more simply, insert
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. insert(old_begin, ...) inserts before old_begin, doesn't it?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're absolutely right! for_loop->body().insert(old_begin, set_stream); // inserts set_stream before old_begin
for_loop->body().insert(old_begin, sync_main); // inserts sync_main before old_beginSince both insertions use the same This creates the order: But we want: The fix is to update the iterator after the first insertion:
Suggested change
This way
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Comment on lines
+44
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The insertion order places This creates the sequence:
Suggested change
Comment on lines
+44
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
This makes the intent clearer and avoids potential issues with iterator invalidation. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // After the loop: create a joining loop to synchronize all worker streams | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you plan on merging this with the above for-loop?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't convinced myself that will work: http://nv/e-d |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hic->topLevel().insert( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| next_it, IrBuilder::create<SetCurrentStream>(main_stream)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* join_loop = IrBuilder::create<ForLoop>( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for_loop->index(), for_loop->start(), for_loop->stop()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+51
to
+52
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reusing |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hic->topLevel().insert(next_it, join_loop); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+49
to
+53
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical logic error: The insertion order will result in incorrect code generation. Both statements use The correct order should be SetCurrentStream THEN join_loop, because:
This matches the reference implementation The fix is to save the iterator returned from the first insert and use it for the second insert:
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // In the joining loop: synchronize each worker stream | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* join_worker_stream = IrBuilder::create<Stream>(join_loop->index()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto* sync_worker = IrBuilder::create<Synchronize>(join_worker_stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| join_loop->body().pushBack(sync_worker); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| it = next_it; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace nvfuser::hir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. | ||
wujingyue marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copyright year is 2026 (future year) - should be 2025 |
||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #pragma once | ||
|
|
||
| #include "optimization_pass.h" | ||
|
|
||
| namespace nvfuser::hir { | ||
|
|
||
| // A host IR pass that assigns streams to stream-parallel loops. | ||
| class AssignStreams : public OptimizationPass<AssignStreams> { | ||
| friend class OptimizationPass<AssignStreams>; | ||
|
|
||
| protected: | ||
| static void runPass(Fusion* fusion); | ||
|
|
||
| static constexpr std::string_view name() { | ||
| return "AssignStreams"; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace nvfuser::hir | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| from nvfuser_direct import FusionDefinition, ParallelType, DataType | ||
|
|
||
|
|
||
| def test_matmul(nvfuser_direct_test): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tests didn't use the nvfuser_direct_test fixture. |
||
| def test_matmul(): | ||
| c = 3 | ||
|
|
||
| with FusionDefinition() as fd: | ||
|
|
@@ -46,7 +46,7 @@ def test_matmul(nvfuser_direct_test): | |
| assert event.input_shapes == [[5, 7], [7, 2], [5, 2]] | ||
|
|
||
|
|
||
| def test_two_matmuls_inlinable(nvfuser_direct_test): | ||
| def test_two_matmuls_inlinable(): | ||
| c = 3 | ||
|
|
||
| with FusionDefinition() as fd: | ||
|
|
@@ -97,7 +97,7 @@ def test_two_matmuls_inlinable(nvfuser_direct_test): | |
| assert event.input_shapes[0][0] == 2 | ||
|
|
||
|
|
||
| def test_two_matmuls_not_inlinable(nvfuser_direct_test): | ||
| def test_two_matmuls_not_inlinable(): | ||
| c = 3 | ||
|
|
||
| with FusionDefinition() as fd: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.