Skip to content

Commit 9d94a97

Browse files
committed
Add tl.range disallow_acc_multi_buffer to autotuner
stack-info: PR: #228, branch: jansel/stack/72
1 parent 46789b9 commit 9d94a97

File tree

8 files changed

+96
-12
lines changed

8 files changed

+96
-12
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ Contains one entry per loop dimension, specifying the unroll factor for
199199
Contains one entry per loop dimension, specifying the number of stages for
200200
`tl.range()` calls. Values less than 1 omit the `num_stages` parameter.
201201

202+
* **range\_multi\_buffers** (`list[bool | None]`):
203+
Contains one entry per loop dimension, controlling the `disallow_acc_multi_buffer`
204+
parameter for `tl.range()` calls. `True` allows multi-buffer (sets `disallow_acc_multi_buffer=False`),
205+
`False` disallows multi-buffer (sets `disallow_acc_multi_buffer=True`), and `None` omits the parameter.
206+
202207
* **reduction\_loops** (`list[int | None]`):
203208
Contains one entry per reduction dimension (see
204209
`examples/softmax.py`). Using `None` triggers a persistent reduction,

helion/_compiler/tile_strategy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
140140
if range_num_stages > 0:
141141
kwargs.append(f"num_stages={range_num_stages}")
142142

143+
range_multi_buffer = env.config_spec.range_multi_buffers.config_get(
144+
state.config.range_multi_buffers, block_idx, None
145+
)
146+
if range_multi_buffer is not None:
147+
kwargs.append(f"disallow_acc_multi_buffer={not range_multi_buffer}")
148+
143149
if kwargs:
144150
return f", {', '.join(kwargs)}"
145151
return ""

helion/autotuner/config_spec.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"flatten_loops",
3939
"range_unroll_factors",
4040
"range_num_stages",
41+
"range_multi_buffers",
4142
"num_warps",
4243
"num_stages",
4344
"use_yz_grid",
@@ -69,6 +70,9 @@ class ConfigSpec:
6970
range_num_stages: BlockIdSequence[RangeNumStagesSpec] = dataclasses.field(
7071
default_factory=BlockIdSequence
7172
)
73+
range_multi_buffers: BlockIdSequence[RangeMultiBufferSpec] = dataclasses.field(
74+
default_factory=BlockIdSequence
75+
)
7276
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
7377
default_factory=dict
7478
)
@@ -80,6 +84,7 @@ def _remove_duplicates(self) -> None:
8084
self.flatten_loops._remove_duplicates()
8185
self.range_unroll_factors._remove_duplicates()
8286
self.range_num_stages._remove_duplicates()
87+
self.range_multi_buffers._remove_duplicates()
8388

8489
def normalize(self, config: helion.Config | dict[str, object]) -> None:
8590
"""Normalize the config to match the block_sizes and validate the config."""
@@ -95,6 +100,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
95100
"flatten_loop",
96101
"range_unroll_factor",
97102
"range_num_stage",
103+
"range_multi_buffer",
98104
):
99105
if name in config:
100106
names = f"{name}s"
@@ -110,6 +116,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
110116
("reduction_loops", self.reduction_loops, True),
111117
("range_unroll_factors", self.range_unroll_factors, True),
112118
("range_num_stages", self.range_num_stages, True),
119+
("range_multi_buffers", self.range_multi_buffers, True),
113120
]:
114121
config[name] = mapping._normalize(
115122
name, config.get(name, ()), flatten=flatten
@@ -122,6 +129,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
122129
"reduction_loops",
123130
"range_unroll_factors",
124131
"range_num_stages",
132+
"range_multi_buffers",
125133
):
126134
if not config[name]:
127135
config.pop(name)
@@ -153,6 +161,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
153161
"reduction_loops": self.reduction_loops._flat_config(self, fn),
154162
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
155163
"range_num_stages": self.range_num_stages._flat_config(self, fn),
164+
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
156165
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
157166
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
158167
"indexing": fn(
@@ -181,6 +190,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
181190
"l2_groupings",
182191
"range_unroll_factors",
183192
"range_num_stages",
193+
"range_multi_buffers",
184194
):
185195
if not config[name]:
186196
config.pop(name)
@@ -350,6 +360,20 @@ def _fill_missing(self) -> int:
350360
return 0
351361

352362

