Skip to content

Add hl.wait & AllGather Matmul example (via hl_ext helper). #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: joydddd/stack/9
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions examples/all_gather_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from __future__ import annotations

import os
from typing import Any

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import helion
import helion.language as hl


def copy_engine_all_gather_w_progress(
output: torch.Tensor,
inp: torch.Tensor, # Must be symmetric tensor
progress: torch.Tensor,
splits_per_rank: int,
backend_stream: torch.cuda.Stream | None = None,
) -> torch.cuda.Stream:
backend_stream = symm_mem._get_backend_stream(priority=-1)
assert inp.is_contiguous()
symm_mem_group = dist.group.WORLD
if symm_mem_group is None:
raise RuntimeError("No symmetric memory group available")
symm_mem_hdl = symm_mem.rendezvous(inp, group=symm_mem_group)
assert symm_mem_hdl is not None

rank = symm_mem_hdl.rank
world_size = symm_mem_hdl.world_size

assert inp.numel() % splits_per_rank == 0
assert progress.numel() >= world_size * splits_per_rank

output_shape = list(inp.shape)
output_shape[0] *= world_size
assert list(output.shape) == output_shape, (list(output.shape), output_shape)

chunks = output.chunk(world_size * splits_per_rank)

symm_mem_hdl.barrier()
backend_stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(backend_stream):
for step in range(world_size):
src_rank = (rank + step + 1) % world_size
for split_id in range(splits_per_rank):
src_buf = symm_mem_hdl.get_buffer(
src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
)
chunks[src_rank * splits_per_rank + split_id].copy_(src_buf)
# cuStreamWriteValue32 issues a system level fence before the write
symm_mem_hdl.stream_write_value32(
progress,
offset=src_rank * splits_per_rank + split_id,
val=1,
)
symm_mem_hdl.barrier()

return backend_stream


@helion.jit(
config=helion.Config(
block_sizes=[128, 256, 64],
num_warps=8,
num_stages=3,
indexing="block_ptr",
),
static_shapes=True,
)
def helion_matmul_w_progress(
a: torch.Tensor,
a_shared: torch.Tensor,
b: torch.Tensor,
progress: torch.Tensor,
SPLITS_PER_RANK: int,
RANK: int,
) -> torch.Tensor:
M, K = a.size()
K2, N = b.size()
assert K2 == K, f"size mismatch {K2} != {K}"

out = torch.empty(
[M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device
)

M_per_rank = a_shared.size(0)

for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
hl.wait(
progress,
[
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
],
signal=1,
update=None,
op="ld",
scope="gpu",
sem="acquire",
)
for tile_k in hl.tile(K):
# TODO(joydddd): use a_shared and skipp barrier when data is available on local rank.
# if tile_k.begin // M_per_rank == RANK:
# acc = torch.addmm(acc, a_shared[tile_m.index - RANK * M_per_rank, tile_k], b[tile_k, tile_n])
# else:
# hl.wait(progress, [tile_m.begin // (M_per_rank // SPLITS_PER_RANK),], signal=1, update=None, op="ld", scope="gpu", sem="acquire")
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
out[tile_m, tile_n] = acc
return out


def helion_all_gather_matmul(
a_shared: torch.Tensor,
b: torch.Tensor,
a_out: torch.Tensor | None = None,
progress: torch.Tensor | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, torch.Tensor]:
configs = {
"SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1),
"BLOCK_SIZE_M": kwargs.get("block_size_m", 128),
"BLOCK_SIZE_N": kwargs.get("block_size_n", 256),
"BLOCK_SIZE_K": kwargs.get("block_size_k", 64),
"GROUP_SIZE_M": kwargs.get("group_size_m", 4),
"num_stages": kwargs.get("num_stages", 3),
"num_warps": kwargs.get("num_warps", 8),
}

symm_mem_group = dist.group.WORLD
if symm_mem_group is None:
raise RuntimeError("No symmetric memory group available")

symm_mem_hdl = symm_mem.rendezvous(a_shared, group=symm_mem_group)

a_shape = list(a_shared.shape)
a_shape[0] *= symm_mem_hdl.world_size

configs["RANK"] = symm_mem_hdl.rank
configs["WORLD_SIZE"] = symm_mem_hdl.world_size
if (
configs["SPLITS_PER_RANK"]
* configs["WORLD_SIZE"]
* configs["BLOCK_SIZE_M"]
* configs["GROUP_SIZE_M"]
> a_shape[0]
):
configs["GROUP_SIZE_M"] = 1
configs["SPLITS_PER_RANK"] = 1

configs["COMM_BLOCK_SIZE_M"] = (
a_shape[0] // configs["WORLD_SIZE"] // configs["SPLITS_PER_RANK"]
)
assert (
configs["COMM_BLOCK_SIZE_M"]
% (configs["BLOCK_SIZE_M"] * configs["GROUP_SIZE_M"])
== 0
)

if a_out is None:
a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device)

