Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,15 @@ class RunResult:
"helion_addmm_tritonbench-speedup": "helion_speedup",
"helion_addmm_tritonbench-accuracy": "helion_accuracy",
},
"addmm-bwd": {
"aten_addmm": "baseline",
"triton_addmm-speedup": "triton_speedup",
"triton_addmm-accuracy": "triton_accuracy",
"pt2_addmm_maxautotune-speedup": "torch_compile_speedup",
"pt2_addmm_maxautotune-accuracy": "torch_compile_accuracy",
"helion_addmm_tritonbench-speedup": "helion_speedup",
"helion_addmm_tritonbench-accuracy": "helion_accuracy",
},
# "ragged_attention": {
# "triton_ragged_attention-speedup": "triton_speedup",
# "triton_ragged_attention-accuracy": "triton_accuracy",
Expand Down Expand Up @@ -562,6 +571,15 @@ class RunResult:
"helion_matmul_tritonbench-speedup": "helion_speedup",
"helion_matmul_tritonbench-accuracy": "helion_accuracy",
},
"gemm-bwd": {
"aten_matmul": "baseline",
"triton_tutorial_matmul-speedup": "triton_speedup",
"triton_tutorial_matmul-accuracy": "triton_accuracy",
"pt2_triton_matmul-speedup": "torch_compile_speedup",
"pt2_triton_matmul-accuracy": "torch_compile_accuracy",
"helion_matmul_tritonbench-speedup": "helion_speedup",
"helion_matmul_tritonbench-accuracy": "helion_accuracy",
},
"fp8_gemm": {
"torch_fp8_gemm": "baseline",
f"{'blackwell_persistent_tma' if IS_B200 else 'triton_tma_persistent'}_fp8_gemm-speedup": "triton_speedup",
Expand Down
170 changes: 39 additions & 131 deletions examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,130 +58,6 @@ def matmul(
return out


@helion.kernel
def matmul_bwd(
grad_out: Tensor, # [m, n] gradient w.r.t output
mat1: Tensor, # [m, k] first matrix
mat2: Tensor, # [k, n] second matrix
) -> tuple[Tensor, Tensor]:
"""
Backward pass for matrix multiplication following Triton reference pattern.

For C = A @ B, given grad_C, computes:
- grad_A = grad_C @ B.T
- grad_B = A.T @ grad_C

Args:
grad_out: Gradient w.r.t output [m, n]
mat1: First matrix [m, k]
mat2: Second matrix [k, n]

Returns:
tuple[Tensor, Tensor]: (grad_mat1, grad_mat2)
"""
# Get all dimensions first
m, n = grad_out.size()
m2, k = mat1.size()
k2, n2 = mat2.size()

# All assertions at the top
assert m == m2 and n == n2 and k == k2, "Size mismatch in matmul backward"

# Declare ALL output tensors at the top before any loops
grad_mat1 = torch.empty_like(mat1)
grad_mat2 = torch.empty_like(mat2)

# First loop block: compute grad_mat1 = grad_out @ mat2.T
for tile_m1, tile_k1 in hl.tile([m, k]):
acc1 = hl.zeros([tile_m1, tile_k1], dtype=torch.float32)
for tile_n1 in hl.tile(n):
# Need mat2.T: mat2 is [k, n], so mat2[tile_k, tile_n].T gives [tile_n, tile_k]
acc1 = torch.addmm(
acc1, grad_out[tile_m1, tile_n1], mat2[tile_k1, tile_n1].T
)
grad_mat1[tile_m1, tile_k1] = acc1.to(mat1.dtype)

# Second loop block: compute grad_mat2 = mat1.T @ grad_out
for tile_k2, tile_n2 in hl.tile([k, n]):
acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32)
for tile_m2 in hl.tile(m):
# Need mat1.T: mat1 is [m, k], so mat1[tile_m, tile_k].T gives [tile_k, tile_m]
acc2 = torch.addmm(
acc2, mat1[tile_m2, tile_k2].T, grad_out[tile_m2, tile_n2]
)
grad_mat2[tile_k2, tile_n2] = acc2.to(mat2.dtype)

return grad_mat1, grad_mat2


@helion.kernel
def addmm_bwd(
grad_out: Tensor, # [m, n] gradient w.r.t output
bias: Tensor, # [m, n] or broadcastable bias tensor
mat1: Tensor, # [m, k] first matrix
mat2: Tensor, # [k, n] second matrix
alpha: float = 1.0, # scalar multiplier for matmul
beta: float = 1.0, # scalar multiplier for bias
) -> tuple[Tensor, Tensor, Tensor]:
"""
Backward pass for addmm operation following Triton reference pattern.

Forward: output = beta * bias + alpha * (mat1 @ mat2)

Based on the Triton kernel analysis:
- grad_input = beta * grad_out (with proper reduction for broadcasting)
- grad_mat1 = alpha * (grad_out @ mat2.T)
- grad_mat2 = alpha * (mat1.T @ grad_out)

Args:
grad_out: Gradient w.r.t output [m, n]
bias: Bias tensor [m, n] (or broadcastable)
mat1: First matrix [m, k]
mat2: Second matrix [k, n]
alpha: Scalar multiplier for matmul
beta: Scalar multiplier for bias

Returns:
tuple[Tensor, Tensor, Tensor]: (grad_input, grad_mat1, grad_mat2)
"""
# Get all dimensions first
m, n = grad_out.size()
m2, k = mat1.size()
k2, n2 = mat2.size()

# All assertions at the top
assert m == m2 and n == n2 and k == k2, "Size mismatch in addmm backward"

# Declare ALL output tensors at the top before any loops
grad_input = torch.empty_like(bias)
grad_mat1 = torch.empty_like(mat1)
grad_mat2 = torch.empty_like(mat2)

# Handle grad_input = beta * grad_out (assuming same shape for now)
for tile_m3, tile_n3 in hl.tile([m, n]):
grad_input[tile_m3, tile_n3] = beta * grad_out[tile_m3, tile_n3]

# First loop block: compute grad_mat1 = alpha * (grad_out @ mat2.T)
for tile_m1, tile_k1 in hl.tile([m, k]):
acc1 = hl.zeros([tile_m1, tile_k1], dtype=torch.float32)
for tile_n1 in hl.tile(n):
acc1 = torch.addmm(
acc1, grad_out[tile_m1, tile_n1], mat2[tile_k1, tile_n1].T
)
grad_mat1[tile_m1, tile_k1] = (alpha * acc1).to(mat1.dtype)

# Second loop block: compute grad_mat2 = alpha * (mat1.T @ grad_out)
for tile_k2, tile_n2 in hl.tile([k, n]):
acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32)
for tile_m2 in hl.tile(m):
acc2 = torch.addmm(
acc2, mat1[tile_m2, tile_k2].T, grad_out[tile_m2, tile_n2]
)
grad_mat2[tile_k2, tile_n2] = (alpha * acc2).to(mat2.dtype)

