Skip to content

Add tl.range loop_unroll_factor to autotuner #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ allowing you to permute the iteration order of the tiles.
Contains one entry per `hl.tile` call with two or more dimensions,
allowing you to flatten the iteration space into a single dimension.

* **range\_unroll\_factors** (`list[int]`):
Contains one entry per loop dimension, specifying the unroll factor for
`tl.range()` calls. Values less than 1 omit the `loop_unroll_factor` parameter.

* **reduction\_loops** (`list[int | None]`):
Contains one entry per reduction dimension (see
`examples/softmax.py`). Using `None` triggers a persistent reduction,
Expand Down
4 changes: 3 additions & 1 deletion helion/_compiler/reduction_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,13 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
f"{mask_var} = {index_var} < {state.sympy_expr(numel)}"
)
)

range_extra = self.get_tl_range_kwargs(state, self.block_index)
for_node = create(
ast.For,
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
iter=expr_from_string(
f"range(0, ({state.sympy_expr(numel)}), {block_size_var})"
f"tl.range(0, ({state.sympy_expr(numel)}), {block_size_var}{range_extra})"
),
body=body,
orelse=[],
Expand Down
17 changes: 15 additions & 2 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ def mask_var(self, block_idx: int) -> str | None:
def block_size_var(self, block_idx: int) -> str | None:
return self.fn.block_size_var_cache.get((block_idx,))

def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
"""Get the range_extra string for loop unroll factor based on config."""
env = CompileEnvironment.current()
range_unroll_factor = env.config_spec.range_unroll_factors.config_get(
state.config.range_unroll_factors, block_idx, 0
)
if range_unroll_factor > 0:
return f", loop_unroll_factor={range_unroll_factor}"
return ""

def user_size(self, block_index: int) -> sympy.Expr:
raise NotImplementedError

Expand Down Expand Up @@ -360,11 +370,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
)
dtype = CompileEnvironment.current().triton_index_type()
lid = self.new_var("lid")
range_extra = self.get_tl_range_kwargs(state, self.block_ids[0])
for_node = create(
ast.For,
target=create(ast.Name, id=lid, ctx=ast.Store()),
iter=expr_from_string(
f"range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}))"
f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
),
body=(
body := [
Expand Down Expand Up @@ -568,11 +579,13 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
end_var_name=end_var_name,
end_expr=self._fold_tile_end_op(state, proxy_end, block_size),
)

range_extra = self.get_tl_range_kwargs(state, block_idx)
for_node = create(
ast.For,
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
iter=expr_from_string(
f"range(begin, end, {block_size_var})",
f"tl.range(begin, end, {block_size_var}{range_extra})",
begin=self._to_ast(begin, to_dtype=dtype),
end=self._to_ast(end, to_dtype=dtype),
),
Expand Down
38 changes: 36 additions & 2 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"l2_groupings",
"reduction_loops",
"flatten_loops",
"range_unroll_factors",
"num_warps",
"num_stages",
"use_yz_grid",
Expand All @@ -61,6 +62,9 @@ class ConfigSpec:
reduction_loops: BlockIdSequence[ReductionLoopSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
range_unroll_factors: BlockIdSequence[RangeUnrollFactorSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
default_factory=dict
)
Expand All @@ -70,6 +74,7 @@ def _remove_duplicates(self) -> None:
self.loop_orders._remove_duplicates()
self.l2_groupings._remove_duplicates()
self.flatten_loops._remove_duplicates()
self.range_unroll_factors._remove_duplicates()

def normalize(self, config: helion.Config | dict[str, object]) -> None:
"""Normalize the config to match the block_sizes and validate the config."""
Expand All @@ -83,6 +88,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"reduction_loop",
"l2_grouping",
"flatten_loop",
"range_unroll_factor",
):
if name in config:
names = f"{name}s"
Expand All @@ -96,12 +102,19 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
("l2_groupings", self.l2_groupings, True),
("loop_orders", self.loop_orders, False),
("reduction_loops", self.reduction_loops, True),
("range_unroll_factors", self.range_unroll_factors, True),
]:
config[name] = mapping._normalize(
name, config.get(name, ()), flatten=flatten
)

for name in ("loop_orders", "l2_groupings", "flatten_loops", "reduction_loops"):
for name in (
"loop_orders",
"l2_groupings",
"flatten_loops",
"reduction_loops",
"range_unroll_factors",
):
if not config[name]:
config.pop(name)

Expand Down Expand Up @@ -130,6 +143,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"flatten_loops": self.flatten_loops._flat_config(self, fn),
"l2_groupings": self.l2_groupings._flat_config(self, fn),
"reduction_loops": self.reduction_loops._flat_config(self, fn),
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
"indexing": fn(
Expand All @@ -151,7 +165,13 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
not config["flatten_loops"] or not config["flatten_loops"][0]
):
config["use_yz_grid"] = use_yz_grid
for name in ("loop_orders", "flatten_loops", "reduction_loops", "l2_groupings"):
for name in (
"loop_orders",
"flatten_loops",
"reduction_loops",
"l2_groupings",
"range_unroll_factors",
):
if not config[name]:
config.pop(name)
return helion.Config(**config) # pyre-ignore[6]
Expand Down Expand Up @@ -292,6 +312,20 @@ def _fill_missing(self) -> None:
return None


class RangeUnrollFactorSpec(_BlockIdItem):
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
return IntegerFragment(0, 4, 0)

def _normalize(self, name: str, value: object) -> int:
if not isinstance(value, int):
raise InvalidConfig(f"{name} must be an integer, got {value!r}")
return value

def _fill_missing(self) -> int:
"""Provide a value when not provided by the user."""
return 0


def _product(seq: Sequence[int]) -> int:
"""Return the product of the elements in the sequence."""
return functools.reduce(operator.mul, seq, 1)
8 changes: 7 additions & 1 deletion helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..autotuner.config_spec import FlattenLoopSpec
from ..autotuner.config_spec import L2GroupingSpec
from ..autotuner.config_spec import LoopOrderSpec
from ..autotuner.config_spec import RangeUnrollFactorSpec
from . import _decorators
from helion.language.tile_proxy import Tile

Expand Down Expand Up @@ -230,17 +231,22 @@ def _add_config_choices(
block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False
) -> None:
config_spec = CompileEnvironment.current().config_spec

if len(block_ids) > 1:
# Add loop reordering choice
config_spec.loop_orders.append(LoopOrderSpec(block_ids))
if is_tile and not has_begin:
config_spec.flatten_loops.append(FlattenLoopSpec(block_ids))

if all(x._loop_type != LoopType.GRID for x in ExtendedAST.current()): # is_grid
is_grid = all(x._loop_type != LoopType.GRID for x in ExtendedAST.current())
if is_grid:
if len(block_ids) == 2:
# TODO(jansel): support L2 grouping with 3+ dims (and maybe non-grids?)
config_spec.l2_groupings.append(L2GroupingSpec(block_ids))
config_spec.allow_use_yz_grid = _allow_use_yz_grid(config_spec, block_ids)
else:
for block_id in block_ids:
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))


