Skip to content

Add tl.range warp_specialize to autotuner #230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ Contains one entry per loop dimension, controlling the `flatten`
parameter for `tl.range()` calls. `True` sets `flatten=True`,
`False` sets `flatten=False`, and `None` omits the parameter.

* **range\_warp\_specializes** (`list[bool | None]`):
Contains one entry per loop dimension, controlling the `warp_specialize`
parameter for `tl.range()` calls. `True` sets `warp_specialize=True`,
`False` sets `warp_specialize=False`, and `None` omits the parameter.
Only available on CUDA devices with Blackwell or newer architectures
when `allow_warp_specialize` setting is enabled.

* **reduction\_loops** (`list[int | None]`):
Contains one entry per reduction dimension (see
`examples/softmax.py`). Using `None` triggers a persistent reduction,
Expand Down
9 changes: 8 additions & 1 deletion helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,16 @@ def codegen_function_def(self) -> ast.FunctionDef:

def codegen_function_call(self) -> ast.AST:
args = [arg.host_str() for arg in self.sorted_args()]

# Workaround for triton bug: warp_specialize requires at least 4 warps
# See: https://github.com/triton-lang/triton/issues/7354
num_warps = self.config.num_warps
if any(self.config.range_warp_specializes):
num_warps = max(4, num_warps)

args.extend(
[
f"num_warps={self.config.num_warps}",
f"num_warps={num_warps}",
f"num_stages={self.config.num_stages}",
]
)
Expand Down
6 changes: 6 additions & 0 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
if range_unroll_factor > 0:
kwargs.append(f"loop_unroll_factor={range_unroll_factor}")

range_warp_specialize = env.config_spec.range_warp_specialize.config_get(
state.config.range_warp_specializes, block_idx, None
)
if range_warp_specialize is not None:
kwargs.append(f"warp_specialize={range_warp_specialize}")

