Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
54ee909
Cleaned up All Reduce.
neoblizz Oct 9, 2025
25e204e
Apply Ruff auto-fixes
github-actions[bot] Oct 9, 2025
c4e03a0
Naming convention follows all_scatter
neoblizz Oct 9, 2025
425bc33
Init impl.
neoblizz Oct 9, 2025
00de58c
Initial impl.
neoblizz Oct 9, 2025
de4315c
Apply Ruff auto-fixes
github-actions[bot] Oct 9, 2025
a8acd4b
Proper acq/rel semantics.
neoblizz Oct 9, 2025
7def2eb
Merge.
neoblizz Oct 9, 2025
d8386a7
Apply Ruff auto-fixes
github-actions[bot] Oct 9, 2025
6a1f4e2
Hoist `num_xcds` query (#205)
mawad-amd Oct 9, 2025
236edec
Fix reg/reg spill reporting.
neoblizz Oct 9, 2025
1756006
Apply Ruff auto-fixes
github-actions[bot] Oct 9, 2025
53f25d8
Functional.
neoblizz Oct 10, 2025
d8ead1e
Apply Ruff auto-fixes
github-actions[bot] Oct 10, 2025
7e05db7
Merge.
neoblizz Oct 10, 2025
5f5ae50
Apply Ruff auto-fixes
github-actions[bot] Oct 10, 2025
d7c24a5
Fix correctness issue in persistent_all_reduce: initialize accumulato…
Copilot Oct 10, 2025
ded7317
2-kernel working.
neoblizz Oct 10, 2025
e1d3419
Apply Ruff auto-fixes
github-actions[bot] Oct 10, 2025
e6c7ecf
Separate all-reduce.
neoblizz Oct 10, 2025
0178b16
[Performance]: Implement optimized ring-based reduce-scatter algorith…
Copilot Oct 10, 2025
3288c2d
Implement all_reduce-only validation for example 16 (#221)
Copilot Oct 10, 2025
d3eb56a
Apply Ruff auto-fixes
github-actions[bot] Oct 10, 2025
3820e62
Correct algo, almost functional.
neoblizz Oct 11, 2025
6797056
Apply Ruff auto-fixes
github-actions[bot] Oct 11, 2025
875838a
Fix race condition in ring-based all_reduce causing intermittent vali…
Copilot Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,11 @@ def parse_args():

# Best to try 1, 6 or 8
parser.add_argument("--gsize_m", type=int, default=6, help="Grid size M")
parser.add_argument("--two_tiles", type=str, default="True", help="Use two tiles")
parser.add_argument("--num_stages", type=int, default=1, help="Number of stages")
parser.add_argument("--num_warps", type=int, default=8, help="Number of warps")
parser.add_argument("--waves_per_eu", type=int, default=0, help="Waves per execution unit")
parser.add_argument("--mfmaInstrSize", type=int, default=16, help="MFMA instruction size")
parser.add_argument("--kpack", type=int, default=2, help="K packing size")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")

# For All Scatter, use: 288
# For One Shot, use: 256
parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM")
parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs")
parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for GEMM")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand Down Expand Up @@ -107,7 +100,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):

A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype)
B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T
C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype)

args["M"] = args["m"]
args["N"] = args["n"]
Expand All @@ -134,19 +126,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

if args["gemm_sms"] >= args["total_sms"]:
print(f"Invalid number of GEMM SMs. {args['gemm_sms']} >= {args['total_sms']}")
exit(1)

tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)

locks = shmem.zeros((args["gemm_sms"],), device="cuda", dtype=torch.int32)

P = shmem.zeros(
(args["gemm_sms"], args["BLK_M"] * args["BLK_N"]),
device="cuda",
dtype=torch.float32,
)
bias = None

gemm_stream = torch.cuda.Stream()
Expand All @@ -165,11 +144,6 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Timestamps
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
shmem.barrier()
tile_completed.zero_()
shmem.barrier()

def run_experiment():
nonlocal local_C
nonlocal global_C
Expand All @@ -190,24 +164,15 @@ def run_experiment():
local_C,
global_C,
bias,
P,
locks,
tile_completed,
rank,
world_size,
args["gemm_sms"],
args["BLK_M"],
args["BLK_N"],
args["BLK_K"],
args["gsize_m"],
args["two_tiles"],
args["num_stages"],
args["num_warps"],
args["waves_per_eu"],
args["mfmaInstrSize"],
args["kpack"],
shmem.get_heap_bases(),
cu_count,
"gfx942",
args["trace_tiles"],
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
Expand All @@ -228,44 +193,44 @@ def run_experiment():
# Warmup
run_experiment()

shmem.barrier()
preamble()
shmem.barrier()

for k in ["gemm"]:
kernel_timing[k]["ms"] = 0
kernel_timing[k]["experiments"] = 0

if not is_triton_interpret_set():
gemm_registers = matmul.streamk_registers
gemm_spills = matmul.streamk_spills

json_writer.add_field("gemm_registers", gemm_registers)
json_writer.add_field("gemm_spills", gemm_spills)

if args["validate"]:
shmem.info("Validating...")

matmul.set_debug(False)
matmul.set_debug(True)
# Validate global result
success = validate_gemm(A, B, global_C, shmem, atol=2)
passed_str = "passed" if success else "failed"
shmem.info(f"Final C validation {passed_str}.")

# Wait for all to finish validation
shmem.barrier()
json_writer.add_field("success", success)
shmem.info("Validation completed")

json_writer.add_field("success", success)

if not is_triton_interpret_set():
gemm_registers = matmul.streamk_registers
gemm_spills = matmul.streamk_spills

json_writer.add_field("gemm_registers", gemm_registers)
json_writer.add_field("gemm_spills", gemm_spills)

if args["benchmark"]:
matmul.set_debug(False)
shmem.info("Benchmarking...")
perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3)
triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble)
triton_ms = iris.do_bench(run_experiment, shmem.barrier)
triton_tflops = perf(triton_ms)
shmem.info(f"tile matmul + all_reduce (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops")
algo_string = "all_reduce"
shmem.info(f"tile matmul + {algo_string} (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops")

json_writer.add_field("triton_tflops", triton_tflops)
json_writer.add_field("triton_ms", triton_ms)
json_writer.add_field("tflops", triton_tflops)
json_writer.add_field("total_ms", triton_ms)

for k in ["gemm"]:
json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"])
Expand All @@ -280,7 +245,8 @@ def run_experiment():

if args["trace_tiles"] and rank == 0:
gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3
filename = f"gemm_all_reduce_tiles_trace_rank{rank}.json"
algo_string = "all_reduce"
filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json"
timestamps.to_json(filename, gpu_freq)

shmem.barrier()
Expand Down
151 changes: 151 additions & 0 deletions examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import triton
import triton.language as tl
from examples.common.utils import read_realtime

import sys
import os

import iris


@triton.jit()
def persistent_gemm_all_reduce(
A,
B,
C,
c_global,
bias_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_cm_global,
stride_cn_global,
stride_bias,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
NUM_XCDS: tl.constexpr,
BIAS: tl.constexpr,
EVEN_K: tl.constexpr,
heap_bases: tl.tensor,
cur_rank: tl.constexpr,
world_size: tl.constexpr,
COLLECT_TIMESTAMPS: tl.constexpr = False,
mm_begin_timestamp_ptr: tl.tensor = None,
mm_end_timestamp_ptr: tl.tensor = None,
):
pid = tl.program_id(0)

if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)

acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

for tile_id in range(pid, total_tiles, NUM_SMS):
if COLLECT_TIMESTAMPS:
timestamp = read_realtime()
tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp)

num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

rk = tl.arange(0, BLOCK_SIZE_K)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn

tl.assume(pid_m > 0)
tl.assume(pid_n > 0)

loop_k = tl.cdiv(K, BLOCK_SIZE_K)
if not EVEN_K:
loop_k -= 1

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, loop_k):
a = tl.load(tl.multiple_of(A_BASE, (1, 16)))
b = tl.load(tl.multiple_of(B_BASE, (16, 1)))
acc += tl.dot(a, b)
A_BASE += BLOCK_SIZE_K * stride_ak
B_BASE += BLOCK_SIZE_K * stride_bk

if not EVEN_K:
k = loop_k
rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
A_BASE = tl.multiple_of(A_BASE, (1, 16))
B_BASE = tl.multiple_of(B_BASE, (16, 1))
a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0)
b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0)
acc += tl.dot(a, b)

# Accumulator registers with C results
c = acc.to(C.type.element_ty)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

# Add compiler hints
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)

# Define the C-mask (BLOCK_SIZE_M, 1) x (1, BLOCK_SIZE_N)
sub_mask = (rm[:, None] < M) & (rn[None, :] < N)

# Calculate the "global" offset of C based on the rank.
# Note how each GPU is producing the entire output but partial-K.
global_offset = rm[:, None] * stride_cm_global + rn[None, :] * stride_cn_global

# Timestamp for GEMM before store
if COLLECT_TIMESTAMPS:
timestamp = read_realtime()
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)

# Store data to the global result using puts
for remote_rank in range(world_size):
if remote_rank == cur_rank:
# For the current rank, we can use store
tl.atomic_add(c_global + global_offset, c, mask=sub_mask)
else:
iris.atomic_add(
c_global + global_offset,
c,
cur_rank,
remote_rank,
heap_bases,
mask=sub_mask,
)

if COLLECT_TIMESTAMPS:
timestamp = read_realtime()
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)
Loading