Skip to content

Commit 93a405c

Browse files
committed
Add tl.range flatten to autotuner
stack-info: PR: #229, branch: jansel/stack/73
1 parent 9d94a97 commit 93a405c

File tree

7 files changed

+94
-10
lines changed

7 files changed

+94
-10
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ Contains one entry per loop dimension, controlling the `disallow_acc_multi_buffe
204204
parameter for `tl.range()` calls. `True` allows multi-buffer (sets `disallow_acc_multi_buffer=False`),
205205
`False` disallows multi-buffer (sets `disallow_acc_multi_buffer=True`), and `None` omits the parameter.
206206

207+
* **range\_flattens** (`list[bool | None]`):
208+
Contains one entry per loop dimension, controlling the `flatten`
209+
parameter for `tl.range()` calls. `True` sets `flatten=True`,
210+
`False` sets `flatten=False`, and `None` omits the parameter.
211+
207212
* **reduction\_loops** (`list[int | None]`):
208213
Contains one entry per reduction dimension (see
209214
`examples/softmax.py`). Using `None` triggers a persistent reduction,

helion/_compiler/tile_strategy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
146146
if range_multi_buffer is not None:
147147
kwargs.append(f"disallow_acc_multi_buffer={not range_multi_buffer}")
148148

149+
range_flatten = env.config_spec.range_flattens.config_get(
150+
state.config.range_flattens, block_idx, None
151+
)
152+
if range_flatten is not None:
153+
kwargs.append(f"flatten={range_flatten}")
154+
149155
if kwargs:
150156
return f", {', '.join(kwargs)}"
151157
return ""

helion/autotuner/config_spec.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"range_unroll_factors",
4040
"range_num_stages",
4141
"range_multi_buffers",
42+
"range_flattens",
4243
"num_warps",
4344
"num_stages",
4445
"use_yz_grid",
@@ -73,6 +74,9 @@ class ConfigSpec:
7374
range_multi_buffers: BlockIdSequence[RangeMultiBufferSpec] = dataclasses.field(
7475
default_factory=BlockIdSequence
7576
)
77+
range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field(
78+
default_factory=BlockIdSequence
79+
)
7680
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
7781
default_factory=dict
7882
)
@@ -85,6 +89,7 @@ def _remove_duplicates(self) -> None:
8589
self.range_unroll_factors._remove_duplicates()
8690
self.range_num_stages._remove_duplicates()
8791
self.range_multi_buffers._remove_duplicates()
92+
self.range_flattens._remove_duplicates()
8893

8994
def normalize(self, config: helion.Config | dict[str, object]) -> None:
9095
"""Normalize the config to match the block_sizes and validate the config."""
@@ -101,6 +106,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
101106
"range_unroll_factor",
102107
"range_num_stage",
103108
"range_multi_buffer",
109+
"range_flatten",
104110
):
105111
if name in config:
106112
names = f"{name}s"
@@ -117,6 +123,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
117123
("range_unroll_factors", self.range_unroll_factors, True),
118124
("range_num_stages", self.range_num_stages, True),
119125
("range_multi_buffers", self.range_multi_buffers, True),
126+
("range_flattens", self.range_flattens, True),
120127
]:
121128
config[name] = mapping._normalize(
122129
name, config.get(name, ()), flatten=flatten
@@ -130,6 +137,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
130137
"range_unroll_factors",
131138
"range_num_stages",
132139
"range_multi_buffers",
140+
"range_flattens",
133141
):
134142
if not config[name]:
135143
config.pop(name)
@@ -162,6 +170,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
162170
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
163171
"range_num_stages": self.range_num_stages._flat_config(self, fn),
164172
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
173+
"range_flattens": self.range_flattens._flat_config(self, fn),
165174
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
166175
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
167176
"indexing": fn(
@@ -191,6 +200,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
191200
"range_unroll_factors",
192201
"range_num_stages",
193202
"range_multi_buffers",
203+
"range_flattens",
194204
):
195205
if not config[name]:
196206
config.pop(name)
@@ -374,6 +384,20 @@ def _fill_missing(self) -> None:
374384
return None
375385

376386

387+
class RangeFlattenSpec(_BlockIdItem):
388+
def _fragment(self, base: ConfigSpec) -> EnumFragment:
389+
return EnumFragment((None, False, True))
390+
391+
def _normalize(self, name: str, value: object) -> bool | None:
392+
if value is not None and not isinstance(value, bool):
393+
raise InvalidConfig(f"{name} must be a boolean or None, got {value!r}")
394+
return value
395+
396+
def _fill_missing(self) -> None:
397+
"""Provide a value when not provided by the user."""
398+
return
399+
400+
377401
def _product(seq: Sequence[int]) -> int:
378402
"""Return the product of the elements in the sequence."""
379403
return functools.reduce(operator.mul, seq, 1)

helion/language/loops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..autotuner.config_spec import FlattenLoopSpec
2727
from ..autotuner.config_spec import L2GroupingSpec
2828
from ..autotuner.config_spec import LoopOrderSpec
29+
from ..autotuner.config_spec import RangeFlattenSpec
2930
from ..autotuner.config_spec import RangeMultiBufferSpec
3031
from ..autotuner.config_spec import RangeNumStagesSpec
3132
from ..autotuner.config_spec import RangeUnrollFactorSpec
@@ -251,6 +252,7 @@ def _add_config_choices(
251252
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))
252253
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))
253254
config_spec.range_multi_buffers.append(RangeMultiBufferSpec([block_id]))
255+
config_spec.range_flattens.append(RangeFlattenSpec([block_id]))
254256

255257

256258
def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
range_unroll_factors: list[int] | None = None,
2929
range_num_stages: list[int] | None = None,
3030
range_multi_buffers: list[bool | None] | None = None,
31+
range_flattens: list[bool | None] | None = None,
3132
num_warps: int | None = None,
3233
num_stages: int | None = None,
3334
use_yz_grid: bool | None = None,
@@ -46,6 +47,7 @@ def __init__(
4647
range_unroll_factors: Loop unroll factors for tl.range calls.
4748
range_num_stages: Number of stages for tl.range calls.
4849
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
50+
range_flattens: Controls flatten parameter for tl.range calls.
4951
num_warps: Number of warps per block.
5052
num_stages: Number of stages for software pipelining.
5153
use_yz_grid: Whether to use yz grid dimensions.
@@ -62,6 +64,7 @@ def __init__(
6264
"range_unroll_factors": range_unroll_factors,
6365
"range_num_stages": range_num_stages,
6466
"range_multi_buffers": range_multi_buffers,
67+
"range_flattens": range_flattens,
6568
"num_warps": num_warps,
6669
"num_stages": num_stages,
6770
"indexing": indexing,
@@ -159,6 +162,10 @@ def range_num_stages(self) -> list[int]:
159162
def range_multi_buffers(self) -> list[bool | None]:
160163
return cast("list[bool | None]", self.config.get("range_multi_buffers", []))
161164

165+
@property
166+
def range_flattens(self) -> list[bool | None]:
167+
return cast("list[bool | None]", self.config.get("range_flattens", []))
168+
162169
@property
163170
def indexing(self) -> IndexingLiteral:
164171
return self.config.get("indexing", "pointer") # type: ignore

test/test_autotuner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ def test_config_fragment0(self):
4444
self.assertExpectedInline(
4545
"\n".join(map(repr, configs)),
4646
"""\
47-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], num_warps=4, num_stages=3, indexing='pointer')
48-
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], num_warps=16, num_stages=8, indexing='pointer')
49-
helion.Config(block_sizes=[64, 16, 128], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[2], range_num_stages=[1], range_multi_buffers=[None], num_warps=32, num_stages=3, indexing='tensor_descriptor')
50-
helion.Config(block_sizes=[64, 16, 32], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], num_warps=2, num_stages=5, indexing='block_ptr')
51-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[3], range_multi_buffers=[None], num_warps=4, num_stages=7, indexing='tensor_descriptor')
52-
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[None], num_warps=16, num_stages=3, indexing='block_ptr')
53-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[0], range_multi_buffers=[True], num_warps=1, num_stages=6, indexing='pointer')
54-
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[4], range_multi_buffers=[True], num_warps=32, num_stages=1, indexing='tensor_descriptor')
55-
helion.Config(block_sizes=[32, 32, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[4], range_num_stages=[1], range_multi_buffers=[True], num_warps=4, num_stages=2, indexing='pointer')
56-
helion.Config(block_sizes=[16, 64, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[2], range_num_stages=[3], range_multi_buffers=[True], num_warps=16, num_stages=4, indexing='block_ptr')""",
47+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer')
48+
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=8, num_stages=1, indexing='block_ptr')
49+
helion.Config(block_sizes=[16, 128, 32], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[1], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[True], num_warps=32, num_stages=3, indexing='tensor_descriptor')
50+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='tensor_descriptor')
51+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[4], range_num_stages=[3], range_multi_buffers=[None], range_flattens=[False], num_warps=32, num_stages=7, indexing='tensor_descriptor')
52+
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=32, num_stages=3, indexing='block_ptr')
53+
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=16, num_stages=6, indexing='pointer')
54+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=1, num_stages=6, indexing='block_ptr')
55+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[64], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=32, num_stages=7, indexing='block_ptr')
56+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[3], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[None], num_warps=8, num_stages=6, indexing='block_ptr')""",
5757
)
5858

5959
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)

test/test_loops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,46 @@ def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor:
17131713
self.assertIn("disallow_acc_multi_buffer=False", code_true)
17141714
self.assertIn("disallow_acc_multi_buffer=True", code_false)
17151715

1716+
def test_range_flatten(self):
1717+
@helion.kernel()
1718+
def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor:
1719+
out = torch.empty_like(x)
1720+
# Outer loop becomes grid (no tl.range)
1721+
for tile_outer in hl.tile(x.size(0)):
1722+
# Inner loop becomes device loop with tl.range
1723+
for tile_inner in hl.tile(x.size(1)):
1724+
out[tile_outer, tile_inner] = x[tile_outer, tile_inner] + 1
1725+
return out
1726+
1727+
# Test configuration validation - that range_flatten works
1728+
args = (torch.randn([64, 32], device=DEVICE),)
1729+
1730+
# Test with range_flattens = [None] (default, no flatten parameter)
1731+
code_none, result_none = code_and_output(
1732+
nested_loop_kernel, args, block_sizes=[32, 16], range_flattens=[None]
1733+
)
1734+
1735+
# Test with range_flattens = [True] (flatten=True for device loop)
1736+
code_true, result_true = code_and_output(
1737+
nested_loop_kernel, args, block_sizes=[32, 16], range_flattens=[True]
1738+
)
1739+
1740+
# Test with range_flattens = [False] (flatten=False for device loop)
1741+
code_false, result_false = code_and_output(
1742+
nested_loop_kernel, args, block_sizes=[32, 16], range_flattens=[False]
1743+
)
1744+
1745+
torch.testing.assert_close(result_none, result_true)
1746+
torch.testing.assert_close(result_none, result_false)
1747+
torch.testing.assert_close(result_none, args[0] + 1)
1748+
self.assertNotEqual(code_none, code_true)
1749+
self.assertNotEqual(code_none, code_false)
1750+
self.assertNotEqual(code_true, code_false)
1751+
# Check that flatten parameter appears in tl.range call
1752+
self.assertNotIn("flatten", code_none)
1753+
self.assertIn("flatten=True", code_true)
1754+
self.assertIn("flatten=False", code_false)
1755+
17161756

17171757
if __name__ == "__main__":
17181758
unittest.main()

0 commit comments

Comments
 (0)