range_num_stages = env.config_spec.range_num_stages.config_get(
state.config.range_num_stages, block_idx, 0
)
Expand Down
59 changes: 30 additions & 29 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"reduction_loops",
"flatten_loops",
"range_unroll_factors",
"range_warp_specializes",
"range_num_stages",
"range_multi_buffers",
"range_flattens",
Expand Down Expand Up @@ -68,6 +69,9 @@ class ConfigSpec:
range_unroll_factors: BlockIdSequence[RangeUnrollFactorSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
range_warp_specialize: BlockIdSequence[RangeWarpSpecializeSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
range_num_stages: BlockIdSequence[RangeNumStagesSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
Expand All @@ -87,6 +91,7 @@ def _remove_duplicates(self) -> None:
self.l2_groupings._remove_duplicates()
self.flatten_loops._remove_duplicates()
self.range_unroll_factors._remove_duplicates()
self.range_warp_specialize._remove_duplicates()
self.range_num_stages._remove_duplicates()
self.range_multi_buffers._remove_duplicates()
self.range_flattens._remove_duplicates()
Expand All @@ -104,6 +109,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"l2_grouping",
"flatten_loop",
"range_unroll_factor",
"range_warp_specialize",
"range_num_stage",
"range_multi_buffer",
"range_flatten",
Expand All @@ -121,6 +127,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
("loop_orders", self.loop_orders, False),
("reduction_loops", self.reduction_loops, True),
("range_unroll_factors", self.range_unroll_factors, True),
("range_warp_specializes", self.range_warp_specialize, True),
("range_num_stages", self.range_num_stages, True),
("range_multi_buffers", self.range_multi_buffers, True),
("range_flattens", self.range_flattens, True),
Expand All @@ -135,6 +142,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"flatten_loops",
"reduction_loops",
"range_unroll_factors",
"range_warp_specializes",
"range_num_stages",
"range_multi_buffers",
"range_flattens",
Expand Down Expand Up @@ -168,6 +176,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"l2_groupings": self.l2_groupings._flat_config(self, fn),
"reduction_loops": self.reduction_loops._flat_config(self, fn),
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
"range_warp_specializes": self.range_warp_specialize._flat_config(self, fn),
"range_num_stages": self.range_num_stages._flat_config(self, fn),
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
"range_flattens": self.range_flattens._flat_config(self, fn),
Expand Down Expand Up @@ -198,6 +207,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"reduction_loops",
"l2_groupings",
"range_unroll_factors",
"range_warp_specializes",
"range_num_stages",
"range_multi_buffers",
"range_flattens",
Expand Down Expand Up @@ -342,24 +352,7 @@ def _fill_missing(self) -> None:
return None


class RangeUnrollFactorSpec(_BlockIdItem):
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
return IntegerFragment(0, 4, 0)

def _normalize(self, name: str, value: object) -> int:
if not isinstance(value, int):
raise InvalidConfig(f"{name} must be an integer, got {value!r}")
return value

def _fill_missing(self) -> int:
"""Provide a value when not provided by the user."""
return 0


class RangeNumStagesSpec(_BlockIdItem):
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
return IntegerFragment(0, 4, 0)

class _OptionalIntSpec(_BlockIdItem):
def _normalize(self, name: str, value: object) -> int:
if not isinstance(value, int):
raise InvalidConfig(f"{name} must be an integer, got {value!r}")
Expand All @@ -370,7 +363,7 @@ def _fill_missing(self) -> int:
return 0


class RangeMultiBufferSpec(_BlockIdItem):
class _OptionalBoolSpec(_BlockIdItem):
def _fragment(self, base: ConfigSpec) -> EnumFragment:
return EnumFragment((None, False, True))

Expand All @@ -384,18 +377,26 @@ def _fill_missing(self) -> None:
return None


class RangeFlattenSpec(_BlockIdItem):
def _fragment(self, base: ConfigSpec) -> EnumFragment:
return EnumFragment((None, False, True))
class RangeUnrollFactorSpec(_OptionalIntSpec):
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
return IntegerFragment(0, 4, 0)

def _normalize(self, name: str, value: object) -> bool | None:
if value is not None and not isinstance(value, bool):
raise InvalidConfig(f"{name} must be a boolean or None, got {value!r}")
return value

def _fill_missing(self) -> None:
"""Provide a value when not provided by the user."""
return
class RangeWarpSpecializeSpec(_OptionalBoolSpec):
pass


class RangeNumStagesSpec(_OptionalIntSpec):
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
return IntegerFragment(0, 4, 0)


class RangeMultiBufferSpec(_OptionalBoolSpec):
pass


class RangeFlattenSpec(_OptionalBoolSpec):
pass


def _product(seq: Sequence[int]) -> int:
Expand Down
30 changes: 26 additions & 4 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import builtins
import inspect
from typing import TYPE_CHECKING
from typing import Iterator
from typing import Sequence
Expand All @@ -10,6 +11,7 @@

import torch
from torch._inductor.runtime.triton_heuristics import get_max_y_grid
import triton.language

from .. import exc
from .._compiler.ast_extension import ExtendedAST
Expand All @@ -30,6 +32,7 @@
from ..autotuner.config_spec import RangeMultiBufferSpec
from ..autotuner.config_spec import RangeNumStagesSpec
from ..autotuner.config_spec import RangeUnrollFactorSpec
from ..autotuner.config_spec import RangeWarpSpecializeSpec
from . import _decorators
from helion.language.tile_proxy import Tile

Expand Down Expand Up @@ -248,11 +251,30 @@ def _add_config_choices(
config_spec.l2_groupings.append(L2GroupingSpec(block_ids))
config_spec.allow_use_yz_grid = _allow_use_yz_grid(config_spec, block_ids)
else:
params = inspect.signature(triton.language.range).parameters
for block_id in block_ids:
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))
config_spec.range_multi_buffers.append(RangeMultiBufferSpec([block_id]))
config_spec.range_flattens.append(RangeFlattenSpec([block_id]))
if "loop_unroll_factor" in params:
config_spec.range_unroll_factors.append(
RangeUnrollFactorSpec([block_id])
)
if _supports_warp_specialize() and "warp_specialize" in params:
config_spec.range_warp_specialize.append(
RangeWarpSpecializeSpec([block_id])
)
if "num_stages" in params:
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))
if "disallow_acc_multi_buffer" in params:
config_spec.range_multi_buffers.append(RangeMultiBufferSpec([block_id]))
if "flatten" in params:
config_spec.range_flattens.append(RangeFlattenSpec([block_id]))


