Skip to content

Commit f14e43c

Browse files
committed
Add tl.range loop_unroll_factor to autotuner
stack-info: PR: #226, branch: jansel/stack/70
1 parent 04ecef6 commit f14e43c

15 files changed

+229
-97
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ allowing you to permute the iteration order of the tiles.
191191
Contains one entry per `hl.tile` call with two or more dimensions,
192192
allowing you to flatten the iteration space into a single dimension.
193193

194+
* **range\_unroll\_factors** (`list[int]`):
195+
Contains one entry per loop dimension, specifying the unroll factor for
196+
`tl.range()` calls. Values less than 1 omit the `loop_unroll_factor` parameter.
197+
194198
* **reduction\_loops** (`list[int | None]`):
195199
Contains one entry per reduction dimension (see
196200
`examples/softmax.py`). Using `None` triggers a persistent reduction,

helion/_compiler/reduction_strategy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,13 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
252252
f"{mask_var} = {index_var} < {state.sympy_expr(numel)}"
253253
)
254254
)
255+
256+
range_extra = self.get_tl_range_kwargs(state, self.block_index)
255257
for_node = create(
256258
ast.For,
257259
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
258260
iter=expr_from_string(
259-
f"range(0, ({state.sympy_expr(numel)}), {block_size_var})"
261+
f"tl.range(0, ({state.sympy_expr(numel)}), {block_size_var}{range_extra})"
260262
),
261263
body=body,
262264
orelse=[],

helion/_compiler/tile_strategy.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,16 @@ def mask_var(self, block_idx: int) -> str | None:
123123
def block_size_var(self, block_idx: int) -> str | None:
124124
return self.fn.block_size_var_cache.get((block_idx,))
125125

126+
def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
127+
"""Get the range_extra string for loop unroll factor based on config."""
128+
env = CompileEnvironment.current()
129+
range_unroll_factor = env.config_spec.range_unroll_factors.config_get(
130+
state.config.range_unroll_factors, block_idx, 0
131+
)
132+
if range_unroll_factor > 0:
133+
return f", loop_unroll_factor={range_unroll_factor}"
134+
return ""
135+
126136
def user_size(self, block_index: int) -> sympy.Expr:
127137
raise NotImplementedError
128138

@@ -360,11 +370,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
360370
)
361371
dtype = CompileEnvironment.current().triton_index_type()
362372
lid = self.new_var("lid")
373+
range_extra = self.get_tl_range_kwargs(state, self.block_ids[0])
363374
for_node = create(
364375
ast.For,
365376
target=create(ast.Name, id=lid, ctx=ast.Store()),
366377
iter=expr_from_string(
367-
f"range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}))"
378+
f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
368379
),
369380
body=(
370381
body := [
@@ -568,11 +579,13 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
568579
end_var_name=end_var_name,
569580
end_expr=self._fold_tile_end_op(state, proxy_end, block_size),
570581
)
582+
583+
range_extra = self.get_tl_range_kwargs(state, block_idx)
571584
for_node = create(
572585
ast.For,
573586
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
574587
iter=expr_from_string(
575-
f"range(begin, end, {block_size_var})",
588+
f"tl.range(begin, end, {block_size_var}{range_extra})",
576589
begin=self._to_ast(begin, to_dtype=dtype),
577590
end=self._to_ast(end, to_dtype=dtype),
578591
),

helion/autotuner/config_spec.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"l2_groupings",
3737
"reduction_loops",
3838
"flatten_loops",
39+
"range_unroll_factors",
3940
"num_warps",
4041
"num_stages",
4142
"use_yz_grid",
@@ -61,6 +62,9 @@ class ConfigSpec:
6162
reduction_loops: BlockIdSequence[ReductionLoopSpec] = dataclasses.field(
6263
default_factory=BlockIdSequence
6364
)
65+
range_unroll_factors: BlockIdSequence[RangeUnrollFactorSpec] = dataclasses.field(
66+
default_factory=BlockIdSequence
67+
)
6468
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
6569
default_factory=dict
6670
)
@@ -70,6 +74,7 @@ def _remove_duplicates(self) -> None:
7074
self.loop_orders._remove_duplicates()
7175
self.l2_groupings._remove_duplicates()
7276
self.flatten_loops._remove_duplicates()
77+
self.range_unroll_factors._remove_duplicates()
7378

7479
def normalize(self, config: helion.Config | dict[str, object]) -> None:
7580
"""Normalize the config to match the block_sizes and validate the config."""
@@ -83,6 +88,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
8388
"reduction_loop",
8489
"l2_grouping",
8590
"flatten_loop",
91+
"range_unroll_factor",
8692
):
8793
if name in config:
8894
names = f"{name}s"
@@ -96,12 +102,19 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
96102
("l2_groupings", self.l2_groupings, True),
97103
("loop_orders", self.loop_orders, False),
98104
("reduction_loops", self.reduction_loops, True),
105+
("range_unroll_factors", self.range_unroll_factors, True),
99106
]:
100107
config[name] = mapping._normalize(
101108
name, config.get(name, ()), flatten=flatten
102109
)
103110