if progress is None:
progress = torch.zeros(
symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"],
dtype=torch.uint32,
device=a_shared.device,
)
else:
progress.fill_(
0
) # Reset progress to 0. Maybe we should reset inside the kernel using cas?

backend_stream = copy_engine_all_gather_w_progress(
a_out, a_shared, progress, configs["SPLITS_PER_RANK"]
)

c = helion_matmul_w_progress(
a_out,
a_shared,
b,
progress,
SPLITS_PER_RANK=configs["SPLITS_PER_RANK"],
RANK=configs["RANK"],
)
assert type(c) is torch.Tensor

torch.cuda.current_stream().wait_stream(backend_stream)

return a_out, c


def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None:
a_shared = symm_mem.empty(
M // world_size, K, dtype=torch.bfloat16, device=device
).normal_()
b = torch.randn((K, N), device="cuda", dtype=torch.bfloat16).T.contiguous().T

a_out, c = helion_all_gather_matmul(a_shared, b)

golden_a = a_shared.clone()
dist_group = dist.group.WORLD
if dist_group is None:
raise RuntimeError("No distributed group available")
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul(
golden_a, [b], gather_dim=0, group_name=dist_group.group_name
)
torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1)
torch.testing.assert_close(a_out, ag_golden)


def main() -> None:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.manual_seed(42 + rank)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group("nccl")
test(4096, 6656, 16384, world_size, device)

dist.destroy_process_group()


if __name__ == "__main__":
"""
torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 examples/all_gather_matmul.py
"""
main()
1 change: 1 addition & 0 deletions helion/_compiler/output_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"triton_helpers": "from torch._inductor.runtime import triton_helpers",
"tl_math": "from torch._inductor.runtime.triton_helpers import math as tl_math",
"libdevice": "from torch._inductor.runtime.triton_compat import libdevice",
"hl_ext": "from helion import _triton_ext as hl_ext",
}

disallowed_names: dict[str, None] = dict.fromkeys(
Expand Down
6 changes: 6 additions & 0 deletions helion/_triton_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import annotations

from .gmem_barrier import _triton_wait_multiple_signal
from .gmem_barrier import _triton_wait_signal

__all__ = ["_triton_wait_multiple_signal", "_triton_wait_signal"]
85 changes: 85 additions & 0 deletions helion/_triton_ext/gmem_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# pyre-ignore-all-errors[2] # ignore Missing parameter annotation
from __future__ import annotations

import triton
import triton.language as tl


@triton.jit
def _triton_wait_signal(
addr,
expect: tl.constexpr, # wait until lock is set to expect
update: tl.constexpr, # update the lock once it is aquired.
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
) -> None:
"""
Wait for a global memory barrier to reach the expected state.

This function implements a spin-wait loop that continuously checks a memory location
until it reaches the expected value, providing synchronization across GPU threads.

Args:
addr: Memory address of the barrier to wait on (Must be a scalar)
expect: Expected value to wait for
update: Update the barrier with once acquired
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
scope: Scope of the atomic operation. Options: "gpu", "sys"
op: Atomic operation type: "ld", "atomic_cas"
"""
tl.static_assert(
addr.type.is_ptr(),
"Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ",
)

tl.static_assert(
sem == "acquire" or sem == "relaxed",
"Invalid memory semantic. options: 'acquire', 'relaxed'. ",
)
tl.static_assert(
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
)
tl.static_assert(
op == "ld" or op == "atomic_cas",
"Invalid op. options: 'ld', 'atomic_cas'. ",
)

# Spin-wait loop:
# Uses atomic_add with update=0 for ld.global.{sem}.{scope}
# Triton generates smem broadcasting of tl.atomic_add return value in ptx,
# but it is optimized away by ptxas in SASS, hence no performance overhead.
if op == "ld":
tl.static_assert(
update == 0, "ld wait on gmem_barriers cannot update the lock. "
)
while tl.atomic_add(addr, 0, sem=sem, scope=scope) != expect:
pass
elif op == "atomic_cas":
while tl.atomic_cas(addr, expect, update, sem=sem, scope=scope) != expect:
pass
else:
raise NotImplementedError(
f"Unsupported op '{op}' for wait signal on gmem barrier. "
)

if not skip_sync:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)
# tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?)


@triton.jit
def _triton_wait_multiple_signal(
addr,
expect: tl.constexpr, # wait until lock is set to expect
update: tl.constexpr, # update the lock once it is aquired.
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
) -> None:
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .memory_ops import atomic_add as atomic_add
from .memory_ops import load as load
from .memory_ops import store as store
from .signal_wait import wait as wait
from .tile_ops import tile_begin as tile_begin
from .tile_ops import tile_block_size as tile_block_size
from .tile_ops import tile_end as tile_end
Expand Down
Loading
Loading