Skip to content

Add tl.range num_stages to autotuner #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ allowing you to flatten the iteration space into a single dimension.
Contains one entry per loop dimension, specifying the unroll factor for
`tl.range()` calls. Values less than 1 omit the `loop_unroll_factor` parameter.

* **range\_num\_stages** (`list[int]`):
Contains one entry per loop dimension, specifying the number of stages for
`tl.range()` calls. Values less than 1 omit the `num_stages` parameter.

* **reduction\_loops** (`list[int | None]`):
Contains one entry per reduction dimension (see
`examples/softmax.py`). Using `None` triggers a persistent reduction,
Expand Down
15 changes: 13 additions & 2 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,24 @@ def block_size_var(self, block_idx: int) -> str | None:
return self.fn.block_size_var_cache.get((block_idx,))

def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
"""Get the range_extra string for loop unroll factor based on config."""
"""Get the range_extra string for loop unroll factor and num_stages based on config."""
env = CompileEnvironment.current()
kwargs = []

range_unroll_factor = env.config_spec.range_unroll_factors.config_get(
state.config.range_unroll_factors, block_idx, 0
)
if range_unroll_factor > 0:
return f", loop_unroll_factor={range_unroll_factor}"
kwargs.append(f"loop_unroll_factor={range_unroll_factor}")

range_num_stages = env.config_spec.range_num_stages.config_get(
state.config.range_num_stages, block_idx, 0
)
if range_num_stages > 0:
kwargs.append(f"num_stages={range_num_stages}")

if kwargs:
return f", {', '.join(kwargs)}"
return ""

def user_size(self, block_index: int) -> sympy.Expr:
Expand Down
24 changes: 24 additions & 0 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"reduction_loops",
"flatten_loops",
"range_unroll_factors",
"range_num_stages",
"num_warps",
"num_stages",
"use_yz_grid",
Expand Down Expand Up @@ -65,6 +66,9 @@ class ConfigSpec:
range_unroll_factors: BlockIdSequence[RangeUnrollFactorSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
range_num_stages: BlockIdSequence[RangeNumStagesSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
default_factory=dict
)
Expand All @@ -75,6 +79,7 @@ def _remove_duplicates(self) -> None:
self.l2_groupings._remove_duplicates()
self.flatten_loops._remove_duplicates()
self.range_unroll_factors._remove_duplicates()
self.range_num_stages._remove_duplicates()

def normalize(self, config: helion.Config | dict[str, object]) -> None:
"""Normalize the config to match the block_sizes and validate the config."""
Expand All @@ -89,6 +94,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"l2_grouping",
"flatten_loop",
"range_unroll_factor",
"range_num_stage",
):
if name in config:
names = f"{name}s"
Expand All @@ -103,6 +109,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
("loop_orders", self.loop_orders, False),
("reduction_loops", self.reduction_loops, True),
("range_unroll_factors", self.range_unroll_factors, True),
("range_num_stages", self.range_num_stages, True),
]:
config[name] = mapping._normalize(
name, config.get(name, ()), flatten=flatten
Expand All @@ -114,6 +121,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"flatten_loops",
"reduction_loops",
"range_unroll_factors",
"range_num_stages",
):
if not config[name]:
config.pop(name)
Expand Down Expand Up @@ -144,6 +152,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"l2_groupings": self.l2_groupings._flat_config(self, fn),
"reduction_loops": self.reduction_loops._flat_config(self, fn),
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
"range_num_stages": self.range_num_stages._flat_config(self, fn),
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
"indexing": fn(
Expand Down Expand Up @@ -171,6 +180,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"reduction_loops",
"l2_groupings",
"range_unroll_factors",
"range_num_stages",
):
if not config[name]:
config.pop(name)
Expand Down Expand Up @@ -326,6 +336,20 @@ def _fill_missing(self) -> int:
return 0


class RangeNumStagesSpec(_BlockIdItem):
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
return IntegerFragment(0, 4, 0)

def _normalize(self, name: str, value: object) -> int:
if not isinstance(value, int):
raise InvalidConfig(f"{name} must be an integer, got {value!r}")
return value

def _fill_missing(self) -> int:
"""Provide a value when not provided by the user."""
return 0