def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
flatten_loops: list[bool] | None = None,
l2_groupings: list[int] | None = None,
reduction_loops: list[int | None] | None = None,
range_unroll_factors: list[int] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
use_yz_grid: bool | None = None,
Expand All @@ -40,6 +41,7 @@ def __init__(
loop_orders: Permutes iteration order of tiles.
l2_groupings: Reorders program IDs for L2 cache locality.
reduction_loops: Configures reduction loop behavior.
range_unroll_factors: Loop unroll factors for tl.range calls.
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
use_yz_grid: Whether to use yz grid dimensions.
Expand All @@ -53,6 +55,7 @@ def __init__(
"flatten_loops": flatten_loops,
"l2_groupings": l2_groupings,
"reduction_loops": reduction_loops,
"range_unroll_factors": range_unroll_factors,
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
Expand Down Expand Up @@ -138,6 +141,10 @@ def l2_groupings(self) -> list[int]:
def use_yz_grid(self) -> bool:
return cast("bool", self.config.get("use_yz_grid", False))

@property
def range_unroll_factors(self) -> list[int]:
return cast("list[int]", self.config.get("range_unroll_factors", []))

@property
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore
Expand Down
20 changes: 10 additions & 10 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def test_config_fragment0(self):
self.assertExpectedInline(
"\n".join(map(repr, configs)),
"""\
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8], num_warps=32, num_stages=3, indexing='block_ptr')
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[1, 0]], l2_groupings=[32], num_warps=8, num_stages=8, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], num_warps=8, num_stages=2, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], num_warps=4, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[32, 128, 64], loop_orders=[[0, 1]], l2_groupings=[2], num_warps=16, num_stages=5, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], num_warps=16, num_stages=3, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=2, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=1, num_stages=1, indexing='tensor_descriptor')""",
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], num_warps=4, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 64, 128], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[2], num_warps=4, num_stages=4, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], num_warps=8, num_stages=7, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], num_warps=16, num_stages=1, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], num_warps=32, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[32, 128, 64], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], num_warps=4, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[256, 32, 128], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[1], num_warps=32, num_stages=2, indexing='block_ptr')
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[3], num_warps=1, num_stages=8, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], num_warps=4, num_stages=5, indexing='block_ptr')""",
)

@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
Expand Down
Loading
Loading