forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathselect_algorithm.py
1937 lines (1672 loc) · 67.4 KB
/
select_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: allow-untyped-defs
import builtins
import contextlib
import dataclasses
import functools
import inspect
import itertools
import json
import logging
import math
import operator
import os
import sys
import textwrap
import time
from collections import namedtuple
from concurrent.futures import as_completed, ThreadPoolExecutor
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from unittest.mock import patch
import sympy
from filelock import FileLock
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state
from . import config, ir
from .autotune_process import (
TensorMeta,
TritonBenchmarkRequest,
TritonCPUBenchmarkRequest,
TritonGPUBenchmarkRequest,
)
from .codecache import code_hash, PersistentCache, PyCodeCache
from .codegen.common import (
CSEVariable,
IndentedBuffer,
KernelTemplate,
OpOverrides,
WorkspaceArg,
)
from .codegen.simd_kernel_features import SIMDKernelFeatures
from .codegen.triton import (
gen_common_triton_imports,
texpr,
TritonKernel,
TritonScheduling,
)
from .codegen.triton_utils import config_of, signature_to_meta
from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType
from .ops_handler import StoreMode
from .runtime.benchmarking import benchmarker
from .runtime.hints import DeviceProperties
from .utils import (
FakeIndentedBuffer,
get_dtype_size,
Placeholder,
restore_stdout_stderr,
sympy_dot,
sympy_index_symbol,
sympy_product,
triton_type_to_torch,
unique,
)
from .virtualized import V
log = logging.getLogger(__name__)
# correctness checks struggle with fp16/tf32
VERIFY: Dict[str, Any] = {}
PRINT_AUTOTUNE = True
DEBUG = False
class KernelNamespace:
pass
# these objects are imported from the generated wrapper code
extern_kernels = KernelNamespace()
_T = TypeVar("_T", bound="AutotuneArgs")
@dataclasses.dataclass
class BenchmarkTensors:
"""Represents a set of inputs and outputs for autotuning with a template"""
input_tensors: List[torch.Tensor]
output_tensor: Optional[torch.Tensor]
def unpack(self):
return self.input_tensors, self.output_tensor
@dataclasses.dataclass
class AutotuneArgs:
"""During autotuning, we need to pass the same inputs to all choices.
Note:
Since we typically have a mix of external choices and triton choices, we create
two lists of inputs for the same underlying buffers:
- External inputs (for aten kernels): Include offset for sliced tensors
- Triton inputs: Use base pointer for sliced tensors, without offset
"""
triton: BenchmarkTensors
extern: BenchmarkTensors
expected: Optional[torch.Tensor] = None
def get_benchmark_tensors(self, extern=False) -> BenchmarkTensors:
"""Returns the inputs and output tensors for a given choice."""
bench_tensors = self.extern if extern else self.triton
return bench_tensors
@classmethod
def from_choice_args(
cls: Type[_T],
example_inputs: List[torch.Tensor],
example_inputs_extern: List[torch.Tensor],
out: torch.Tensor,
out_extern: torch.Tensor,
expected: Optional[torch.Tensor] = None,
) -> _T:
"""Factory method to create AutotuneInputs from separate inputs/outputs"""
return cls(
triton=BenchmarkTensors(example_inputs, out),
extern=BenchmarkTensors(example_inputs_extern, out_extern),
expected=expected,
)
def verify(self, **kwargs):
"""Verify the correctness of the benchmarking results"""
torch.testing.assert_close(self.extern.output_tensor, self.expected, **kwargs)
class PartialRender:
"""
Some parts of a template need to be generated at the end, but
inserted into the template at the start. This allows doing a bunch
of replacements after the initial render.
"""
def __init__(self, code, replacement_hooks) -> None:
super().__init__()
self.code = code
self.replacement_hooks = replacement_hooks
def finalize_hook(self, hook_key: str, strict=True) -> None:
if hook_key not in self.replacement_hooks:
if strict:
raise RuntimeError(
f"{hook_key} not registered in self.replacement_hooks"
)
else:
return
assert (
self.replacement_hooks[hook_key] is not None
), "hook_key can only be called once"
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
self.replacement_hooks[hook_key] = None
def finalize_all(self) -> str:
for key, fn in self.replacement_hooks.items():
self.code = self.code.replace(key, fn())
return self.code
# This is used to store info needed for lowering each subgraph in triton
# templates
SubgraphInfo = namedtuple(
"SubgraphInfo",
[
"body",
"template_mask",
"template_out",
],
)
class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined]
"""Handles placeholder substitutions during subgraph processing."""
def __init__(
self,
kernel,
subgraph_number: int,
fixed_inputs: Dict[str, Any],
mask: Optional[str],
):
super().__init__(V.ops)
self.name = f"PlaceholderSubstitution_{subgraph_number}"
self.kernel = kernel
self.fixed_inputs = fixed_inputs
self.mask = mask
def load(self, name: str, index: sympy.Expr):
"""Handle loading from tensor or fixed input."""
if name not in self.fixed_inputs:
index_str = self._process_indexing(index)
var = self._add_kernel_input(name)
return f"tl.load({var} + {index_str})"
return f"({self.fixed_inputs[name]})"
def indirect_indexing(self, index_var: str, size, check, wrap_neg=True):
"""Convert index variable to symbolic form."""
return sympy_index_symbol(str(index_var))
def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> str:
"""Currently only supports stores for atomic adds coming from scatter nodes
This is used by flex_attention's backwards grad for captured buffers, see
zeros_and_scatter lowering
"""
assert (
self.mask is not None
), "Mask is required for inner stores in modifications"
assert mode == "atomic_add", "Only atomic_add is supported for inner stores"
buf_name = self._add_kernel_input(name)
index_str = self._process_indexing(index)
index_str = f"tl.broadcast_to({index_str}, {value}.shape)"
store = f"tl.atomic_add({buf_name} + {index_str}, {value}, {self.mask}, sem='relaxed')"
return store
def _add_kernel_input(self, name: str):
"""Add name as input to kernel and return input ref."""
return self.kernel.args.input(name)
def _process_indexing(self, index):
"""Process and rename indexing, adding symbols as kernel inputs."""
return self.kernel.kexpr(self.kernel.rename_indexing(index))
class TritonTemplateKernel(TritonKernel):
def __init__(
self,
kernel_name,
input_nodes,
output_node,
defines,
num_stages,
num_warps,
grid_fn,
meta,
call_sizes,
use_jit=False,
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
subgraphs: Optional[List[ir.ComputedBuffer]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
) -> None:
numel = sympy_product(output_node.get_size())
super().__init__(
{
"x": numel,
"r": sympy.S.One,
},
features=SIMDKernelFeatures([], numel),
)
self.input_nodes = input_nodes
self.output_node = output_node
self.named_input_nodes = {} # type: ignore[var-annotated]
self.defines = defines
self.kernel_name = kernel_name
self.use_jit = use_jit
self.num_stages = num_stages
self.num_warps = num_warps
self.grid_fn = grid_fn
self.meta = meta
self.call_sizes = call_sizes
# for templates with fixed epilogues
self.prefix_args = prefix_args
self.suffix_args = suffix_args
self.epilogue_fn = epilogue_fn
self.render_hooks = {} # type: ignore[var-annotated]
self.triton_meta: Optional[Dict[str, object]] = None
# For Templated Attention this can be a list of ir.Subgraph
self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs
# Some templates use extra global memory as a workspace
self.workspace_arg = workspace_arg
if workspace_arg is not None:
self.args.workspace_args.append(workspace_arg)
# The following attributes (body, template_mask, output_val) are all
# used for triton kernel codegen.
# They are swapped onto the TritonTemplateKernel object by
# `set_subgraph_body`
self.subgraph_bodies: Dict[str, SubgraphInfo] = {}
self.body: IndentedBuffer = FakeIndentedBuffer()
self.template_mask: Optional[str] = None
self.template_out: Optional[str] = None
@contextlib.contextmanager
def set_subgraph_body(self, body_name: str):
old_body, old_mask, old_out = self.body, self.template_mask, self.template_out
assert body_name in self.subgraph_bodies, body_name
self.body, self.template_mask, self.template_out = self.subgraph_bodies[
body_name
]
yield
self.subgraph_bodies[body_name] = SubgraphInfo(
self.body, self.template_mask, self.template_out
)
self.body, self.template_mask, self.template_out = old_body, old_mask, old_out
@contextlib.contextmanager
def create_subgraph_body(self, body_name: str):
assert body_name not in self.subgraph_bodies
self.subgraph_bodies[body_name] = SubgraphInfo(IndentedBuffer(), None, None)
with self.set_subgraph_body(body_name):
yield
def need_numel_args(self):
return False
def estimate_kernel_num_bytes(self):
"""
Estimate the total number of bytes this kernel takes.
For in/out nodes, sizes are counted twice: once for reading and
once for writing.
"""
ninplace_args = len(unique(self.args.inplace_buffers.values()))
num_bytes = []
for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))):
size = V.graph.sizevars.size_hints(inp.get_size())
numel = functools.reduce(operator.mul, size, 1)
dtype_size = get_dtype_size(inp.get_dtype())
num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
return sum(num_bytes)
def jit_lines(self):
if self.use_jit:
return "@triton.jit"
argdefs, _, signature, _ = self.args.python_argdefs()
triton_meta: Dict[str, Any] = {
"signature": signature_to_meta(
signature, size_dtype=self.index_dtype, argdefs=argdefs
),
"device": DeviceProperties.create(self.output_node.get_device()),
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index]
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None)
waves_per_eu = self.meta.get("waves_per_eu", None)
kpack = self.meta.get("kpack", None)
if matrix_instr_nonkdim:
triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim
if waves_per_eu:
triton_meta["waves_per_eu"] = waves_per_eu
if kpack:
triton_meta["kpack"] = kpack
self.triton_meta = triton_meta
inductor_meta = {
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
**TritonKernel.inductor_meta_common(),
}
if config.profile_bandwidth or config.benchmark_kernel:
num_gb = self.estimate_kernel_num_bytes() / 1e9
inductor_meta["kernel_num_gb"] = num_gb
return f"""
@triton_heuristics.template(
num_stages={self.num_stages},
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)
@triton.jit
"""
def gen_argdefs(self):
def hook():
# python_argdefs() cannot be run until after the rest of the template lazily adds more args
arg_defs, *_ = self.args.python_argdefs()
return f"{', '.join(arg_defs)}"
self.render_hooks["<ARGDEFS>"] = hook
return "<ARGDEFS>"
def gen_defines(self):
return self.defines
def def_kernel(self, *argnames):
"""
Hook called from template code to generate function def and
needed args.
"""
assert all(isinstance(x, str) for x in argnames)
renames = IndentedBuffer(initial_indent=1)
named_args = self.input_nodes[
self.prefix_args : len(self.input_nodes) - self.suffix_args
]
assert len(argnames) == len(named_args), (
len(argnames),
len(named_args),
self.prefix_args,
len(self.input_nodes),
)
for input_node in self.input_nodes[: self.prefix_args]:
# get args in correct order
self.args.input(input_node.get_name())
for name, input_node in zip(argnames, named_args):
arg_name = f"arg_{name}"
self.named_input_nodes[name] = input_node
self.args.input_buffers[input_node.get_name()] = arg_name
# The args may be duplicated, so renaming must be after args are de-duplicated.
for name in argnames:
input_node = self.named_input_nodes[name]
arg_name = self.args.input_buffers[input_node.get_name()]
if input_node.get_layout().offset == 0:
renames.writeline(f"{name} = {arg_name}")
else:
offset = texpr(self.rename_indexing(input_node.get_layout().offset))
renames.writeline(f"{name} = {arg_name} + {offset}")
for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
# get args in correct order
self.args.input(input_node.get_name())
def hook():
# python_argdefs() cannot be run until after the rest of the template lazily adds more args
arg_defs, *_ = self.args.python_argdefs()
code = IndentedBuffer()
code.splice(gen_common_triton_imports())
code.splice(self.jit_lines())
code.writeline(f"def {self.kernel_name}({', '.join(arg_defs)}):")
with code.indent():
code.splice(self.defines)
code.splice(renames.getvalue())
return code.getvalue()
assert "<DEF_KERNEL>" not in self.render_hooks
self.render_hooks["<DEF_KERNEL>"] = hook
return "<DEF_KERNEL>"
def size(self, name: str, index: int):
"""
Hook called from template code to get the size of an arg.
Will add needed args to pass it in if it is dynamic.
"""
assert isinstance(index, int)
if name is None:
val = self.output_node.get_size()[index]
else:
assert isinstance(name, str)
val = self.named_input_nodes[name].get_size()[index]
return texpr(self.rename_indexing(val))
def stride(self, name, index=None):
"""
Hook called from template code to get the stride of an arg.
Will add needed args to pass it in if it is dynamic.
"""
if name is None:
val = self.output_node.get_stride()
else:
assert isinstance(name, str)
val = self.named_input_nodes[name].get_stride()
if isinstance(index, int):
return texpr(self.rename_indexing(val[index]))
return ", ".join([texpr(self.rename_indexing(i)) for i in val])
def _get_subgraph(self, subgraph_number: int):
assert isinstance(subgraph_number, int)
assert isinstance(self.subgraphs, list)
assert subgraph_number < len(
self.subgraphs
), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
assert (
self.body.getvalue() == ""
), "Body should be clear before adding a modification"
return self.subgraphs[subgraph_number]
def _handle_scatter_graph(self, scatter_graph):
"""Handle processing for a single scatter graph.
Args:
scatter_graph: The scatter graph to process
"""
assert isinstance(
scatter_graph, ir.ComputedBuffer
), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
def contiguous_strides(x):
# We always create a fresh contiguous grad for scattering into
return sum(
x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride())
)
return scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined]
def modification(
self,
subgraph_number: int,
output_name: Optional[str],
mask: Optional[str] = None,
**fixed_inputs,
) -> str:
"""This creates a modification function for a subgraph.
To use this inside a template, the first argument should specify which subgraph to codegen for
Args:
subgraph_number (int): The index of the subgraph in self.subgraphs
output_name (Optional[str]): The name of the output variable to store the result in
mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask
will be applied to the store.
"""
num = 0
out = None
scatters = []
while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies:
num += 1
with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"):
subgraph = self._get_subgraph(subgraph_number)
modification_handler = ModificationWrapper(
self, subgraph_number, fixed_inputs, mask
)
with V.set_ops_handler(modification_handler):
assert isinstance(
subgraph, (ir.ComputedBuffer, List)
), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
# Handle scatter stores
if isinstance(subgraph, list):
for scatter_graph in subgraph:
scatters.append(self._handle_scatter_graph(scatter_graph))
elif isinstance(subgraph.data, ir.InputBuffer):
out = subgraph.data.make_loader()(())
else:
out = subgraph.data.inner_fn(())
self.codegen_body()
if output_name is not None:
assert isinstance(output_name, str)
assert out is not None
self.body.writeline(f"{output_name} = {out.value}")
else:
assert out is None
for scatter in scatters:
self.body.writeline(str(scatter))
body_val = self.body.getvalue()
self.cse.invalidate(set()) # type: ignore[arg-type]
return body_val
def store_output(
self,
indices: Union[List[Any], Tuple[Any]],
val: str,
mask: Optional[str] = None,
indent_width: int = 4,
):
"""Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away.
Args:
indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of
these indices and output strides must match `val`.
val (str): The value to store.
mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask
will be applied to the store.
indent_width (int): The number of spaces to use for indentation. This is used when the call to
store_output is indented in the kernel definition.
"""
with self.create_subgraph_body("<STORE_OUTPUT>"):
assert isinstance(indices, (list, tuple))
assert isinstance(val, str)
assert isinstance(mask, (str, type(None)))
assert self.template_mask is None
indices = list(map(OpOverrides.paren, indices))
index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
lengths = [
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
]
assert len(indices) == len(lengths)
# glue to make generated code use same indexing from template
for name, range_tree_entry in zip(
indices, self.range_trees[0].construct_entries(lengths)
):
range_tree_entry.set_name(name)
contiguous_index = sympy_dot(
ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
)
contiguous_index = self.rename_indexing(contiguous_index)
self.body.writeline("xindex = " + texpr(contiguous_index))
self.range_trees[0].lookup(sympy.S.One, sympy_product(lengths)).set_name(
"xindex"
)
self.template_mask = mask
self.template_out = val
self.template_indices = indices
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
output_index = self.rename_indexing(output_index)
if output_index == contiguous_index:
output_index = sympy.Symbol("xindex", integer=True)
acc_dtype = (
triton_type_to_torch(self.meta["ACC_TYPE"])
if "ACC_TYPE" in self.meta
else torch.float32
)
epilogue_args = [V.kernel.cse.namedvar(val, dtype=acc_dtype)]
for input_node in itertools.chain(
self.input_nodes[: self.prefix_args],
self.input_nodes[len(self.input_nodes) - self.suffix_args :],
):
input_node.freeze_layout()
epilogue_args.append(input_node.make_loader()(index_symbols))
V.ops.store(
self.output_node.get_name(),
output_index,
self.epilogue_fn(*epilogue_args),
)
self.codegen_body()
def hook():
# more stuff might have been added since the codegen_body above
self.codegen_body()
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
assert "<STORE_OUTPUT>" not in self.render_hooks
self.render_hooks["<STORE_OUTPUT>"] = hook
return "<STORE_OUTPUT>"
def render(self, template, kwargs):
return PartialRender(
template.render(**self.template_env(), **kwargs),
self.render_hooks,
)
def make_load(self, name, indices, mask):
"""
Optional helper called from template code to generate the code
needed to load from an tensor.
"""
assert isinstance(indices, (list, tuple))
assert isinstance(name, str)
assert isinstance(mask, str)
stride = self.named_input_nodes[name].get_stride()
indices = list(map(OpOverrides.paren, indices))
assert len(indices) == len(stride)
index = " + ".join(
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
)
return f"tl.load({name} + ({index}), {mask}, other=0.0)"
def template_env(self):
"""
Generate the namespace visible in the template.
"""
return {
fn.__name__: fn
for fn in [
self.def_kernel,
self.size,
self.stride,
self.store_output,
self.make_load,
self.modification,
self.gen_argdefs,
self.gen_defines,
]
}
def indexing(
self,
index: sympy.Expr,
*,
dense_indexing=False,
copy_shape=None,
override_mask=None,
block_ptr=False,
):
"""
Override the default indexing to use our custom mask and force
dense indexing.
"""
return super().indexing(
index,
dense_indexing=False,
# We pass template_out as the shape to broadcast the indexing to as
# the mask might be broadcast to the output shape
copy_shape=self.template_out,
override_mask=self.template_mask,
block_ptr=block_ptr,
)
def codegen_range_tree(self):
pass # ignore default codegen
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
wrapper = V.graph.wrapper_code
_, call_args, _, arg_types = self.args.python_argdefs()
# Handle workspace allocation
if self.workspace_arg is not None:
wrapper.generate_workspace_allocation(self.workspace_arg)
if V.graph.cpp_wrapper:
# In the cpp_wrapper case, we have to compute CUDA launch grid at runtime
# if any dynamic dimension is involved. We rely on the Python version
# of the grid function to generate those grid configs, which may contain
# symbolic values. The wrapper will use cexpr to print out C++ code
# appropriately for the grid configs.
grid = self.call_sizes + [self.meta]
wrapper.generate_kernel_call(
name,
call_args,
grid=self.grid_fn(*grid),
arg_types=arg_types,
triton_meta=self.triton_meta,
)
else:
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
meta = wrapper.add_meta_once(self.meta)
grid = self.call_sizes + [meta]
wrapper.generate_kernel_call(
name,
call_args,
grid=grid,
grid_fn=f"{self.grid_fn.__module__}.{self.grid_fn.__name__}",
arg_types=arg_types,
triton_meta=self.triton_meta,
gpu="cpu" not in V.graph.device_types,
)
if self.workspace_arg is not None:
wrapper.generate_workspace_deallocation(self.workspace_arg)
@functools.lru_cache(None)
def _jinja2_env():
try:
import jinja2
return jinja2.Environment(
undefined=jinja2.StrictUndefined,
)
except ImportError:
return None
class TritonTemplate(KernelTemplate):
index_counter = itertools.count()
all_templates: Dict[str, "TritonTemplate"] = {}
def __init__(self, name: str, grid: Any, source: str, debug=False) -> None:
super().__init__(name)
self.grid = grid
self.template = self._template_from_string(source)
assert name not in self.all_templates, "duplicate template name"
self.all_templates[name] = self
self.debug = debug
def generate( # type: ignore[override]
self,
input_nodes,
layout,
num_stages,
num_warps,
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
subgraphs=None,
mutated_inputs=None,
call_sizes=None,
workspace_arg: Optional[WorkspaceArg] = None,
**kwargs,
):
"""This function generates a TritonTemplateCaller
Args:
input_nodes: List of input nodes
layout: Output layout
num_stages: Number of stages for triton launch
num_warps: Number of warps for triton launch
prefix_args: Number of input nodes to be passed as arguments
suffix_args: Number of input nodes to be passed as arguments
epilogue_fn: Optional epilogue function to be called on the output
subgraphs: Optional subgraphs to be passed as arguments, these will be inlined
into the triton template string
mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful
if you need to return multiple outputs. You can pass them as inputs and mark them as
being mutated by the kernel.
"""
assert self.template, "requires jinja2"
defines = StringIO()
for name, val in kwargs.items():
defines.write(f"{name} : tl.constexpr = {val}\n")
defines = defines.getvalue()
fake_out = ir.Buffer(name="buf_out", layout=layout)
kernel_name = f"triton_{self.name}"
numel = sympy_product(layout.size)
buffers = itertools.chain(input_nodes, (fake_out,))
if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
raise NotImplementedError(
"64-bit indexing is not yet implemented for triton templates"
)
if call_sizes is None:
call_sizes = layout.size
kernel_options = {
"input_nodes": input_nodes,
"defines": defines,
"num_stages": num_stages,
"num_warps": num_warps,
"grid_fn": self.grid,
"meta": kwargs,
"call_sizes": call_sizes,
"prefix_args": prefix_args,
"suffix_args": suffix_args,
"epilogue_fn": epilogue_fn,
"subgraphs": subgraphs,
}
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(fake_out)
), V.graph.set_current_device(layout.device), TritonTemplateKernel(
kernel_name=kernel_name,
output_node=fake_out,
workspace_arg=workspace_arg,
use_jit=False,
**kernel_options,
) as kernel:
try:
template = kernel.render(self.template, kwargs)
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
code = template.finalize_all()
except ZeroDivisionError:
# TODO(nmacchioni): fix sympy division by zero
return None
if self.debug:
print("Generated Code:\n", code)
extra = (
"-".join(
[
*[
f"{kwarg}={repr(kwargs[kwarg])}"
for kwarg in sorted(kwargs.keys())
],
f"num_stages={num_stages}",
f"num_warps={num_warps}",
]
)
+ "-"
)
mod = PyCodeCache.load(code, extra)
input_call_args = tuple(kernel.args.input_buffers.keys())
# We expect the input_buffer order to be [*input_nodes, *captured_buffers]
expected_input_args = tuple(unique(x.get_name() for x in input_nodes))
assert input_call_args[: len(expected_input_args)] == expected_input_args, (
input_call_args,
expected_input_args,
)
full_input_nodes = tuple([V.graph.get_buffer(k) for k in input_call_args])
extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, tuple(kernel.args.sizevars.keys())),
fallback=config.unbacked_symint_fallback,
)
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
def make_kernel_render(out_node):
kernel = TritonTemplateKernel(
kernel_name=str(Placeholder.KERNEL_NAME),
output_node=out_node,
workspace_arg=workspace_arg,
use_jit=False,
**kernel_options,
)
render = functools.partial(
kernel.render,
self.template,
kwargs,
)
return kernel, render
# create the BenchmarkRequest
assert mod.__file__ is not None
grid = self.grid(
*V.graph.sizevars.size_hints(
call_sizes,
fallback=config.unbacked_symint_fallback,
),
kwargs,
)
bmreq_cls: Type[TritonBenchmarkRequest]
if layout.device.type == "cpu":
bmreq_cls = TritonCPUBenchmarkRequest
else:
bmreq_cls = TritonGPUBenchmarkRequest
bmreq = bmreq_cls(
module_path=mod.__file__,
module_cache_key=mod.key,
kernel_name=kernel_name,
grid=grid,
extra_args=extra_args,
num_stages=num_stages,
num_warps=num_warps,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
waves_per_eu=kwargs.get("waves_per_eu", 0),
kpack=kwargs.get("kpack", 2),
input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type]
output_tensor_meta=TensorMeta.from_irnodes(layout),
workspace_arg=workspace_arg,
)
return TritonTemplateCaller(
kernel_hash_name,
full_input_nodes,
layout,
make_kernel_render,
extra.strip("-").replace("-", ", "),
bmreq,
log_info={
"tile_shape": str(
(
kwargs.get("BLOCK_M", -1),
kwargs.get("BLOCK_K", -1),
kwargs.get("BLOCK_N", -1),
)
),
"num_stages": num_stages,
"num_warps": num_warps,
"allow_tf32": str(kwargs.get("ALLOW_TF32", None)),
"acc_type": str(kwargs.get("ACC_TYPE", None)),
},
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,
)
class ExternKernelChoice:
def __init__(
self,
kernel,
cpp_kernel=None,
*,
name=None,
has_out_variant=True,
op_overload=None,
use_fallback_kernel=False,
kernel_creator=None,
) -> None:
super().__init__()
name = name or kernel.__name__
assert callable(kernel)
assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}"
self.name = name
self.cpp_kernel_name = cpp_kernel
self.has_out_variant = has_out_variant
setattr(extern_kernels, name, kernel)
self.op_overload = op_overload
self.use_fallback_kernel = use_fallback_kernel
self.kernel_creator = kernel_creator
def to_callable(self):
return getattr(extern_kernels, self.name)
def call_name(self):
return f"extern_kernels.{self.name}"
@functools.lru_cache(None) # noqa: B019
def hash_key(self):
fn = self.to_callable()
parts = [
self.name,
getattr(fn, "__name__", ""),
getattr(fn, "__module__", ""),
]
try: