Skip to content

Add tl.range flatten to autotuner #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ Contains one entry per loop dimension, controlling the `disallow_acc_multi_buffe
parameter for `tl.range()` calls. `True` allows multi-buffer (sets `disallow_acc_multi_buffer=False`),
`False` disallows multi-buffer (sets `disallow_acc_multi_buffer=True`), and `None` omits the parameter.

* **range\_flattens** (`list[bool | None]`):
Contains one entry per loop dimension, controlling the `flatten`
parameter for `tl.range()` calls. `True` sets `flatten=True`,
`False` sets `flatten=False`, and `None` omits the parameter.

* **reduction\_loops** (`list[int | None]`):
Contains one entry per reduction dimension (see
`examples/softmax.py`). Using `None` triggers a persistent reduction,
Expand Down
6 changes: 6 additions & 0 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
if range_multi_buffer is not None:
kwargs.append(f"disallow_acc_multi_buffer={not range_multi_buffer}")

range_flatten = env.config_spec.range_flattens.config_get(
state.config.range_flattens, block_idx, None
)
if range_flatten is not None:
kwargs.append(f"flatten={range_flatten}")

if kwargs:
return f", {', '.join(kwargs)}"
return ""
Expand Down
24 changes: 24 additions & 0 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"range_unroll_factors",
"range_num_stages",
"range_multi_buffers",
"range_flattens",
"num_warps",
"num_stages",
"use_yz_grid",
Expand Down Expand Up @@ -73,6 +74,9 @@ class ConfigSpec:
range_multi_buffers: BlockIdSequence[RangeMultiBufferSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
default_factory=dict
)
Expand All @@ -85,6 +89,7 @@ def _remove_duplicates(self) -> None:
self.range_unroll_factors._remove_duplicates()
self.range_num_stages._remove_duplicates()
self.range_multi_buffers._remove_duplicates()
self.range_flattens._remove_duplicates()

def normalize(self, config: helion.Config | dict[str, object]) -> None:
"""Normalize the config to match the block_sizes and validate the config."""
Expand All @@ -101,6 +106,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"range_unroll_factor",
"range_num_stage",
"range_multi_buffer",
"range_flatten",
):
if name in config:
names = f"{name}s"
Expand All @@ -117,6 +123,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
("range_unroll_factors", self.range_unroll_factors, True),
("range_num_stages", self.range_num_stages, True),
("range_multi_buffers", self.range_multi_buffers, True),
("range_flattens", self.range_flattens, True),
]:
config[name] = mapping._normalize(
name, config.get(name, ()), flatten=flatten
Expand All @@ -130,6 +137,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"range_unroll_factors",
"range_num_stages",
"range_multi_buffers",
"range_flattens",
):
if not config[name]:
config.pop(name)
Expand Down Expand Up @@ -162,6 +170,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
"range_num_stages": self.range_num_stages._flat_config(self, fn),
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
"range_flattens": self.range_flattens._flat_config(self, fn),
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
"indexing": fn(
Expand Down Expand Up @@ -191,6 +200,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"range_unroll_factors",
"range_num_stages",
"range_multi_buffers",
"range_flattens",
):
if not config[name]:
config.pop(name)
Expand Down Expand Up @@ -374,6 +384,20 @@ def _fill_missing(self) -> None:
return None


class RangeFlattenSpec(_BlockIdItem):
def _fragment(self, base: ConfigSpec) -> EnumFragment:
return EnumFragment((None, False, True))

def _normalize(self, name: str, value: object) -> bool | None:
if value is not None and not isinstance(value, bool):
raise InvalidConfig(f"{name} must be a boolean or None, got {value!r}")
return value

def _fill_missing(self) -> None:
"""Provide a value when not provided by the user."""
return


def _product(seq: Sequence[int]) -> int:
"""Return the product of the elements in the sequence."""
return functools.reduce(operator.mul, seq, 1)
2 changes: 2 additions & 0 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..autotuner.config_spec import FlattenLoopSpec
from ..autotuner.config_spec import L2GroupingSpec
from ..autotuner.config_spec import LoopOrderSpec
from ..autotuner.config_spec import RangeFlattenSpec
from ..autotuner.config_spec import RangeMultiBufferSpec
from ..autotuner.config_spec import RangeNumStagesSpec
from ..autotuner.config_spec import RangeUnrollFactorSpec
Expand Down Expand Up @@ -251,6 +252,7 @@ def _add_config_choices(
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))
config_spec.range_multi_buffers.append(RangeMultiBufferSpec([block_id]))
config_spec.range_flattens.append(RangeFlattenSpec([block_id]))


