Skip to content

Refactor examples to use run_example helper #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions examples/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
import helion.language as hl


Expand All @@ -23,17 +24,9 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:


def check(m: int, n: int) -> None:
from triton.testing import do_bench

x = torch.randn([m, n], device="cuda", dtype=torch.float16)
y = torch.randn([m, n], device="cuda", dtype=torch.float16)
result = add(x, y)
torch.testing.assert_close(result, x + y, rtol=1e-2, atol=1e-1)
sec = do_bench(lambda: add(x, y))
baseline_sec = do_bench(lambda: torch.add(x, y))
print(
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
)
run_example(add, torch.add, (x, y))


def main() -> None:
Expand Down
57 changes: 25 additions & 32 deletions examples/attention.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from __future__ import annotations

import math
from typing import Callable
from typing import cast

import torch
from torch.nn.attention.flex_attention import flex_attention

import helion
from helion._testing import run_example
import helion.language as hl


@helion.kernel(
config=helion.Config(
# This config was autotuned on a 3090, it won't be fast for other architectures
block_sizes=[128, 64],
num_warps=4,
# This config was autotuned on a 5090, it won't be fast for other cards
block_sizes=[128, 16],
loop_orders=[[0, 1]],
l2_groupings=[2],
num_warps=2,
num_stages=3,
indexing="block_ptr",
indexing="pointer",
),
# Static shapes provides a speedup for attention
static_shapes=True,
Expand Down Expand Up @@ -82,36 +87,24 @@ def test(
for _ in range(3)
]

# reference implementation
p = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
p = torch.softmax(p.float(), dim=-1).to(dtype)
ref_out = torch.matmul(p, v)
def ref_attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
"""Reference manual attention implementation"""
p = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
p = torch.softmax(p.float(), dim=-1).to(dtype)
return torch.matmul(p, v)

# flex attention version
# TODO(jansel): turn the above kernel into a flex attention kernel
flex_compiled = torch.compile(flex_attention, fullgraph=True)
flex_out = flex_compiled(q, k, v)
torch.testing.assert_close(flex_out, ref_out, atol=1e-2, rtol=1e-2)

# sdpa version
sdpa_out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
torch.testing.assert_close(sdpa_out, ref_out, atol=1e-2, rtol=1e-2)

# helion version
hl_out = attention(q, k, v)
torch.testing.assert_close(hl_out, ref_out, atol=1e-2, rtol=1e-2)

# benchmark
from triton.testing import do_bench

spda_sec = do_bench(
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v)
)
flex_sec = do_bench(lambda: flex_compiled(q, k, v))
helion_sec = do_bench(lambda: attention(q, k, v))
print(
f"Helion time: {helion_sec:.4f}ms, flex time: {flex_sec:.4f}, torch time: {spda_sec:.4f}"
flex_compiled = cast(
"Callable[..., torch.Tensor]", torch.compile(flex_attention, fullgraph=True)
)
baselines = {
"torch": torch.nn.functional.scaled_dot_product_attention,
"flex": flex_compiled,
"ref": ref_attention,
}

run_example(attention, baselines, (q, k, v))


def main() -> None:
Expand Down
11 changes: 2 additions & 9 deletions examples/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
import helion.language as hl


Expand All @@ -26,17 +27,9 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:


def check(b: int, m: int, k: int, n: int) -> None:
from triton.testing import do_bench

x = torch.randn([b, m, k], device="cuda", dtype=torch.float16)
y = torch.randn([b, k, n], device="cuda", dtype=torch.float16)
result = bmm(x, y)
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1e-1)
sec = do_bench(lambda: bmm(x, y))
baseline_sec = do_bench(lambda: torch.bmm(x, y))
print(
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
)
run_example(bmm, torch.bmm, (x, y))


def main() -> None:
Expand Down
12 changes: 2 additions & 10 deletions examples/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
import helion.language as hl


Expand Down Expand Up @@ -31,18 +32,9 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:


def main() -> None:
from triton.testing import do_bench

x = torch.randn([1500, 400], device="cuda")
y = torch.randn([1500, 600], device="cuda")
result = concat2d_dim1(x, y)
expected = torch.cat([x, y], dim=1)
torch.testing.assert_close(result, expected)
sec = do_bench(lambda: concat2d_dim1(x, y))
baseline_sec = do_bench(lambda: torch.cat([x, y], dim=1))
print(
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
)
run_example(concat2d_dim1, lambda x, y: torch.cat([x, y], dim=1), (x, y))


if __name__ == "__main__":
Expand Down
11 changes: 3 additions & 8 deletions examples/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
import helion.language as hl


Expand All @@ -24,17 +25,11 @@ def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:


def main() -> None:
from triton.testing import do_bench