363+
class RangeMultiBufferSpec(_BlockIdItem):
364+
def _fragment(self, base: ConfigSpec) -> EnumFragment:
365+
return EnumFragment((None, False, True))
366+
367+
def _normalize(self, name: str, value: object) -> bool | None:
368+
if value is not None and not isinstance(value, bool):
369+
raise InvalidConfig(f"{name} must be a boolean or None, got {value!r}")
370+
return value
371+
372+
def _fill_missing(self) -> None:
373+
"""Provide a value when not provided by the user."""
374+
return None
375+
376+
353377
def _product(seq: Sequence[int]) -> int:
354378
"""Return the product of the elements in the sequence."""
355379
return functools.reduce(operator.mul, seq, 1)

helion/language/loops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..autotuner.config_spec import FlattenLoopSpec
2727
from ..autotuner.config_spec import L2GroupingSpec
2828
from ..autotuner.config_spec import LoopOrderSpec
29+
from ..autotuner.config_spec import RangeMultiBufferSpec
2930
from ..autotuner.config_spec import RangeNumStagesSpec
3031
from ..autotuner.config_spec import RangeUnrollFactorSpec
3132
from . import _decorators
@@ -249,6 +250,7 @@ def _add_config_choices(
249250
for block_id in block_ids:
250251
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))
251252
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))
253+
config_spec.range_multi_buffers.append(RangeMultiBufferSpec([block_id]))
252254

253255