104-
for name in ("loop_orders", "l2_groupings", "flatten_loops", "reduction_loops"):
111+
for name in (
112+
"loop_orders",
113+
"l2_groupings",
114+
"flatten_loops",
115+
"reduction_loops",
116+
"range_unroll_factors",
117+
):
105118
if not config[name]:
106119
config.pop(name)
107120

@@ -130,6 +143,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
130143
"flatten_loops": self.flatten_loops._flat_config(self, fn),
131144
"l2_groupings": self.l2_groupings._flat_config(self, fn),
132145
"reduction_loops": self.reduction_loops._flat_config(self, fn),
146+
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
133147
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
134148
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
135149
"indexing": fn(
@@ -151,7 +165,13 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
151165
not config["flatten_loops"] or not config["flatten_loops"][0]
152166
):
153167
config["use_yz_grid"] = use_yz_grid
154-
for name in ("loop_orders", "flatten_loops", "reduction_loops", "l2_groupings"):
168+
for name in (
169+
"loop_orders",
170+
"flatten_loops",
171+
"reduction_loops",
172+
"l2_groupings",
173+
"range_unroll_factors",
174+
):
155175
if not config[name]:
156176
config.pop(name)
157177
return helion.Config(**config) # pyre-ignore[6]
@@ -292,6 +312,20 @@ def _fill_missing(self) -> None:
292312
return None
293313

294314

315+
class RangeUnrollFactorSpec(_BlockIdItem):
316+
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
317+
return IntegerFragment(0, 4, 0)
318+
319+
def _normalize(self, name: str, value: object) -> int:
320+
if not isinstance(value, int):
321+
raise InvalidConfig(f"{name} must be an integer, got {value!r}")
322+
return value
323+
324+
def _fill_missing(self) -> int:
325+
"""Provide a value when not provided by the user."""
326+
return 0
327+
328+
295329
def _product(seq: Sequence[int]) -> int:
296330
"""Return the product of the elements in the sequence."""
297331
return functools.reduce(operator.mul, seq, 1)

helion/language/loops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..autotuner.config_spec import FlattenLoopSpec
2727
from ..autotuner.config_spec import L2GroupingSpec
2828
from ..autotuner.config_spec import LoopOrderSpec
29+
from ..autotuner.config_spec import RangeUnrollFactorSpec
2930
from . import _decorators
3031
from helion.language.tile_proxy import Tile
3132

@@ -230,17 +231,22 @@ def _add_config_choices(
230231
block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False
231232
) -> None:
232233
config_spec = CompileEnvironment.current().config_spec
234+
233235
if len(block_ids) > 1:
234236
# Add loop reordering choice
235237
config_spec.loop_orders.append(LoopOrderSpec(block_ids))
236238
if is_tile and not has_begin:
237239
config_spec.flatten_loops.append(FlattenLoopSpec(block_ids))
238240

239-
if all(x._loop_type != LoopType.GRID for x in ExtendedAST.current()): # is_grid
241+
is_grid = all(x._loop_type != LoopType.GRID for x in ExtendedAST.current())
242+
if is_grid:
240243
if len(block_ids) == 2:
241244
# TODO(jansel): support L2 grouping with 3+ dims (and maybe non-grids?)
242245
config_spec.l2_groupings.append(L2GroupingSpec(block_ids))
243246
config_spec.allow_use_yz_grid = _allow_use_yz_grid(config_spec, block_ids)
247+
else:
248+
for block_id in block_ids:
249+
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))
244250

245251

246252
def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
flatten_loops: list[bool] | None = None,
2626
l2_groupings: list[int] | None = None,
2727
reduction_loops: list[int | None] | None = None,
28+
range_unroll_factors: list[int] | None = None,
2829
num_warps: int | None = None,
2930
num_stages: int | None = None,
3031
use_yz_grid: bool | None = None,
@@ -40,6 +41,7 @@ def __init__(
4041
loop_orders: Permutes iteration order of tiles.
4142
l2_groupings: Reorders program IDs for L2 cache locality.
4243
reduction_loops: Configures reduction loop behavior.
44+
range_unroll_factors: Loop unroll factors for tl.range calls.
4345
num_warps: Number of warps per block.
4446
num_stages: Number of stages for software pipelining.
4547
use_yz_grid: Whether to use yz grid dimensions.
@@ -53,6 +55,7 @@ def __init__(
5355
"flatten_loops": flatten_loops,
5456
"l2_groupings": l2_groupings,
5557
"reduction_loops": reduction_loops,
58+
"range_unroll_factors": range_unroll_factors,
5659
"num_warps": num_warps,
5760
"num_stages": num_stages,
5861
"indexing": indexing,
@@ -138,6 +141,10 @@ def l2_groupings(self) -> list[int]:
138141
def use_yz_grid(self) -> bool:
139142
return cast("bool", self.config.get("use_yz_grid", False))
140143

144+
@property
145+
def range_unroll_factors(self) -> list[int]:
146+
return cast("list[int]", self.config.get("range_unroll_factors", []))
147+
141148
@property
142149
def indexing(self) -> IndexingLiteral:
143150
return self.config.get("indexing", "pointer") # type: ignore

test/test_autotuner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ def test_config_fragment0(self):
4444
self.assertExpectedInline(
4545
"\n".join(map(repr, configs)),
4646
"""\
47-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=3, indexing='pointer')
48-
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8], num_warps=32, num_stages=3, indexing='block_ptr')
49-
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[1, 0]], l2_groupings=[32], num_warps=8, num_stages=8, indexing='block_ptr')
50-
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=7, indexing='tensor_descriptor')
51-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], num_warps=8, num_stages=2, indexing='tensor_descriptor')
52-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], num_warps=4, num_stages=7, indexing='tensor_descriptor')
53-
helion.Config(block_sizes=[32, 128, 64], loop_orders=[[0, 1]], l2_groupings=[2], num_warps=16, num_stages=5, indexing='pointer')
54-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], num_warps=16, num_stages=3, indexing='tensor_descriptor')
55-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=2, indexing='block_ptr')
56-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=1, num_stages=1, indexing='tensor_descriptor')""",
47+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], num_warps=4, num_stages=3, indexing='pointer')
48+
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], num_warps=4, num_stages=4, indexing='tensor_descriptor')
49+
helion.Config(block_sizes=[16, 64, 128], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[2], num_warps=4, num_stages=4, indexing='pointer')
50+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], num_warps=8, num_stages=7, indexing='block_ptr')
51+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], num_warps=16, num_stages=1, indexing='pointer')
52+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], num_warps=32, num_stages=7, indexing='tensor_descriptor')
53+
helion.Config(block_sizes=[32, 128, 64], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], num_warps=4, num_stages=4, indexing='tensor_descriptor')
54+
helion.Config(block_sizes=[256, 32, 128], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[1], num_warps=32, num_stages=2, indexing='block_ptr')
55+
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[3], num_warps=1, num_stages=8, indexing='tensor_descriptor')
56+
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], num_warps=4, num_stages=5, indexing='block_ptr')""",
5757
)
5858

5959
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)

0 commit comments

Comments
 (0)