Skip to content

x[i] returns scalar when i=scalar #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from .ast_extension import LoopType
from .ast_extension import NodeVisitor
from .ast_extension import create
from .ast_extension import expr_from_string
from .ast_read_writes import ReadWrites
from .compile_environment import CompileEnvironment
from .host_function import HostFunction
Expand Down Expand Up @@ -239,13 +238,6 @@ def name(self) -> str:
def codegen(self, state: CodegenState) -> list[object]:
test = state.ast_arg(0)

test_proxy = state.proxy_arg(0)
if isinstance(test_proxy, torch.Tensor) and test_proxy.numel() == 1:
# Triton does not support `if one_elem_tensor:` but supports `if scalar:`,
# so we need to use tl.sum to extract the scalar.
test_code = ast.unparse(test)
test = expr_from_string(f"tl.sum({test_code})")

args = state.ast_args[2]
assert isinstance(args, list)
assert all(isinstance(x, ast.AST) for x in args)
Expand Down
4 changes: 2 additions & 2 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ def create(
mask_values.setdefault(f"({mask}){expand}")
output_idx += 1
else:
expand = tile_strategy.expand_str(output_size, output_idx)
# When the index is a scalar (no BlockSizeOrigin), the corresponding dim is eliminated.
val = state.device_function.literal_expr(k)
index_values.append(f"tl.full([1], {val}, {dtype}){expand}")
index_values.append(f"({val})")
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
expand = tile_strategy.expand_str(output_size, output_idx)
size = fake_value.size(len(index_values))
Expand Down
2 changes: 1 addition & 1 deletion test/test_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _fn_kernel(a, out0, out1, out2, a_size_0, a_size_1, a_stride_0, a_stride_1,
v_1 = load_2 + subscript
tl.store(out1 + (indices_0[:, None] * out1_stride_0 + indices_1[None, :] * out1_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
load_4 = tl.load(a + (indices_0[:, None] * a_stride_0 + indices_1[None, :] * a_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_5 = tl.load(a + (indices_0[:, None] * a_stride_0 + tl.full([1], idx1, tl.int32)[None, :] * a_stride_1), mask_0[:, None], other=0)
load_5 = tl.load(a + (indices_0[:, None] * a_stride_0 + idx1 * a_stride_1), mask_0[:, None], other=0)
v_2 = load_4 + load_5
tl.store(out2 + (indices_0[:, None] * out2_stride_0 + indices_1[None, :] * out2_stride_1), v_2, mask_0[:, None] & mask_1[None, :])

Expand Down
53 changes: 38 additions & 15 deletions test/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,16 @@ def _fn_make_precompiler(x, v):
return make_precompiler(_fn_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), v, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
)

def test_if_arg_one_element_tensor(self):
def test_if_arg_indexed_scalar(self):
@helion.kernel
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)

for idx in hl.grid(x.shape[0]):
# Since `y[idx]` is a one-element tensor, comparing it against 0 will also create a one-element tensor.
# Since `y[idx]` is a scalar, comparing it against 0 will also create a scalar.
if y[idx] != 0:
output[idx] = x[idx] * 2
if (
y[idx] == 0
): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if.
else:
output[idx] = x[idx]

return output
Expand All @@ -123,20 +121,18 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0):
pid_0 = tl.program_id(0)
offset_0 = pid_0
load = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None)
load = tl.load(y + offset_0 * y_stride_0, None)
v_0 = tl.full([], 0, tl.int32)
v_1 = load != v_0
if tl.sum(v_1):
load_1 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
if v_1:
load_1 = tl.load(x + offset_0 * x_stride_0, None)
v_2 = 2.0
v_3 = load_1 * v_2
tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, v_3, None)
load_2 = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None)
v_4 = tl.full([], 0, tl.int32)
v_5 = load_2 == v_4
if tl.sum(v_5):
load_3 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, load_3, None)
tl.store(output + offset_0 * output_stride_0, v_3, None)
_not = not v_1
if _not:
load_2 = tl.load(x + offset_0 * x_stride_0, None)
tl.store(output + offset_0 * output_stride_0, load_2, None)

