Skip to content

[Release 2.6] Triton/inductor related optimisations #2008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: release/2.6
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
6da9e66008b58a7b8553f96c69021cca0d0028f0
a34a79dbd711ea9f8fb5090bcaf24a7717574206
4 changes: 4 additions & 0 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,8 @@ def __init__(
num_stages: int,
num_warps: int,
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit
kpack: int = 0, # ROCm specific gemm paramete
workspace_arg: Optional[WorkspaceArg] = None,
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
Expand All @@ -639,6 +641,8 @@ def __init__(
self.num_stages = num_stages
self.num_warps = num_warps
self.matrix_instr_nonkdim = matrix_instr_nonkdim
self.waves_per_eu = waves_per_eu
self.kpack = kpack
self.workspace_arg = workspace_arg

def make_run_fn(
Expand Down
193 changes: 123 additions & 70 deletions torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel"""
import os
import itertools

import logging
import math
Expand Down Expand Up @@ -1206,10 +1208,24 @@ def flex_attention(
if torch.version.hip:
configs = [(c[0], c[1], c[2], 1) for c in configs]

# Check if the environment variable is set
if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1":
param1 = [16, 32, 64, 128, 256, 512]
param2 = [16, 32, 64, 128, 256, 512]
param3 = [2, 4, 8, 16]
param4 = [1]

# Generate full search space
configs = list(itertools.product(param1, param2, param3, param4))

# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)

# ROCm specific considerations
if torch.version.hip:
kernel_options["kpack"] = 2

# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
# We do need to explicitly pass it in for autotuning though.
Expand All @@ -1234,33 +1250,67 @@ def flex_attention(
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)

error = flex_attention_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
logsumexp,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
],
layout=layout,
subgraphs=[
subgraph_buffer,
mask_graph_buffer,
],
mutated_inputs=[
logsumexp,
],
num_stages=num_stages,
num_warps=num_warps,
call_sizes=query.get_size(),
**cur_kernel_options,
)
if error is not None and len(configs) == 1:
raise error
if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1":
for mfma in [0, 16]:
for wpeu in [0, 1, 2, 4, 8]:
cur_kernel_options["waves_per_eu"] = wpeu
cur_kernel_options["matrix_instr_non_kdim"] = mfma
error = flex_attention_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
logsumexp,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
],
layout=layout,
subgraphs=[
subgraph_buffer,
mask_graph_buffer,
],
mutated_inputs=[
logsumexp,
],
num_stages=num_stages,
num_warps=num_warps,
call_sizes=query.get_size(),
**cur_kernel_options,
)
if error is not None and len(configs) == 1:
raise error
else:
error = flex_attention_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
logsumexp,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
],
layout=layout,
subgraphs=[
subgraph_buffer,
mask_graph_buffer,
],
mutated_inputs=[
logsumexp,
],
num_stages=num_stages,
num_warps=num_warps,
call_sizes=query.get_size(),
**cur_kernel_options,
)
if error is not None and len(configs) == 1:
raise error

inputs_for_autotuning = (
[
query,
Expand Down Expand Up @@ -2257,13 +2307,15 @@ def flex_attention_backward(*args, **kwargs):
configs.extend(
[
(BLOCK1, BLOCK2, w, s)
for BLOCK1 in [32, 64]
for BLOCK2 in [32, 64, 128]
for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4])
for BLOCK1 in [16, 32, 64, 128, 256, 512]
for BLOCK2 in [16, 32, 64, 128, 256, 512]
for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4, 8])
for s in num_stages_list
if BLOCK2 % BLOCK1 == 0
]
)


original_kernel_options = kernel_options.copy()
for BLOCK1, BLOCK2, num_warps, num_stages in configs:
if (
Expand All @@ -2273,9 +2325,6 @@ def flex_attention_backward(*args, **kwargs):
or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
):
continue
if num_warps == 8:
# Working around https://github.com/pytorch/pytorch/issues/141603
continue

# Performance tuning
cur_kernel_options = original_kernel_options.copy()
Expand All @@ -2287,43 +2336,47 @@ def flex_attention_backward(*args, **kwargs):
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)

flex_attention_backward_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
logsumexp,
delta,
grad_out,
grad_query,
broadcasted_grad_value,
kv_num_blocks,
kv_indices,
q_num_blocks,
q_indices,
full_kv_num_blocks,
full_kv_indices,
full_q_num_blocks,
full_q_indices,
],
layout=layout_broadcasted_k, # We use store_output only for grad_key
subgraphs=[
fw_subgraph_buffer,
joint_outputs.grad_input,
mask_graph_buffer,
joint_outputs.captured_grads_compute,
],
mutated_inputs=[
grad_query,
broadcasted_grad_value,
*joint_outputs.mutated_grads,
],
call_sizes=query.get_size() + key.get_size()[1:3],
num_stages=num_stages,
num_warps=num_warps,
**cur_kernel_options,
)
for wpeu in [0, 1, 2, 4, 8]:
for mfma in [0, 16]:
cur_kernel_options["waves_per_eu"] = wpeu
cur_kernel_options["matrix_instr_non_kdim"] = mfma
flex_attention_backward_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
logsumexp,
delta,
grad_out,
grad_query,
broadcasted_grad_value,
kv_num_blocks,
kv_indices,
q_num_blocks,
q_indices,
full_kv_num_blocks,
full_kv_indices,
full_q_num_blocks,
full_q_indices,
],
layout=layout_broadcasted_k, # We use store_output only for grad_key
subgraphs=[
fw_subgraph_buffer,
joint_outputs.grad_input,
mask_graph_buffer,
joint_outputs.captured_grads_compute,
],
mutated_inputs=[
grad_query,
broadcasted_grad_value,
*joint_outputs.mutated_grads,
],
call_sizes=query.get_size() + key.get_size()[1:3],
num_stages=num_stages,
num_warps=num_warps,
**cur_kernel_options,
)
inputs_for_autotuning = (
[
query,
Expand Down
9 changes: 8 additions & 1 deletion torch/_inductor/kernel/mm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
from typing import Any, cast, Dict, Sequence, Tuple

from torch.utils._ordered_set import OrderedSet

import sympy

import torch
Expand Down Expand Up @@ -75,7 +77,7 @@ def filtered_configs(
),
min_block_size_k,
)
used = set()
used = OrderedSet[tuple[int, ...]]()
for block_m, block_n, block_k, num_stages, num_warps in configs:
# shrink configs for small sizes
block_m = max(min(int(block_m * scale), m), min_block_size)
Expand All @@ -88,20 +90,23 @@ def filtered_configs(
# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
if torch.version.hip:
kpack = 2
for matrix_instr_nonkdim in [0, 16]:
if matrix_instr_nonkdim != 0 and (
block_m % matrix_instr_nonkdim != 0
or block_n % matrix_instr_nonkdim != 0
):
# block_m and block_n must be a multiple of matrix_instr_nonkdim
continue

if (
block_m,
block_n,
block_k,
num_stages,
num_warps,
matrix_instr_nonkdim,
kpack,
) not in used:
used.add(
(
Expand All @@ -111,6 +116,7 @@ def filtered_configs(
num_stages,
num_warps,
matrix_instr_nonkdim,
kpack,
)
)
yield triton_config(
Expand All @@ -120,6 +126,7 @@ def filtered_configs(
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=matrix_instr_nonkdim,
kpack=kpack,
)
else:
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/runtime/coordinate_descent_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def get_field(config, name):
return config.num_warps
elif name == "num_stages":
return config.num_stages
elif name == "waves_per_eu":
return config.kwargs.get(name, int(8 // config.num_warps))
else:
return config.kwargs.get(name, None)

Expand Down Expand Up @@ -97,6 +99,8 @@ def tunable_fields(self):
]
if self.is_mm:
out.append("num_stages")
if self.inductor_meta.get("is_hip") is True:
out.append("waves_per_eu")

return out

Expand All @@ -105,6 +109,8 @@ def value_too_large(self, name: str, val: int) -> bool:
return val > self.get_config_max(name[0].lower())
if name == "num_warps":
return val > self.get_warpsmax()
if name == "waves_per_eu":
return val > 8

return False

Expand Down
13 changes: 11 additions & 2 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,16 @@ def jit_lines(self):
triton_meta["configs"] = [config_of(signature)]
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index]
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0)
if matrix_instr_nonkdim != 0:
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None)
waves_per_eu = self.meta.get("waves_per_eu", None)
kpack = self.meta.get("kpack", None)
if matrix_instr_nonkdim:
triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim
if waves_per_eu:
triton_meta["waves_per_eu"] = waves_per_eu
if kpack:
triton_meta["kpack"] = kpack


self.triton_meta = triton_meta

Expand Down Expand Up @@ -920,6 +927,8 @@ def make_kernel_render(out_node):
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
waves_per_eu=kwargs.get("waves_per_eu", 0),
kpack=kwargs.get("kpack", 2),
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type]
output_tensor_meta=TensorMeta.from_irnodes(layout),
workspace_arg=workspace_arg,
Expand Down