def _supports_warp_specialize() -> bool:
"""Check if the current device supports warp specialization."""
env = CompileEnvironment.current()
if env.device.type != "cuda" or not env.settings.allow_warp_specialize:
return False
return torch.cuda.get_device_capability() >= (12, 0)


def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
l2_groupings: list[int] | None = None,
reduction_loops: list[int | None] | None = None,
range_unroll_factors: list[int] | None = None,
range_warp_specializes: list[bool | None] | None = None,
range_num_stages: list[int] | None = None,
range_multi_buffers: list[bool | None] | None = None,
range_flattens: list[bool | None] | None = None,
Expand All @@ -45,6 +46,7 @@ def __init__(
l2_groupings: Reorders program IDs for L2 cache locality.
reduction_loops: Configures reduction loop behavior.
range_unroll_factors: Loop unroll factors for tl.range calls.
range_warp_specializes: Warp specialization for tl.range calls.
range_num_stages: Number of stages for tl.range calls.
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
range_flattens: Controls flatten parameter for tl.range calls.
Expand All @@ -62,6 +64,7 @@ def __init__(
"l2_groupings": l2_groupings,
"reduction_loops": reduction_loops,
"range_unroll_factors": range_unroll_factors,
"range_warp_specializes": range_warp_specializes,
"range_num_stages": range_num_stages,
"range_multi_buffers": range_multi_buffers,
"range_flattens": range_flattens,
Expand Down Expand Up @@ -154,6 +157,10 @@ def use_yz_grid(self) -> bool:
def range_unroll_factors(self) -> list[int]:
return cast("list[int]", self.config.get("range_unroll_factors", []))

@property
def range_warp_specializes(self) -> list[bool | None]:
return cast("list[bool | None]", self.config.get("range_warp_specializes", []))

@property
def range_num_stages(self) -> list[int]:
return cast("list[int]", self.config.get("range_num_stages", []))
Expand Down
4 changes: 4 additions & 0 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class _Settings:
autotune_precompile: bool = sys.platform != "win32"
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
allow_warp_specialize: bool = (
os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1"
)


class Settings(_Settings):
Expand All @@ -85,6 +88,7 @@ class Settings(_Settings):
"autotune_precompile": "If True, precompile the kernel before autotuning. Requires fork-safe environment.",
"print_output_code": "If True, print the output code of the kernel to stderr.",
"force_autotune": "If True, force autotuning even if a config is provided.",
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
}
assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)}

Expand Down
22 changes: 12 additions & 10 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from helion.autotuner.config_generation import ConfigGeneration
from helion.autotuner.random_search import RandomSearch
import helion.language as hl
from helion.language import loops

datadir = Path(__file__).parent / "data"
basic_kernels = import_path(datadir / "basic_kernels.py")
Expand All @@ -34,6 +35,7 @@ def setUp(self):
random.seed(112)

@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
@patch.object(loops, "_supports_warp_specialize", lambda: True)
def test_config_fragment0(self):
args = (
torch.randn([512, 512], device=DEVICE),
Expand All @@ -44,16 +46,16 @@ def test_config_fragment0(self):
self.assertExpectedInline(
"\n".join(map(repr, configs)),
"""\
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=8, num_stages=1, indexing='block_ptr')
helion.Config(block_sizes=[16, 128, 32], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[1], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[True], num_warps=32, num_stages=3, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[4], range_num_stages=[3], range_multi_buffers=[None], range_flattens=[False], num_warps=32, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=32, num_stages=3, indexing='block_ptr')
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=16, num_stages=6, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=1, num_stages=6, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[64], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=32, num_stages=7, indexing='block_ptr')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[3], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[None], num_warps=8, num_stages=6, indexing='block_ptr')""",
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=1, num_stages=7, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[1], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=2, num_stages=8, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], num_warps=32, num_stages=2, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_warp_specializes=[True], range_num_stages=[3], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='pointer')
helion.Config(block_sizes=[256, 128, 16], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[None], range_flattens=[False], num_warps=8, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[False], num_warps=1, num_stages=8, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[False], num_warps=4, num_stages=4, indexing='tensor_descriptor')
helion.Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[True], num_warps=16, num_stages=4, indexing='block_ptr')
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[False], num_warps=4, num_stages=1, indexing='tensor_descriptor')""",
)

@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
Expand Down
Loading