def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
range_unroll_factors: list[int] | None = None,
range_num_stages: list[int] | None = None,
range_multi_buffers: list[bool | None] | None = None,
range_flattens: list[bool | None] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
use_yz_grid: bool | None = None,
Expand All @@ -46,6 +47,7 @@ def __init__(
range_unroll_factors: Loop unroll factors for tl.range calls.
range_num_stages: Number of stages for tl.range calls.
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
range_flattens: Controls flatten parameter for tl.range calls.
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
use_yz_grid: Whether to use yz grid dimensions.
Expand All @@ -62,6 +64,7 @@ def __init__(
"range_unroll_factors": range_unroll_factors,
"range_num_stages": range_num_stages,
"range_multi_buffers": range_multi_buffers,
"range_flattens": range_flattens,
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
Expand Down Expand Up @@ -159,6 +162,10 @@ def range_num_stages(self) -> list[int]:
def range_multi_buffers(self) -> list[bool | None]:
return cast("list[bool | None]", self.config.get("range_multi_buffers", []))

@property
def range_flattens(self) -> list[bool | None]:
return cast("list[bool | None]", self.config.get("range_flattens", []))

@property
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore
Expand Down
20 changes: 10 additions & 10 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def test_config_fragment0(self):
self.assertExpectedInline(
"\n".join(map(repr, configs)),
"""\
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], num_warps=16, num_stages=8, indexing='pointer')
helion.Config(block_sizes=[64, 16, 128], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[2], range_num_stages=[1], range_multi_buffers=[None], num_warps=32, num_stages=3, indexing='tensor_descriptor')
helion.Config(block_sizes=[64, 16, 32], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], num_warps=2, num_stages=5, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[3], range_multi_buffers=[None], num_warps=4, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[None], num_warps=16, num_stages=3, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[0], range_multi_buffers=[True], num_warps=1, num_stages=6, indexing='pointer')
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[4], range_multi_buffers=[True], num_warps=32, num_stages=1, indexing='tensor_descriptor')
helion.Config(block_sizes=[32, 32, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[4], range_num_stages=[1], range_multi_buffers=[True], num_warps=4, num_stages=2, indexing='pointer')
helion.Config(block_sizes=[16, 64, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[2], range_num_stages=[3], range_multi_buffers=[True], num_warps=16, num_stages=4, indexing='block_ptr')""",
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=8, num_stages=1, indexing='block_ptr')
helion.Config(block_sizes=[16, 128, 32], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[1], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[True], num_warps=32, num_stages=3, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[4], range_num_stages=[3], range_multi_buffers=[None], range_flattens=[False], num_warps=32, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=32, num_stages=3, indexing='block_ptr')
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=16, num_stages=6, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=1, num_stages=6, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[64], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=32, num_stages=7, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[3], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[None], num_warps=8, num_stages=6, indexing='block_ptr')""",
)

@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
Expand Down
40 changes: 40 additions & 0 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,46 @@ def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor:
self.assertIn("disallow_acc_multi_buffer=False", code_true)
self.assertIn("disallow_acc_multi_buffer=True", code_false)

def test_range_flatten(self):
@helion.kernel()
def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
# Outer loop becomes grid (no tl.range)
for tile_outer in hl.tile(x.size(0)):
# Inner loop becomes device loop with tl.range
for tile_inner in hl.tile(x.size(1)):
out[tile_outer, tile_inner] = x[tile_outer, tile_inner] + 1
return out

# Test configuration validation - that range_flatten works
args = (torch.randn([64, 32], device=DEVICE),)

# Test with range_flattens = [None] (default, no flatten parameter)
code_none, result_none = code_and_output(
nested_loop_kernel, args, block_sizes=[32, 16], range_flattens=[None]
)

# Test with range_flattens = [True] (flatten=True for device loop)
code_true, result_true = code_and_output(
nested_loop_kernel, args, block_sizes=[32, 16], range_flattens=[True]
)

# Test with range_flattens = [False] (flatten=False for device loop)
code_false, result_false = code_and_output(
nested_loop_kernel, args, block_sizes=[32, 16], range_flattens=[False]
)

torch.testing.assert_close(result_none, result_true)
torch.testing.assert_close(result_none, result_false)
torch.testing.assert_close(result_none, args[0] + 1)
self.assertNotEqual(code_none, code_true)
self.assertNotEqual(code_none, code_false)
self.assertNotEqual(code_true, code_false)
# Check that flatten parameter appears in tl.range call
self.assertNotIn("flatten", code_none)
self.assertIn("flatten=True", code_true)
self.assertIn("flatten=False", code_false)


if __name__ == "__main__":
unittest.main()