Skip to content

Commit 5eb323f

Browse files
authored
Merge pull request #2641 from pytorch/attention_converter_cherry_pick
cherry-pick: Attention converter and linting fixes
2 parents c189b4c + 4d11385 commit 5eb323f

File tree

7 files changed

+271
-63
lines changed

7 files changed

+271
-63
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ repos:
4747
hooks:
4848
- id: ruff
4949
- repo: https://github.com/psf/black
50-
rev: 23.7.0
50+
rev: 24.1.1
5151
hooks:
5252
- id: black
5353
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -2243,7 +2243,14 @@ def tensorrt_scaled_dot_product_attention(
22432243
name: str,
22442244
) -> Union[TRTTensor, Sequence[TRTTensor]]:
22452245
return impl.attention.scaled_dot_product_attention(
2246-
ctx, target, SourceIR.TORCHTRT_LOWERED, name, args[0], args[1], args[2]
2246+
ctx,
2247+
target,
2248+
SourceIR.TORCHTRT_LOWERED,
2249+
name,
2250+
args[0],
2251+
args[1],
2252+
args[2],
2253+
kwargs.get("scale", None),
22472254
)
22482255

22492256

py/torch_tensorrt/dynamo/conversion/impl/attention.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def scaled_dot_product_attention(
1717
query: TRTTensor,
1818
key: TRTTensor,
1919
value: TRTTensor,
20+
scale: Optional[float],
2021
) -> TRTTensor:
2122
mm = impl.matmul.matrix_multiply(
2223
ctx,
@@ -27,16 +28,26 @@ def scaled_dot_product_attention(
2728
key,
2829
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
2930
)
30-
div = impl.elementwise.div(
31-
ctx,
32-
target,
33-
source_ir,
34-
name + "_scale",
35-
mm,
36-
math.sqrt(query.shape[-1]),
37-
)
31+
if scale is None:
32+
scaled = impl.elementwise.div(
33+
ctx,
34+
target,
35+
source_ir,
36+
name + "_scale",
37+
mm,
38+
math.sqrt(query.shape[-1]),
39+
)
40+
else:
41+
scaled = impl.elementwise.mul(
42+
ctx,
43+
target,
44+
source_ir,
45+
name + "_scale",
46+
mm,
47+
scale,
48+
)
3849
softmax = impl.normalization.softmax(
39-
ctx, target, source_ir, name + "_softmax", div, -1
50+
ctx, target, source_ir, name + "_softmax", scaled, -1
4051
)
4152
out = impl.matmul.matrix_multiply(
4253
ctx,

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from .constant_folding import constant_fold
77
from .fuse_prims_broadcast import fuse_prims_broadcast
8-
from .lower_efficient_attention import lower_efficient_attention
98
from .lower_linear import lower_linear
9+
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
1010
from .pass_manager import DynamoPassManager
1111
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1212
from .repair_input_as_output import repair_input_as_output
@@ -18,7 +18,7 @@
1818
remove_input_alias_fixing_clones,
1919
constant_fold,
2020
repair_input_as_output,
21-
lower_efficient_attention,
21+
lower_scaled_dot_product_attention,
2222
lower_linear,
2323
fuse_prims_broadcast,
2424
replace_max_pool_with_indices,

py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py

-50
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import logging
2+
import operator
3+
from typing import Callable, Sequence, Tuple
4+
5+
import torch
6+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
7+
clean_up_graph_after_modifications,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
REPLACEABLE_ATEN_OPS = {
12+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
13+
torch.ops.aten._scaled_dot_product_flash_attention.default,
14+
}
15+
16+
17+
def lower_scaled_dot_product_attention(
18+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
19+
) -> torch.fx.GraphModule:
20+
"""Replace specific versions of scaled_dot_product_attention with an equivalent
21+
implementation which can be easily converted to TRT
22+
"""
23+
original_fns, replacement = scaled_dot_product_attention_replacement()
24+
replaced_nodes = []
25+
26+
# For each original function, search for it in the graph and replace
27+
for original in original_fns:
28+
replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters(
29+
gm,
30+
original,
31+
replacement,
32+
ignore_literals=True,
33+
)
34+
35+
if replaced_nodes:
36+
# Repair instances which use the kwargs field (specifically the "scale" kwarg)
37+
for match in replaced_nodes:
38+
attention_node_replaced = None
39+
# Seek the attention operator being replaced
40+
for node in match.nodes_map:
41+
if node.target in REPLACEABLE_ATEN_OPS:
42+
attention_node_replaced = match.nodes_map[node]
43+
break
44+
45+
assert attention_node_replaced is not None
46+
47+
# If the attention operator had keyword-args, copy them to the new node
48+
if attention_node_replaced.kwargs:
49+
assert len(match.replacements) == 1
50+
new_attention_node = match.replacements[0]
51+
assert (
52+
new_attention_node.target
53+
== torch.nn.functional.scaled_dot_product_attention
54+
)
55+
new_attention_node.kwargs = {**attention_node_replaced.kwargs}
56+
57+
gm = clean_up_graph_after_modifications(gm)
58+
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
59+
60+
return gm
61+
62+
63+
def scaled_dot_product_attention_replacement() -> Tuple[
64+
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
65+
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
66+
]:
67+
"""Constructs the original and replacement functions for efficient attention"""
68+
69+
# Efficient Attention original graph
70+
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
71+
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
72+
q,
73+
k,
74+
v,
75+
None,
76+
False,
77+
)
78+
out = operator.getitem(outputs, 0)
79+
return out
80+
81+
# Flash Attention original graph
82+
def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
83+
outputs = torch.ops.aten._scaled_dot_product_flash_attention.default(
84+
q,
85+
k,
86+
v,
87+
)
88+
out = operator.getitem(outputs, 0)
89+
return out
90+
91+
# Efficient Attention w/Scale original graph
92+
def efficient_scale(
93+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
94+
) -> torch.Tensor:
95+
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
96+
q,
97+
k,
98+
v,
99+
None,
100+
False,
101+
scale=1.0,
102+
)
103+
out = operator.getitem(outputs, 0)
104+
return out
105+
106+
# Flash Attention w/Scale original graph
107+
def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
108+
outputs = torch.ops.aten._scaled_dot_product_flash_attention.default(
109+
q,
110+
k,
111+
v,
112+
scale=1.0,
113+
)
114+
out = operator.getitem(outputs, 0)
115+
return out
116+
117+
# Replacement graph consists of the functional version of scaled_dot_product_attention
118+
def replacement(
119+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
120+
) -> torch.Tensor:
121+
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
122+
123+
return (efficient, flash, efficient_scale, flash_scale), replacement