def _product(seq: Sequence[int]) -> int:
"""Return the product of the elements in the sequence."""
return functools.reduce(operator.mul, seq, 1)
2 changes: 2 additions & 0 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..autotuner.config_spec import FlattenLoopSpec
from ..autotuner.config_spec import L2GroupingSpec
from ..autotuner.config_spec import LoopOrderSpec
from ..autotuner.config_spec import RangeNumStagesSpec
from ..autotuner.config_spec import RangeUnrollFactorSpec
from . import _decorators
from helion.language.tile_proxy import Tile
Expand Down Expand Up @@ -247,6 +248,7 @@ def _add_config_choices(
else:
for block_id in block_ids:
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))


def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
l2_groupings: list[int] | None = None,
reduction_loops: list[int | None] | None = None,
range_unroll_factors: list[int] | None = None,
range_num_stages: list[int] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
use_yz_grid: bool | None = None,
Expand All @@ -42,6 +43,7 @@ def __init__(
l2_groupings: Reorders program IDs for L2 cache locality.
reduction_loops: Configures reduction loop behavior.
range_unroll_factors: Loop unroll factors for tl.range calls.
range_num_stages: Number of stages for tl.range calls.
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
use_yz_grid: Whether to use yz grid dimensions.
Expand All @@ -56,6 +58,7 @@ def __init__(
"l2_groupings": l2_groupings,
"reduction_loops": reduction_loops,
"range_unroll_factors": range_unroll_factors,
"range_num_stages": range_num_stages,
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
Expand Down Expand Up @@ -145,6 +148,10 @@ def use_yz_grid(self) -> bool:
def range_unroll_factors(self) -> list[int]:
return cast("list[int]", self.config.get("range_unroll_factors", []))

@property
def range_num_stages(self) -> list[int]:
return cast("list[int]", self.config.get("range_num_stages", []))

@property
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore
Expand Down
20 changes: 10 additions & 10 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def test_config_fragment0(self):
self.assertExpectedInline(
"\n".join(map(repr, configs)),
"""\
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], num_warps=4, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 64, 128], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[2], num_warps=4, num_stages=4, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], num_warps=8, num_stages=7, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], num_warps=16, num_stages=1, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], num_warps=32, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[32, 128, 64], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], num_warps=4, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[256, 32, 128], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[1], num_warps=32, num_stages=2, indexing='block_ptr')
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[3], num_warps=1, num_stages=8, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], num_warps=4, num_stages=5, indexing='block_ptr')""",
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], num_warps=2, num_stages=8, indexing='pointer')
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[2], range_num_stages=[1], num_warps=2, num_stages=3, indexing='tensor_descriptor')
helion.Config(block_sizes=[64, 16, 32], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], num_warps=32, num_stages=4, indexing='block_ptr')
helion.Config(block_sizes=[128, 16, 64], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[2], range_num_stages=[4], num_warps=8, num_stages=1, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[4], num_warps=1, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[4], num_warps=32, num_stages=3, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 64, 64], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_num_stages=[0], num_warps=8, num_stages=2, indexing='block_ptr')
helion.Config(block_sizes=[64, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[0], range_num_stages=[2], num_warps=4, num_stages=5, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[4], range_num_stages=[1], num_warps=32, num_stages=7, indexing='block_ptr')""",
)

@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
Expand Down
35 changes: 35 additions & 0 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,41 @@ def _nested_loop_kernel_make_precompiler(x: torch.Tensor):
return make_precompiler(_nested_loop_kernel_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
)

def test_range_num_stages(self):
@helion.kernel()
def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
# Outer loop becomes grid (no tl.range)
for tile_outer in hl.tile(x.size(0)):
# Inner loop becomes device loop with tl.range
for tile_inner in hl.tile(x.size(1)):
out[tile_outer, tile_inner] = x[tile_outer, tile_inner] + 1
return out

# Test configuration validation - that range_num_stages works
args = (torch.randn([64, 32], device=DEVICE),)

# Test with range_num_stages = [0] (no num_stages for device loop)
code0, result0 = code_and_output(
nested_loop_kernel, args, block_sizes=[32, 16], range_num_stages=[0]
)

# Test with range_num_stages = [3] (num_stages=3 for device loop)
code3, result3 = code_and_output(
nested_loop_kernel, args, block_sizes=[32, 16], range_num_stages=[3]
)

torch.testing.assert_close(result0, result3)
torch.testing.assert_close(result0, args[0] + 1)
self.assertNotEqual(code0, code3)
# Check that range_num_stages parameter appears in tl.range call
self.assertNotIn(
"tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, num_stages=", code0
)
self.assertIn(
"tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, num_stages=3)", code3
)


if __name__ == "__main__":
unittest.main()