num_embeddings, embedding_dim = 16, 64
x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32)
weight = torch.randn([num_embeddings, embedding_dim], device="cuda")
result = embedding(x, weight)
torch.testing.assert_close(result, torch.nn.functional.embedding(x, weight))
sec = do_bench(lambda: embedding(x, weight))
baseline_sec = do_bench(lambda: torch.nn.functional.embedding(x, weight))
print(
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
run_example(
embedding, torch.nn.functional.embedding, (x, weight), atol=0.0, rtol=0.0
)


Expand Down
10 changes: 6 additions & 4 deletions examples/jagged_dense_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
import helion.language as hl

"""
Expand Down Expand Up @@ -110,11 +111,12 @@ def random_jagged_2d(

def main() -> None:
rows, cols = 256, 5000
x = random_jagged_2d(rows, cols, device="cuda")
x_data, x_offsets = random_jagged_2d(rows, cols, device="cuda")
y = torch.randn([rows, cols], device="cuda")
result = jagged_dense_add_2d(*x, y)
expected = jagged_dense_add_2d_reference(*x, y)
torch.testing.assert_close(result, expected)

run_example(
jagged_dense_add_2d, jagged_dense_add_2d_reference, (x_data, x_offsets, y)
)


if __name__ == "__main__":
Expand Down
30 changes: 8 additions & 22 deletions examples/long_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
import helion.language as hl


Expand Down Expand Up @@ -72,31 +73,16 @@ def longsum_manual(x: torch.Tensor) -> torch.Tensor:


def check(m: int, n: int) -> None:
from triton.testing import do_bench

x = torch.randn([m, n], device="cuda", dtype=torch.float32)

helion_out = longsum(x)
torch.testing.assert_close(helion_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
print("✅ Results Match ✅ naive reduction")
# Test all three kernel variants against the baseline
kernels = {
"helion naive": longsum,
"helion loop": longsum_w_red_loop,
"helion manual": longsum_manual,
}

helion_red_loop_out = longsum_w_red_loop(x)
torch.testing.assert_close(
helion_red_loop_out, baseline_sum(x), rtol=1e-2, atol=1e-1
)
print("✅ Results Match ✅ Reduction Loop")

helion_manual_out = longsum_manual(x)
torch.testing.assert_close(helion_manual_out, baseline_sum(x), rtol=1e-2, atol=1e-1)
print("✅ Results Match ✅ Manual Reduction Loop")

sec = do_bench(lambda: longsum(x))
loop_sec = do_bench(lambda: longsum_w_red_loop(x))
manual_loop_sec = do_bench(lambda: longsum_manual(x))
baseline_sec = do_bench(lambda: baseline_sum(x))
print(
f"Helion Naive time: {sec:.4f}ms, Helion Looped Time: {loop_sec:.4f}, Helion Manual Loop Time: {manual_loop_sec:.4f} torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x {baseline_sec / loop_sec:.2f}x {baseline_sec / manual_loop_sec:.2f}x"
)
run_example(kernels, baseline_sum, (x,))


def main() -> None:
Expand Down
11 changes: 2 additions & 9 deletions examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
import helion.language as hl


Expand All @@ -24,17 +25,9 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:


def check(m: int, k: int, n: int) -> None:
from triton.testing import do_bench

x = torch.randn([m, k], device="cuda", dtype=torch.float16)
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
result = matmul(x, y)
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1e-1)
sec = do_bench(lambda: matmul(x, y))
baseline_sec = do_bench(lambda: torch.matmul(x, y))
print(
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
)
run_example(matmul, torch.matmul, (x, y))


def main() -> None:
Expand Down
12 changes: 2 additions & 10 deletions examples/matmul_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
import helion.language as hl


Expand Down Expand Up @@ -51,20 +52,11 @@ def matmul_layernorm_pytorch(


def check(m: int, k: int, n: int) -> None:
from triton.testing import do_bench

x = torch.randn([m, k], device="cuda", dtype=torch.float16)
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
weight = torch.randn([n], device="cuda", dtype=torch.float16)
bias = torch.randn([n], device="cuda", dtype=torch.float16)
result = matmul_layernorm(x, y, weight, bias)
expected = matmul_layernorm_pytorch(x, y, weight, bias)
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-1)
sec = do_bench(lambda: matmul_layernorm(x, y, weight, bias))
baseline_sec = do_bench(lambda: matmul_layernorm_pytorch(x, y, weight, bias))
print(
f"Helion time: {sec:.4f}s, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
)
run_example(matmul_layernorm, matmul_layernorm_pytorch, (x, y, weight, bias))


def main() -> None:
Expand Down
11 changes: 2 additions & 9 deletions examples/matmul_split_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

import helion
from helion._testing import run_example
from helion.autotuner import PowerOfTwoFragment
import helion.language as hl

Expand All @@ -27,17 +28,9 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:


def check(m: int, k: int, n: int) -> None:
from triton.testing import do_bench

x = torch.randn([m, k], device="cuda", dtype=torch.float16)
y = torch.randn([k, n], device="cuda", dtype=torch.float16)
result = matmul_split_k(x, y)
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1)
sec = do_bench(lambda: matmul_split_k(x, y))
baseline_sec = do_bench(lambda: torch.matmul(x, y))
print(
f"Helion time: {sec:.4f}ms, torch time: {baseline_sec:.4f}, speedup: {baseline_sec / sec:.2f}x"
)
run_example(matmul_split_k, torch.matmul, (x, y), atol=1)


def main() -> None:
Expand Down
Loading
Loading