return grad_input, grad_mat1, grad_mat2


# %%
class MatMulFunction(torch.autograd.Function):
@staticmethod
Expand All @@ -200,10 +76,24 @@ def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Tensor,
) -> tuple[Tensor | None, Tensor | None]:
"""Backward pass for matrix multiplication."""
"""
Backward pass for matrix multiplication.

For C = A @ B, given grad_C:
- grad_A = grad_C @ B.T
- grad_B = A.T @ grad_C

We reuse the forward matmul kernel for both computations.
"""
grad_out = grad_outputs[0]
mat1, mat2 = ctx.saved_tensors
grad_mat1, grad_mat2 = matmul_bwd(grad_out, mat1, mat2)

# grad_mat1 = grad_out @ mat2.T
grad_mat1 = matmul(grad_out, mat2.T)

# grad_mat2 = mat1.T @ grad_out
grad_mat2 = matmul(mat1.T, grad_out)

return grad_mat1, grad_mat2


Expand Down Expand Up @@ -242,15 +132,33 @@ def backward(
ctx: Any, # noqa: ANN401
*grad_outputs: Tensor,
) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]:
"""Backward pass for addmm operation."""
"""
Backward pass for addmm operation.

Forward: output = beta * bias + alpha * (mat1 @ mat2)

Given grad_out:
- grad_bias = beta * grad_out
- grad_mat1 = alpha * (grad_out @ mat2.T)
- grad_mat2 = alpha * (mat1.T @ grad_out)

We reuse the forward matmul kernel for both matrix gradient computations.
"""
grad_out = grad_outputs[0]
bias, mat1, mat2 = ctx.saved_tensors
alpha = ctx.alpha
beta = ctx.beta
grad_input, grad_mat1, grad_mat2 = addmm_bwd(
grad_out, bias, mat1, mat2, alpha, beta
)
return grad_input, grad_mat1, grad_mat2, None, None

# grad_bias = beta * grad_out
grad_bias = beta * grad_out

# grad_mat1 = alpha * (grad_out @ mat2.T)
grad_mat1 = alpha * matmul(grad_out, mat2.T)

# grad_mat2 = alpha * (mat1.T @ grad_out)
grad_mat2 = alpha * matmul(mat1.T, grad_out)

return grad_bias, grad_mat1, grad_mat2, None, None


def addmm_autograd(
Expand Down
Loading