254256
def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
reduction_loops: list[int | None] | None = None,
2828
range_unroll_factors: list[int] | None = None,
2929
range_num_stages: list[int] | None = None,
30+
range_multi_buffers: list[bool | None] | None = None,
3031
num_warps: int | None = None,
3132
num_stages: int | None = None,
3233
use_yz_grid: bool | None = None,
@@ -44,6 +45,7 @@ def __init__(
4445
reduction_loops: Configures reduction loop behavior.
4546
range_unroll_factors: Loop unroll factors for tl.range calls.
4647
range_num_stages: Number of stages for tl.range calls.
48+
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
4749
num_warps: Number of warps per block.
4850
num_stages: Number of stages for software pipelining.
4951
use_yz_grid: Whether to use yz grid dimensions.
@@ -59,6 +61,7 @@ def __init__(
5961
"reduction_loops": reduction_loops,
6062
"range_unroll_factors": range_unroll_factors,
6163
"range_num_stages": range_num_stages,
64+
"range_multi_buffers": range_multi_buffers,
6265
"num_warps": num_warps,
6366
"num_stages": num_stages,
6467
"indexing": indexing,
@@ -152,6 +155,10 @@ def range_unroll_factors(self) -> list[int]:
152155
def range_num_stages(self) -> list[int]:
153156
return cast("list[int]", self.config.get("range_num_stages", []))
154157

158+
@property
159+
def range_multi_buffers(self) -> list[bool | None]:
160+
return cast("list[bool | None]", self.config.get("range_multi_buffers", []))
161+
155162
@property
156163
def indexing(self) -> IndexingLiteral:
157164
return self.config.get("indexing", "pointer") # type: ignore

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ select = [
5858
"TD004", "TRY002", "TRY203", "TRY401", "UP", "W", "YTT",
5959
]
6060
ignore = [
61-
"C409", "C419", "COM812", "E501", "ERA001", "FURB189", "G004", "PERF203",
62-
"PERF401", "SIM102", "SIM108", "SIM115", "UP035", "UP038",
61+
"C409", "C419", "COM812", "E501", "ERA001", "FURB189", "G004", "PERF203", "PERF401",
62+
"RET501", "SIM102", "SIM108", "SIM115", "UP035", "UP038",
6363
]
6464
extend-safe-fixes = ["TC", "UP045", "RUF013", "RSE102"]
6565
preview = true

test/test_autotuner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ def test_config_fragment0(self):
4444
self.assertExpectedInline(
4545
"\n".join(map(repr, configs)),
4646
"""\
47-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], num_warps=4, num_stages=3, indexing='pointer')
48-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], num_warps=2, num_stages=8, indexing='pointer')
49-
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[2], range_num_stages=[1], num_warps=2, num_stages=3, indexing='tensor_descriptor')
50-
helion.Config(block_sizes=[64, 16, 32], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], num_warps=32, num_stages=4, indexing='block_ptr')
51-
helion.Config(block_sizes=[128, 16, 64], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[2], range_num_stages=[4], num_warps=8, num_stages=1, indexing='block_ptr')
52-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[4], num_warps=1, num_stages=4, indexing='tensor_descriptor')
53-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[4], num_warps=32, num_stages=3, indexing='tensor_descriptor')
54-
helion.Config(block_sizes=[16, 64, 64], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_num_stages=[0], num_warps=8, num_stages=2, indexing='block_ptr')
55-
helion.Config(block_sizes=[64, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[0], range_num_stages=[2], num_warps=4, num_stages=5, indexing='block_ptr')
56-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[4], range_num_stages=[1], num_warps=32, num_stages=7, indexing='block_ptr')""",
47+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], num_warps=4, num_stages=3, indexing='pointer')
48+
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], num_warps=16, num_stages=8, indexing='pointer')
49+
helion.Config(block_sizes=[64, 16, 128], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[2], range_num_stages=[1], range_multi_buffers=[None], num_warps=32, num_stages=3, indexing='tensor_descriptor')
50+
helion.Config(block_sizes=[64, 16, 32], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], num_warps=2, num_stages=5, indexing='block_ptr')
51+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[3], range_multi_buffers=[None], num_warps=4, num_stages=7, indexing='tensor_descriptor')
52+
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[None], num_warps=16, num_stages=3, indexing='block_ptr')
53+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[0], range_multi_buffers=[True], num_warps=1, num_stages=6, indexing='pointer')
54+
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[4], range_multi_buffers=[True], num_warps=32, num_stages=1, indexing='tensor_descriptor')
55+
helion.Config(block_sizes=[32, 32, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[4], range_num_stages=[1], range_multi_buffers=[True], num_warps=4, num_stages=2, indexing='pointer')
56+
helion.Config(block_sizes=[16, 64, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[2], range_num_stages=[3], range_multi_buffers=[True], num_warps=16, num_stages=4, indexing='block_ptr')""",
5757
)
5858

5959
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)

test/test_loops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,46 @@ def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor:
16731673
"tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, num_stages=3)", code3
16741674
)
16751675

1676+
def test_range_multi_buffers(self):
1677+
@helion.kernel()
1678+
def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor:
1679+
out = torch.empty_like(x)
1680+
# Outer loop becomes grid (no tl.range)
1681+
for tile_outer in hl.tile(x.size(0)):
1682+
# Inner loop becomes device loop with tl.range
1683+
for tile_inner in hl.tile(x.size(1)):
1684+
out[tile_outer, tile_inner] = x[tile_outer, tile_inner] + 1
1685+
return out
1686+
1687+
# Test configuration validation - that range_multi_buffers works
1688+
args = (torch.randn([64, 32], device=DEVICE),)
1689+
1690+
# Test with range_multi_buffers = [None] (no disallow_acc_multi_buffer for device loop)
1691+
code_none, result_none = code_and_output(
1692+
nested_loop_kernel, args, block_sizes=[32, 16], range_multi_buffers=[None]
1693+
)
1694+
1695+
# Test with range_multi_buffers = [True] (disallow_acc_multi_buffer=False for device loop)
1696+
code_true, result_true = code_and_output(
1697+
nested_loop_kernel, args, block_sizes=[32, 16], range_multi_buffers=[True]
1698+
)
1699+
1700+
# Test with range_multi_buffers = [False] (disallow_acc_multi_buffer=True for device loop)
1701+
code_false, result_false = code_and_output(
1702+
nested_loop_kernel, args, block_sizes=[32, 16], range_multi_buffers=[False]
1703+
)
1704+
1705+
torch.testing.assert_close(result_none, result_true)
1706+
torch.testing.assert_close(result_none, result_false)
1707+
torch.testing.assert_close(result_none, args[0] + 1)
1708+
self.assertNotEqual(code_none, code_true)
1709+
self.assertNotEqual(code_none, code_false)
1710+
self.assertNotEqual(code_true, code_false)
1711+
# Check that disallow_acc_multi_buffer parameter appears in tl.range call
1712+
self.assertNotIn("disallow_acc_multi_buffer", code_none)
1713+
self.assertIn("disallow_acc_multi_buffer=False", code_true)
1714+
self.assertIn("disallow_acc_multi_buffer=True", code_false)
1715+
16761716

16771717
if __name__ == "__main__":
16781718
unittest.main()

0 commit comments

Comments
 (0)