Skip to content

Commit 2959a8d

Browse files
authored
Add tl.range num_stages to autotuner (#227)
1 parent 8392666 commit 2959a8d

File tree

7 files changed

+95
-12
lines changed

7 files changed

+95
-12
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ allowing you to flatten the iteration space into a single dimension.
195195
Contains one entry per loop dimension, specifying the unroll factor for
196196
`tl.range()` calls. Values less than 1 omit the `loop_unroll_factor` parameter.
197197

198+
* **range\_num\_stages** (`list[int]`):
199+
Contains one entry per loop dimension, specifying the number of stages for
200+
`tl.range()` calls. Values less than 1 omit the `num_stages` parameter.
201+
198202
* **reduction\_loops** (`list[int | None]`):
199203
Contains one entry per reduction dimension (see
200204
`examples/softmax.py`). Using `None` triggers a persistent reduction,

helion/_compiler/tile_strategy.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,24 @@ def block_size_var(self, block_idx: int) -> str | None:
124124
return self.fn.block_size_var_cache.get((block_idx,))
125125

126126
def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
127-
"""Get the range_extra string for loop unroll factor based on config."""
127+
"""Get the range_extra string for loop unroll factor and num_stages based on config."""
128128
env = CompileEnvironment.current()
129+
kwargs = []
130+
129131
range_unroll_factor = env.config_spec.range_unroll_factors.config_get(
130132
state.config.range_unroll_factors, block_idx, 0
131133
)
132134
if range_unroll_factor > 0:
133-
return f", loop_unroll_factor={range_unroll_factor}"
135+
kwargs.append(f"loop_unroll_factor={range_unroll_factor}")
136+
137+
range_num_stages = env.config_spec.range_num_stages.config_get(
138+
state.config.range_num_stages, block_idx, 0
139+
)
140+
if range_num_stages > 0:
141+
kwargs.append(f"num_stages={range_num_stages}")
142+
143+
if kwargs:
144+
return f", {', '.join(kwargs)}"
134145
return ""
135146

136147
def user_size(self, block_index: int) -> sympy.Expr:

helion/autotuner/config_spec.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"reduction_loops",
3838
"flatten_loops",
3939
"range_unroll_factors",
40+
"range_num_stages",
4041
"num_warps",
4142
"num_stages",
4243
"use_yz_grid",
@@ -65,6 +66,9 @@ class ConfigSpec:
6566
range_unroll_factors: BlockIdSequence[RangeUnrollFactorSpec] = dataclasses.field(
6667
default_factory=BlockIdSequence
6768
)
69+
range_num_stages: BlockIdSequence[RangeNumStagesSpec] = dataclasses.field(
70+
default_factory=BlockIdSequence
71+
)
6872
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
6973
default_factory=dict
7074
)
@@ -75,6 +79,7 @@ def _remove_duplicates(self) -> None:
7579
self.l2_groupings._remove_duplicates()
7680
self.flatten_loops._remove_duplicates()
7781
self.range_unroll_factors._remove_duplicates()
82+
self.range_num_stages._remove_duplicates()
7883

7984
def normalize(self, config: helion.Config | dict[str, object]) -> None:
8085
"""Normalize the config to match the block_sizes and validate the config."""
@@ -89,6 +94,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
8994
"l2_grouping",
9095
"flatten_loop",
9196
"range_unroll_factor",
97+
"range_num_stage",
9298
):
9399
if name in config:
94100
names = f"{name}s"
@@ -103,6 +109,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
103109
("loop_orders", self.loop_orders, False),
104110
("reduction_loops", self.reduction_loops, True),
105111
("range_unroll_factors", self.range_unroll_factors, True),
112+
("range_num_stages", self.range_num_stages, True),
106113
]:
107114
config[name] = mapping._normalize(
108115
name, config.get(name, ()), flatten=flatten
@@ -114,6 +121,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
114121
"flatten_loops",
115122
"reduction_loops",
116123
"range_unroll_factors",
124+
"range_num_stages",
117125
):
118126
if not config[name]:
119127
config.pop(name)
@@ -144,6 +152,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
144152
"l2_groupings": self.l2_groupings._flat_config(self, fn),
145153
"reduction_loops": self.reduction_loops._flat_config(self, fn),
146154
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
155+
"range_num_stages": self.range_num_stages._flat_config(self, fn),
147156
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
148157
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
149158
"indexing": fn(
@@ -171,6 +180,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
171180
"reduction_loops",
172181
"l2_groupings",
173182
"range_unroll_factors",
183+
"range_num_stages",
174184
):
175185
if not config[name]:
176186
config.pop(name)
@@ -326,6 +336,20 @@ def _fill_missing(self) -> int:
326336
return 0
327337

