-
Notifications
You must be signed in to change notification settings - Fork 19
Let's make Iris' All-reduce the best! #204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the Iris all-reduce implementation by introducing a ring-based all-reduce algorithm alongside the existing atomic-based approach. The changes reorganize existing examples into separate directories and implement a cleaner ring-based communication pattern.
- Splits the atomic-based all-reduce into its own directory (
08_gemm_all_reduce_atomics
) - Creates a new ring-based all-reduce implementation (
15_gemm_all_reduce_ring_based
) - Simplifies the API by removing many configuration parameters and hardcoding architecture-specific values
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py |
Updates wrapper to use ring-based kernel with simplified parameter interface |
examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py |
New ring-based all-reduce kernel implementation using lock-synchronized communication |
examples/15_gemm_all_reduce_ring_based/benchmark.py |
Benchmark script for the new ring-based implementation |
examples/08_gemm_atomics_all_reduce/gemm_atomics_all_reduce.py |
Removes the original atomic-based implementation |
examples/08_gemm_all_reduce_atomics/matmul_wrapper.py |
New wrapper for the refactored atomic-based implementation |
examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py |
Simplified atomic-based all-reduce implementation |
examples/08_gemm_all_reduce_atomics/benchmark.py |
Updated benchmark script for atomic-based approach |
Comments suppressed due to low confidence (1)
examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py:1
- Hardcoded architecture-specific values should be replaced with a proper architecture detection and configuration system. Consider creating a lookup table or configuration class based on the
arch
parameter.
# SPDX-License-Identifier: MIT
# 2) Wait for PREV rank to signal our local flag for this tile | ||
# Spin; single-lane uniform load is fine here. | ||
while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: | ||
pass |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This busy-wait loop could consume significant CPU cycles. Consider adding a small delay or using a more efficient synchronization mechanism to reduce CPU usage during the wait period.
pass | |
tl.device_sleep(100) # Sleep for 100 nanoseconds to reduce busy-wait CPU usage |
Copilot uses AI. Check for mistakes.
# TODO: Use arch-specific values. | ||
num_stages = 2 | ||
num_warps = 8 | ||
waves_per_eu = 0 | ||
mfma = 16 | ||
kpack = 1 |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This TODO comment indicates incomplete implementation. The hardcoded values should be replaced with architecture-specific configuration based on the arch
parameter to ensure optimal performance across different GPU architectures.
Copilot uses AI. Check for mistakes.
@copilot Fix the following issue:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (1)
examples/08_gemm_all_reduce_atomics/matmul_wrapper.py:1
- The attribute names
streamk_registers
andstreamk_spills
are inconsistent with the new API that usesget_matmul_registers()
andget_matmul_spills()
methods. These should be updated to use the getter methods or the attributes should be renamed to match the new pattern.
# SPDX-License-Identifier: MIT
exit(1) | ||
|
||
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) | ||
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The locks tensor is allocated with size total_tiles
but should be allocated based on the number of tiles that will actually be processed by each rank. This could lead to memory waste or insufficient locks for proper synchronization.
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) | |
local_blocks_M = total_blocks_M | |
local_blocks_N = total_blocks_N | |
local_tiles = local_blocks_M * local_blocks_N | |
locks = shmem.zeros((local_tiles,), device="cuda", dtype=torch.int32) |
Copilot uses AI. Check for mistakes.
tl.debug_barrier() | ||
|
||
# Signal "ready" by setting NEXT rank's flag for this tile to 1 | ||
iris.atomic_xchg(locks + tile_id, 1, cur_rank, next_rank, heap_bases, sem="release", scope="sys") |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The atomic exchange operation uses locks + tile_id
but the locks array may not have sufficient elements for all possible tile_id values across different ranks, potentially causing out-of-bounds access.
Copilot uses AI. Check for mistakes.
…r with local data (#217) Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: neoblizz <[email protected]>
…m for all_reduce (#225) Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: neoblizz <[email protected]>
Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: neoblizz <[email protected]> Co-authored-by: Muhammad Osama <[email protected]>
…dation failures (#228) Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: neoblizz <[email protected]>
This pull request introduces a new distributed GEMM benchmark example using all-reduce atomics, refactors and simplifies the matmul wrapper, and adds a Triton kernel for persistent GEMM with all-reduce atomics. The changes improve clarity, usability, and maintainability of the distributed GEMM benchmarking code.
New distributed GEMM benchmark and kernel:
benchmark.py
for the all-reduce atomics example, supporting validation, benchmarking, and tile tracing, with comprehensive argument parsing and logging.persistent_gemm_all_reduce
ingemm_all_reduce_atomics.py
implementing persistent GEMM with all-reduce atomics across ranks, including timestamp tracing and atomic operations for distributed accumulation.Refactoring and simplification of matmul wrapper:
matmul_wrapper.py
to remove unused parameters and logic related to streamK and synchronization primitives, simplifying function signatures and internal calculations. [1] [2] [3] [4] [5] [6]gemm_atomics_all_reduce
) with the newgemm_all_reduce_atomics
kernel import inmatmul_wrapper.py
.examples/08_gemm_all_reduce_atomics/
, updating the file structure for clarity. [1] [2]These changes collectively provide a cleaner, more maintainable, and feature-rich distributed GEMM benchmarking example using Triton and atomics.