tests/py/dynamo/lowering/test_aten_lowering_passes.py

+117
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,123 @@ def forward(self, q, k, v):
267267
torch._dynamo.reset()
268268

269269

270+
class TestLowerFlashAttention(TestCase):
271+
def test_lower_flash_attention(self):
272+
class FlashAttention(torch.nn.Module):
273+
def forward(self, q, k, v):
274+
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
275+
q,
276+
k,
277+
v,
278+
scale=0.15,
279+
)
280+
return attn[0]
281+
282+
inputs = [
283+
torch.rand(8, 4, 16, 8).half().cuda(),
284+
torch.rand(8, 4, 16, 8).half().cuda(),
285+
torch.rand(8, 4, 16, 8).half().cuda(),
286+
]
287+
288+
fx_graph = torch.fx.symbolic_trace(FlashAttention())
289+
expected_ops = {torch.nn.functional.scaled_dot_product_attention}
290+
unexpected_ops = {torch.ops.aten._scaled_dot_product_flash_attention.default}
291+
292+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
293+
fx_graph,
294+
inputs,
295+
expected_ops=expected_ops,
296+
unexpected_ops=unexpected_ops,
297+
min_block_size=1,
298+
)
299+
300+
self.assertEquals(
301+
len(unexpected_ops_seen),
302+
0,
303+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
304+
)
305+
306+
self.assertEquals(
307+
len(expected_ops_unseen),
308+
0,
309+
f"The following expected ops were not encountered: {expected_ops_unseen}",
310+
)
311+
torch._dynamo.reset()
312+
313+
# Validate that the results between Torch and Torch-TRT are similar
314+
optimized_model = torch_tensorrt.compile(
315+
fx_graph,
316+
"torch_compile",
317+
inputs,
318+
min_block_size=1,
319+
pass_through_build_failures=True,
320+
)
321+
optimized_model_results = torch.cat(
322+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
323+
)
324+
torch_model_results = torch.cat(
325+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
326+
)
327+
328+
max_diff = float(
329+
torch.max(torch.abs(optimized_model_results - torch_model_results))
330+
)
331+
# Remove 1 decimal from the requirement for FP16
332+
self.assertAlmostEqual(
333+
max_diff,
334+
0,
335+
DECIMALS_OF_AGREEMENT - 1,
336+
msg=f"FlashAttention TRT outputs don't match with the original model.",
337+
)
338+
torch._dynamo.reset()
339+
340+
def test_flash_attention_converter(self):
341+
class FlashAttention(torch.nn.Module):
342+
def forward(self, q, k, v):
343+
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
344+
q,
345+
k,
346+
v,
347+
scale=0.25,
348+
)
349+
return attn[0]
350+
351+
inputs = [
352+
torch.rand(1, 3, 6, 8).half().cuda(),
353+
torch.rand(1, 3, 2, 8).half().cuda(),
354+
torch.rand(1, 3, 2, 8).half().cuda(),
355+
]
356+
357+
fx_graph = torch.fx.symbolic_trace(FlashAttention())
358+
359+
# Validate that the results between Torch and Torch-TRT are similar
360+
optimized_model = torch_tensorrt.compile(
361+
fx_graph,
362+
"torch_compile",
363+
inputs,
364+
min_block_size=1,
365+
pass_through_build_failures=True,
366+
)
367+
optimized_model_results = torch.cat(
368+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
369+
)
370+
torch_model_results = torch.cat(
371+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
372+
)
373+
374+
max_diff = float(
375+
torch.max(torch.abs(optimized_model_results - torch_model_results))
376+
)
377+
# Remove 1 decimal from the requirement for FP16
378+
self.assertAlmostEqual(
379+
max_diff,
380+
0,
381+
DECIMALS_OF_AGREEMENT - 1,
382+
msg=f"FlashAttention TRT outputs don't match with the original model.",
383+
)
384+
torch._dynamo.reset()
385+
386+
270387
class TestLowerLinear(TestCase):
271388
def test_lower_linear(self):
272389
class Linear(torch.nn.Module):

0 commit comments

Comments
 (0)