328338

339+
class RangeNumStagesSpec(_BlockIdItem):
340+
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
341+
return IntegerFragment(0, 4, 0)
342+
343+
def _normalize(self, name: str, value: object) -> int:
344+
if not isinstance(value, int):
345+
raise InvalidConfig(f"{name} must be an integer, got {value!r}")
346+
return value
347+
348+
def _fill_missing(self) -> int:
349+
"""Provide a value when not provided by the user."""
350+
return 0
351+
352+
329353
def _product(seq: Sequence[int]) -> int:
330354
"""Return the product of the elements in the sequence."""
331355
return functools.reduce(operator.mul, seq, 1)

helion/language/loops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..autotuner.config_spec import FlattenLoopSpec
2727
from ..autotuner.config_spec import L2GroupingSpec
2828
from ..autotuner.config_spec import LoopOrderSpec
29+
from ..autotuner.config_spec import RangeNumStagesSpec
2930
from ..autotuner.config_spec import RangeUnrollFactorSpec
3031
from . import _decorators
3132
from helion.language.tile_proxy import Tile
@@ -247,6 +248,7 @@ def _add_config_choices(
247248
else:
248249
for block_id in block_ids:
249250
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))
251+
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))
250252

251253

252254
def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
l2_groupings: list[int] | None = None,
2727
reduction_loops: list[int | None] | None = None,
2828
range_unroll_factors: list[int] | None = None,
29+
range_num_stages: list[int] | None = None,
2930
num_warps: int | None = None,
3031
num_stages: int | None = None,
3132
use_yz_grid: bool | None = None,
@@ -42,6 +43,7 @@ def __init__(
4243
l2_groupings: Reorders program IDs for L2 cache locality.
4344
reduction_loops: Configures reduction loop behavior.
4445
range_unroll_factors: Loop unroll factors for tl.range calls.
46+
range_num_stages: Number of stages for tl.range calls.
4547
num_warps: Number of warps per block.
4648
num_stages: Number of stages for software pipelining.
4749
use_yz_grid: Whether to use yz grid dimensions.
@@ -56,6 +58,7 @@ def __init__(
5658
"l2_groupings": l2_groupings,
5759
"reduction_loops": reduction_loops,
5860
"range_unroll_factors": range_unroll_factors,
61+
"range_num_stages": range_num_stages,
5962
"num_warps": num_warps,
6063
"num_stages": num_stages,
6164
"indexing": indexing,
@@ -145,6 +148,10 @@ def use_yz_grid(self) -> bool:
145148
def range_unroll_factors(self) -> list[int]:
146149
return cast("list[int]", self.config.get("range_unroll_factors", []))
147150

151+
@property
152+
def range_num_stages(self) -> list[int]:
153+
return cast("list[int]", self.config.get("range_num_stages", []))
154+
148155
@property
149156
def indexing(self) -> IndexingLiteral:
150157
return self.config.get("indexing", "pointer") # type: ignore

test/test_autotuner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ def test_config_fragment0(self):
4444
self.assertExpectedInline(
4545
"\n".join(map(repr, configs)),
4646
"""\
47-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], num_warps=4, num_stages=3, indexing='pointer')
48-
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], num_warps=4, num_stages=4, indexing='tensor_descriptor')
49-
helion.Config(block_sizes=[16, 64, 128], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[2], num_warps=4, num_stages=4, indexing='pointer')
50-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], num_warps=8, num_stages=7, indexing='block_ptr')
51-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], num_warps=16, num_stages=1, indexing='pointer')
52-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], num_warps=32, num_stages=7, indexing='tensor_descriptor')
53-
helion.Config(block_sizes=[32, 128, 64], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], num_warps=4, num_stages=4, indexing='tensor_descriptor')
54-
helion.Config(block_sizes=[256, 32, 128], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[1], num_warps=32, num_stages=2, indexing='block_ptr')
55-
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[3], num_warps=1, num_stages=8, indexing='tensor_descriptor')
56-
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], num_warps=4, num_stages=5, indexing='block_ptr')""",
47+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], num_warps=4, num_stages=3, indexing='pointer')
48+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], num_warps=2, num_stages=8, indexing='pointer')
49+
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[2], range_num_stages=[1], num_warps=2, num_stages=3, indexing='tensor_descriptor')
50+
helion.Config(block_sizes=[64, 16, 32], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], num_warps=32, num_stages=4, indexing='block_ptr')
51+
helion.Config(block_sizes=[128, 16, 64], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[2], range_num_stages=[4], num_warps=8, num_stages=1, indexing='block_ptr')
52+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[4], num_warps=1, num_stages=4, indexing='tensor_descriptor')
53+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[4], num_warps=32, num_stages=3, indexing='tensor_descriptor')
54+
helion.Config(block_sizes=[16, 64, 64], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_num_stages=[0], num_warps=8, num_stages=2, indexing='block_ptr')
55+
helion.Config(block_sizes=[64, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[0], range_num_stages=[2], num_warps=4, num_stages=5, indexing='block_ptr')
56+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[4], range_num_stages=[1], num_warps=32, num_stages=7, indexing='block_ptr')""",
5757
)
5858

5959
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)

test/test_loops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,41 @@ def _nested_loop_kernel_make_precompiler(x: torch.Tensor):
16381638
return make_precompiler(_nested_loop_kernel_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
16391639
)
16401640

1641+
def test_range_num_stages(self):
1642+
@helion.kernel()
1643+
def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor:
1644+
out = torch.empty_like(x)
1645+
# Outer loop becomes grid (no tl.range)
1646+
for tile_outer in hl.tile(x.size(0)):
1647+
# Inner loop becomes device loop with tl.range
1648+
for tile_inner in hl.tile(x.size(1)):
1649+
out[tile_outer, tile_inner] = x[tile_outer, tile_inner] + 1
1650+
return out
1651+
1652+
# Test configuration validation - that range_num_stages works
1653+
args = (torch.randn([64, 32], device=DEVICE),)
1654+
1655+
# Test with range_num_stages = [0] (no num_stages for device loop)
1656+
code0, result0 = code_and_output(
1657+
nested_loop_kernel, args, block_sizes=[32, 16], range_num_stages=[0]
1658+
)
1659+
1660+
# Test with range_num_stages = [3] (num_stages=3 for device loop)
1661+
code3, result3 = code_and_output(
1662+
nested_loop_kernel, args, block_sizes=[32, 16], range_num_stages=[3]
1663+
)
1664+
1665+
torch.testing.assert_close(result0, result3)
1666+
torch.testing.assert_close(result0, args[0] + 1)
1667+
self.assertNotEqual(code0, code3)
1668+
# Check that range_num_stages parameter appears in tl.range call
1669+
self.assertNotIn(
1670+
"tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, num_stages=", code0
1671+
)
1672+
self.assertIn(
1673+
"tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, num_stages=3)", code3
1674+
)
1675+
16411676

16421677
if __name__ == "__main__":
16431678
unittest.main()

0 commit comments

Comments
 (0)