def fn(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
Expand All @@ -149,6 +145,33 @@ def _fn_make_precompiler(x: torch.Tensor, y: torch.Tensor):
return make_precompiler(_fn_kernel)(x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)""",
)

def test_if_arg_tensor_sum(self):
@helion.kernel
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.zeros_like(x)

for tile in hl.tile(x.shape[0]):
# Since `y[idx]` is a tensor, comparing it against 0 will also create a tensor.
# if condition must takes a scalar, therefore we call .sum() to reduce the tensor to a scalar.
if (y[tile] != 0).sum():
output[tile] = x[tile] * 2
if (
y[tile] == 0
).sum(): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if.
output[tile] = x[tile]

return output

x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=DEVICE)
y = torch.tensor([0, 1, 0, 1], device=DEVICE, dtype=torch.int32)
expected = torch.tensor([1.0, 4.0, 3.0, 8.0], device=DEVICE)
code, result = code_and_output(
fn,
(x, y),
block_size=1,
)
torch.testing.assert_close(result, expected)

def test_constant_true(self):
@helion.kernel(
config={
Expand Down
8 changes: 4 additions & 4 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,11 +1692,11 @@ def test_moe_matmul_ogs(self):
def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A_stride_0, A_stride_1, C_stride_0, C_stride_1, W_stride_0, W_stride_1, W_stride_2, expert_token_counts_stride_0, expert_token_offsets_stride_0, sorted_to_orig_token_idx_stride_0, max_T_per_expert, N, K, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
start = tl.load(expert_token_offsets + tl.full([1], offset_0, tl.int32) * expert_token_offsets_stride_0, None)
num_tokens = tl.load(expert_token_counts + tl.full([1], offset_0, tl.int32) * expert_token_counts_stride_0, None)
start = tl.load(expert_token_offsets + offset_0 * expert_token_offsets_stride_0, None)
num_tokens = tl.load(expert_token_counts + offset_0 * expert_token_counts_stride_0, None)
v_0 = tl.full([], 0, tl.int32)
v_1 = num_tokens != v_0
if tl.sum(v_1):
if v_1:
num_tokens_copy = num_tokens
start_copy = start
num_tokens_copy_0 = num_tokens_copy
Expand Down Expand Up @@ -1729,7 +1729,7 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
expert_orig_token_indices_copy_0 = expert_orig_token_indices_copy
acc_copy_0 = acc_copy
A_frag = tl.load(A + (expert_orig_token_indices_copy_0[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
W_frag = tl.load(W + (tl.full([1], offset_0, tl.int32)[:, None] * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
W_frag = tl.load(W + (offset_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
acc = tl.dot(A_frag, W_frag, acc=acc_copy_0, input_precision='tf32')
existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1])
Expand Down
32 changes: 16 additions & 16 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
acc_copy = acc
acc_copy_0 = acc_copy
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
load = tl.load(x + (offset_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
v_0 = acc.to(tl.float16)
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
tl.store(out + (offset_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])

def grid_1d(x: torch.Tensor, y: torch.Tensor):
b, m, k = x.size()
Expand Down Expand Up @@ -225,11 +225,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
acc_copy = acc
acc_copy_0 = acc_copy
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
v_0 = acc.to(tl.float16)
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)

def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor):
bi, bj, m, k = x.size()
Expand Down Expand Up @@ -363,11 +363,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
acc_copy = acc
acc_copy_0 = acc_copy
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
v_0 = acc.to(tl.float16)
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)

def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor):
bi, bj, m, k = x.size()
Expand Down Expand Up @@ -425,10 +425,10 @@ def _grid_begin_end_kernel(x, out, out_stride_0, x_stride_0):
pid_0 = tl.program_id(0)
begin_0 = 2
offset_0 = begin_0 + pid_0
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
load = tl.load(x + offset_0 * x_stride_0, None)
v_0 = 2.0
v_1 = load * v_0
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
tl.store(out + offset_0 * out_stride_0, v_1, None)

def grid_begin_end(x: torch.Tensor):
n = x.size(0)
Expand Down Expand Up @@ -475,10 +475,10 @@ def grid_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor:
def _grid_begin_end_step_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
load = tl.load(x + offset_0 * x_stride_0, None)
v_0 = 2.0
v_1 = load * v_0
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
tl.store(out + offset_0 * out_stride_0, v_1, None)

def grid_begin_end_step(x: torch.Tensor):
n = x.size(0)
Expand Down Expand Up @@ -527,10 +527,10 @@ def grid_end_step_kwarg_pytorch(x: torch.Tensor) -> torch.Tensor:
def _grid_end_step_kwarg_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
load = tl.load(x + offset_0 * x_stride_0, None)
v_0 = 2.0
v_1 = load * v_0
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
tl.store(out + offset_0 * out_stride_0, v_1, None)

def grid_end_step_kwarg(x: torch.Tensor):
n = x.size(0)
Expand Down Expand Up @@ -587,10 +587,10 @@ def _grid_multidim_begin_end_kernel(x, out, out_stride_0, out_stride_1, x_stride
offset_0 = begin_0 + pid_0
begin_1 = 1
offset_1 = begin_1 + pid_1
load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
v_0 = 2.0
v_1 = load * v_0
tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)

def grid_multidim_begin_end(x: torch.Tensor):
m, n = x.size()
Expand Down Expand Up @@ -643,10 +643,10 @@ def _grid_multidim_begin_end_step_kernel(x, out, out_stride_0, out_stride_1, x_s
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
offset_1 = pid_1 * _BLOCK_SIZE_1
load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
v_0 = 2.0
v_1 = load * v_0
tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)

def grid_multidim_begin_end_step(x: torch.Tensor):
m, n = x.size()
Expand Down
6 changes: 6 additions & 0 deletions test/test_masking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import unittest

from expecttest import TestCase
import torch

Expand Down Expand Up @@ -332,3 +334,7 @@ def _fn_make_precompiler(x):
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_fn_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
)


if __name__ == "__main__":
unittest.main()
8 changes: 7 additions & 1 deletion test/test_register_tunable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import unittest

from expecttest import TestCase
import torch

Expand Down Expand Up @@ -187,7 +189,7 @@ def _fn_kernel(x, partial, partial_stride_0, x_stride_0, m, _BLOCK_SIZE_0: tl.co
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
sum_1 = tl.sum(load, 0)
floordiv = triton_helpers.div_floor_integer(offset_0, _BLOCK_SIZE_0)
tl.store(partial + tl.full([1], floordiv, tl.int32) * partial_stride_0, sum_1, None)
tl.store(partial + floordiv * partial_stride_0, sum_1, None)

def fn(x: torch.Tensor):
m = x.size(0)
Expand Down Expand Up @@ -317,3 +319,7 @@ def _matmul_split_k_make_precompiler(x: torch.Tensor, y: torch.Tensor):
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_matmul_split_k_kernel)(x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), n, k, m, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=16, num_stages=8)""",
)


if __name__ == "__main__":
unittest.main()
Loading