Skip to content

Switch from TensorDescriptor to tl.make_tensor_descriptor #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,5 @@ venv
.watchmanconfig
*.zip
CLAUDE.md
triton
torch
32 changes: 2 additions & 30 deletions helion/_compat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import functools
import importlib

import torch
from torch._inductor.runtime.hints import DeviceProperties
from torch._inductor.utils import triton_type
import triton
from triton.backends.compiler import GPUTarget
import triton.language as tl

Expand All @@ -22,35 +22,7 @@ def _supports_tensor_descriptor() -> bool:
major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
if major < 9:
return False
try:
return get_triton_tensor_descriptor_class() is not None
except ImportError:
return False


@functools.cache
def get_triton_tensor_descriptor_class_import_path() -> str:
cls = get_triton_tensor_descriptor_class()
return f"from {cls.__module__} import {cls.__qualname__}"


@functools.cache
def get_triton_tensor_descriptor_class() -> type[object]:
"""Attempt to import TensorDescriptor class from known Triton modules."""
possible_modules = [
"triton.tools.tensor_descriptor",
"triton.tools.experimental_descriptor",
]
for module_name in possible_modules:
try:
module = importlib.import_module(module_name)
if hasattr(module, "TensorDescriptor"):
return module.TensorDescriptor
except ImportError:
continue
raise ImportError(
"TensorDescriptor class not found in any of the known Triton modules."
)
return hasattr(triton.language, "make_tensor_descriptor")


@functools.cache
Expand Down
113 changes: 91 additions & 22 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,35 @@ def sort_key(self) -> tuple[object, ...]:
@dataclasses.dataclass
class TensorArg(Argument):
fake_value: torch.Tensor
_host_str: str
_host_str: str | None

def host_str(self) -> str:
if self._host_str is None:
raise RuntimeError("TensorArg has no host representation")
return self._host_str


@dataclasses.dataclass
class TensorDescriptorArg(TensorArg):
pass
# Permutation applied to make stride==1 dimension last
permutation: list[int] | None = None

def host_str(self) -> str:
if self._host_str is None:
raise RuntimeError(
"TensorDescriptorArg is device-only and has no host representation"
)
return self._host_str

@property
def inverse_permutation(self) -> list[int]:
"""Get the inverse permutation to undo the applied permutation."""
if (permutation := self.permutation) is None:
raise RuntimeError("TensorDescriptorArg.permutation is None")
inverse_perm = [0] * len(permutation)
for i, p in enumerate(permutation):
inverse_perm[p] = i
return inverse_perm


@dataclasses.dataclass
Expand Down Expand Up @@ -144,6 +164,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
self.config = config
self.codegen = codegen
self.arguments: list[Argument] = []
self.preamble: list[ast.AST] = []
self.body: list[ast.AST] = []
self._tensor_args: dict[torch.Tensor, TensorArg] = {}
self._tensor_descriptor_args: dict[
Expand Down Expand Up @@ -272,20 +293,59 @@ def tensor_arg(

def tensor_descriptor_arg(
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
) -> TensorArg:
) -> TensorDescriptorArg:
host_function = HostFunction.current()
block_size_expr = ", ".join(
map(HostFunction.current().literal_expr, block_size)
)
block_size_expr = ", ".join(map(self.literal_expr, block_size))
key = (fake_value, block_size_expr)
if key not in self._tensor_descriptor_args:
origin = host_function.tensor_to_origin[fake_value]
desc_name = self.new_var(origin.suggest_var_name() + "_desc")
env = CompileEnvironment.current()

# Find which dimension has stride==1
stride_one_dim = [*map(env.size_hint, fake_value.stride())].index(1)

# Determine if we need permutation (stride==1 dimension is not last)
permutation = None
if stride_one_dim != fake_value.ndim - 1:
# Create permutation to move stride==1 dimension to last position
permutation = [*range(fake_value.ndim)]
permutation.pop(stride_one_dim)
permutation.append(stride_one_dim)

# Create the regular tensor arg and size/stride args
tensor_arg = self.tensor_arg(fake_value)
size_args = [
self.tensor_size(fake_value, i) for i in range(fake_value.ndim)
]
stride_args = [
self.tensor_stride(fake_value, i) for i in range(fake_value.ndim)
]

# Apply permutation if needed
if permutation is not None:
size_args = [size_args[i] for i in permutation]
stride_args = [stride_args[i] for i in permutation]
block_size = [block_size[i] for i in permutation]
# Update block_size_expr for the permuted order
block_size_expr = ", ".join(map(self.literal_expr, block_size))

# Add tl.make_tensor_descriptor call to preamble
sizes = ", ".join([arg.name for arg in size_args])
strides = ", ".join([arg.name for arg in stride_args])

descriptor_stmt = statement_from_string(
f"{desc_name} = tl.make_tensor_descriptor({tensor_arg.name}, [{sizes}], [{strides}], [{block_size_expr}])"
)
self.preamble.append(descriptor_stmt)

arg = TensorDescriptorArg(
self.new_var(origin.suggest_var_name() + "_desc"),
desc_name,
fake_value,
f"TensorDescriptor.from_tensor({origin.host_str()}, [{block_size_expr}])",
None, # No host_str since this is device-only
permutation,
)
self.arguments.append(arg)
# Don't add to self.arguments since this is device-only
self._tensor_descriptor_args[key] = arg
return self._tensor_descriptor_args[key]

Expand Down Expand Up @@ -342,20 +402,28 @@ def sorted_args(self) -> list[Argument]:
self.arguments.sort(key=lambda arg: arg.sort_key())
return self.arguments

def codegen_function_def(self) -> ast.FunctionDef:
return ast_rename(
create(
ast.FunctionDef,
name=self.name,
args=create_arguments(
[arg.arg_def_node() for arg in self.sorted_args()]
def codegen_function_def(self) -> list[ast.stmt]:
prefix = []
if self._tensor_descriptor_args:
prefix.append(
statement_from_string("helion.runtime.set_triton_allocator()")
)
return [
*prefix,
ast_rename(
create(
ast.FunctionDef,
name=self.name,
args=create_arguments(
[arg.arg_def_node() for arg in self.sorted_args()]
),
body=[*self.preamble, *self.body],
decorator_list=[expr_from_string("triton.jit")],
type_params=[],
),
body=self.body,
decorator_list=[expr_from_string("triton.jit")],
type_params=[],
{k: v[0] for k, v in self._variable_renames.items()},
),
{k: v[0] for k, v in self._variable_renames.items()},
)
]

def codegen_function_call(self) -> ast.AST:
args = [arg.host_str() for arg in self.sorted_args()]
Expand Down Expand Up @@ -390,14 +458,15 @@ def dead_code_elimination(self) -> None:
"""

for _ in range(8):
rw = ReadWrites.from_list(self.body)
rw = ReadWrites.from_list([*self.preamble, *self.body])
to_remove = set()
for name in self.dce_vars:
if name in rw.writes and name not in rw.reads:
to_remove.add(name)
if not to_remove:
break
self.body[:] = ast_delete_assignments(self.body, to_remove)
self.preamble[:] = ast_delete_assignments(self.preamble, to_remove)

# drop any unused args
args_to_remove = {
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
result = ast.Module(
[
*func.codegen_imports(),
kernel_def,
*kernel_def,
host_def,
precompile_def,
],
Expand Down
93 changes: 79 additions & 14 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

if TYPE_CHECKING:
from ..runtime.config import Config
from .device_function import TensorDescriptorArg
from .inductor_lowering import CodegenState


Expand Down Expand Up @@ -145,28 +146,71 @@ def codegen_store(
class TensorDescriptorIndexingStrategy(IndexingStrategy):
"""Use TensorDescriptor to load/store from tensors"""

def codegen_load(
self,
@staticmethod
def is_supported(
state: CodegenState,
fake_tensor: torch.Tensor,
subscript: list[object],
extra_mask: ast.AST | None,
) -> ast.AST:
) -> bool:
"""Check if tensor descriptor indexing is supported with additional requirements."""
# First check the basic BlockedSubscriptIndexing requirements
if not BlockedSubscriptIndexing.is_supported(
state, fake_tensor, subscript, extra_mask
):
return False

# Additional tensor descriptor requirements:
# 1) ndim must be between 2 and 5
if not (2 <= fake_tensor.ndim <= 5):
return False

# 2) Exactly 1 dimension should have stride==1
env = CompileEnvironment.current()
stride_one_count = 0
element_size = fake_tensor.element_size()
for dim in range(fake_tensor.ndim):
stride = env.size_hint(fake_tensor.stride(dim))
if stride == 1:
stride_one_count += 1
else:
# 3) All other dimensions should have 16-byte aligned strides
byte_stride = stride * element_size
if byte_stride % 16 != 0:
return False

# TODO(jansel): check that base_ptr is aligned to 16 bytes
return stride_one_count == 1

def codegen_load(
self,
state: CodegenState,
fake_tensor: torch.Tensor,
subscript: list[object],
extra_mask: ast.AST | None,
) -> ast.AST:
if not self.is_supported(state, fake_tensor, subscript, extra_mask):
return PointerIndexingStrategy().codegen_load(
state, fake_tensor, subscript, extra_mask
)
assert extra_mask is None
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)
return indexing.reshape_load(
state,
expr_from_string(
f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str()})"
),

# Load from tensor descriptor with permuted offsets
load_expr = expr_from_string(
f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})"
)

# Apply inverse permutation to the loaded result if needed
desc_arg = indexing.tensor_descriptor_arg(state)
if desc_arg.permutation is not None:
load_expr = expr_from_string(
f"tl.permute(load_result, {desc_arg.inverse_permutation!r})",
load_result=load_expr,
)

return indexing.reshape_load(state, load_expr)

def codegen_store(
self,
state: CodegenState,
Expand All @@ -175,17 +219,27 @@ def codegen_store(
value: ast.AST,
extra_mask: ast.AST | None,
) -> ast.AST:
if not BlockedSubscriptIndexing.is_supported(
state, fake_tensor, subscript, extra_mask
):
if not self.is_supported(state, fake_tensor, subscript, extra_mask):
return PointerIndexingStrategy().codegen_store(
state, fake_tensor, subscript, value, extra_mask
)
assert extra_mask is None
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)

# Apply permutation to the value being stored if needed
desc_arg = indexing.tensor_descriptor_arg(state)
store_value = indexing.reshape_store(state, value)

if desc_arg.permutation is not None:
# Apply permutation to the value
store_value = expr_from_string(
f"tl.permute(store_val, {desc_arg.permutation!r})",
store_val=store_value,
)

return expr_from_string(
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str()}, value)",
value=indexing.reshape_store(state, value),
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, value)",
value=store_value,
)


Expand Down Expand Up @@ -371,9 +425,21 @@ def tensor_descriptor(self, state: CodegenState) -> str:
self.base, self.block_shape
).name

def tensor_descriptor_arg(self, state: CodegenState) -> TensorDescriptorArg:
return state.device_function.tensor_descriptor_arg(self.base, self.block_shape)

def offsets_str(self) -> str:
return f"[{', '.join(self.offsets)}]"

def offsets_str_permuted(self, state: CodegenState) -> str:
"""Get offsets string with permutation applied if needed."""
desc_arg = self.tensor_descriptor_arg(state)
if desc_arg.permutation is not None:
# Apply permutation to offsets
permuted_offsets = [self.offsets[i] for i in desc_arg.permutation]
return f"[{', '.join(permuted_offsets)}]"
return self.offsets_str()

@property
def ndim(self) -> int:
return self.base.ndim
Expand Down Expand Up @@ -427,7 +493,6 @@ def is_supported(
index: list[object],
extra_mask: ast.AST | None,
) -> bool:
# TODO(jansel): TensorDescriptor has some extra restrictions that are not captured here.
if extra_mask is not None:
# TODO(jansel): support block_ptr with extra_mask
return False
Expand Down
Loading
Loading