Skip to content

Commit 8ebf24d

Browse files
authored
perf: Add lowering passes to improve TRT runtime on SD (#2351)
1 parent ef07bea commit 8ebf24d

18 files changed

+575
-28
lines changed

docsrc/contributors/writing_dynamo_aten_lowering_passes.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Lowering Pass Requirements
1212
------------
1313

1414
An ATen lowering pass function in Torch-TRT must satisfy two requirements:
15-
- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule`
15+
- The function must take as input a `torch.fx.GraphModule` and a sequence of torch Tensors, `Sequence[torch.Tensor]`, and return the lowered `torch.fx.GraphModule`
1616
- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation
1717

1818
See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines.
@@ -22,7 +22,7 @@ Example Lowering Pass
2222

2323
.. code-block:: python
2424
25-
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
25+
def repair_input_as_output(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
2626
"""Repair scenarios where inputs are also outputs of the graph
2727
2828
TRT does not allow such cases, so we insert a clone (identity) layer
@@ -82,15 +82,15 @@ For instance, to insert the pass at the default location (end of the list), the
8282
.. code-block:: python
8383
8484
@_aten_lowering_pass
85-
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
85+
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
8686
...
8787
8888
Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used:
8989

9090
.. code-block:: python
9191
9292
@_aten_lowering_pass(index=0)
93-
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
93+
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
9494
...
9595
9696
There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index.
@@ -101,7 +101,7 @@ There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for
101101
print(dump_lowering_passes())
102102
103103
# Apply lowering passes to a GraphModule
104-
apply_lowering_passes(graph_module)
104+
apply_lowering_passes(graph_module, sample_inputs)
105105
106106
# Remove the lowering pass at index 1
107107
_remove_lowering_pass(index=1)

py/torch_tensorrt/dynamo/aten_tracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ def trace(
2828
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
2929
):
3030
graph_module = export(model, tuple(inputs)).module()
31-
graph_module = apply_lowering_passes(graph_module)
31+
graph_module = apply_lowering_passes(graph_module, inputs)
3232
logger.debug("Post export graph: " + str(graph_module.graph))
3333
return graph_module

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _pretraced_backend(
8787

8888
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
8989

90-
gm = apply_lowering_passes(gm)
90+
gm = apply_lowering_passes(gm, sample_inputs)
9191

9292
trt_compiled = compile_module(
9393
gm,

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,3 +1517,18 @@ def aten_ops_max_pool(
15171517
dilation=args_bounds_check(args, 4, replacement=1),
15181518
ceil_mode=args_bounds_check(args, 5, replacement=False),
15191519
)
1520+
1521+
1522+
@dynamo_tensorrt_converter(
1523+
torch.nn.functional.scaled_dot_product_attention,
1524+
) # type: ignore[misc]
1525+
def tensorrt_scaled_dot_product_attention(
1526+
ctx: ConversionContext,
1527+
target: Target,
1528+
args: Tuple[Argument, ...],
1529+
kwargs: Dict[str, Argument],
1530+
name: str,
1531+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1532+
return impl.attention.scaled_dot_product_attention(
1533+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
1534+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import (
44
activation,
5+
attention,
56
cast,
67
condition,
78
conv,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import math
2+
from typing import Optional, Union
3+
4+
import tensorrt as trt
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo.conversion import impl
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
9+
from torch_tensorrt.fx.types import TRTTensor
10+
11+
12+
def scaled_dot_product_attention(
13+
ctx: ConversionContext,
14+
target: Union[Target, str],
15+
source_ir: Optional[SourceIR],
16+
name: str,
17+
query: TRTTensor,
18+
key: TRTTensor,
19+
value: TRTTensor,
20+
) -> TRTTensor:
21+
mm = impl.matmul.matrix_multiply(
22+
ctx,
23+
target,
24+
source_ir,
25+
name + "_mm",
26+
query,
27+
key,
28+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
29+
)
30+
div = impl.elementwise.div(
31+
ctx,
32+
target,
33+
source_ir,
34+
name + "_scale",
35+
mm,
36+
math.sqrt(query.shape[-1]),
37+
)
38+
softmax = impl.normalization.softmax(
39+
ctx, target, source_ir, name + "_softmax", div, -1
40+
)
41+
out = impl.matmul.matrix_multiply(
42+
ctx,
43+
target,
44+
source_ir,
45+
name + "_out",
46+
softmax,
47+
value,
48+
)
49+
50+
return out

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
aten.special_log_ndtr,
150150
aten.special_xlog1py,
151151
aten.stack,
152+
aten.std,
152153
aten.t,
153154
aten.tanh_backward,
154155
aten.threshold,
@@ -163,6 +164,8 @@
163164
aten.upsample_bilinear2d,
164165
aten.upsample_bilinear2d.vec,
165166
aten.upsample_nearest2d_backward,
167+
aten.var,
168+
aten.var_mean,
166169
aten.xlogy,
167170
aten.zero,
168171
aten.zero_,

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Callable, Dict, Optional
2+
from typing import Any, Callable, Dict, List, Optional
33

44
import torch
55
from torch._decomp import register_decomposition
@@ -83,11 +83,6 @@ def inplace_op(*args, **kwargs): # type: ignore
8383
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)
8484

8585

86-
@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS)
87-
def std_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
88-
return torch.sqrt(torch.var(*args, **kwargs))
89-
90-
9186
@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
9287
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
9388
return torch.reciprocal(torch.sqrt(*args, **kwargs))
@@ -135,6 +130,54 @@ def reciprocal_replacement(
135130
return torch.div(1, input_)
136131

137132

133+
@register_torch_trt_decomposition(
134+
torch.ops.prims.var.default, registry=TORCH_TRT_DECOMPOSITIONS
135+
)
136+
def var_decomposition(
137+
input_tensor: torch.Tensor,
138+
dims: Optional[List[int]],
139+
correction: int,
140+
output_dtype: Optional[torch.dtype] = None,
141+
) -> torch.Tensor:
142+
if dims is None:
143+
dims = []
144+
145+
# If the dimensions are empty, variance is taken over all dimensions
146+
if isinstance(dims, (tuple, list)) and len(dims) == 0:
147+
N = input_tensor.numel()
148+
# Otherwise, the number of samples is the product of the dimensions reduced over
149+
else:
150+
N = 1
151+
for dim_i in dims:
152+
N *= input_tensor.shape[dim_i]
153+
154+
# Compute the mean, difference, and correction term as per the formula:
155+
# https://pytorch.org/docs/stable/generated/torch.var.html
156+
157+
# Additionally, prims does not support keepdim, and so we only keep dimensions
158+
# on the first reduction, then remove it for the second
159+
sample_mean = torch.mean(input_tensor, dims, keepdim=True)
160+
diff = input_tensor - sample_mean
161+
squared_diff = diff * diff
162+
variance_unnormalized = torch.sum(squared_diff, dims, keepdim=False)
163+
164+
if correction is None:
165+
correction_term = float(N - 1)
166+
elif isinstance(correction, int):
167+
correction_term = float(N - correction)
168+
elif isinstance(correction, float):
169+
correction_term = float(N) - correction
170+
else:
171+
raise RuntimeError("correction must be int or float")
172+
173+
if correction_term <= 0:
174+
raise RuntimeError(f"correction term was non-positive, got: {correction_term}")
175+
176+
variance = variance_unnormalized / correction_term
177+
178+
return variance
179+
180+
138181
def get_decompositions(
139182
enable_experimental_decompositions: bool = False,
140183
) -> Dict[OpOverload, Callable[[Any], Any]]:

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2-
from typing import Callable, Optional
2+
from typing import Callable, Optional, Sequence, Union
33

44
import torch
55

66
from .constant_folding import constant_fold
7+
from .fuse_prims_broadcast import fuse_prims_broadcast
8+
from .lower_efficient_attention import lower_efficient_attention
79
from .pass_manager import DynamoPassManager
810
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
911
from .repair_input_as_output import repair_input_as_output
@@ -13,19 +15,25 @@
1315
remove_input_alias_fixing_clones,
1416
constant_fold,
1517
repair_input_as_output,
18+
lower_efficient_attention,
19+
fuse_prims_broadcast,
1620
]
1721
)
1822

1923
logger = logging.getLogger(__name__)
2024

2125

22-
LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
26+
LoweringPassSignature = Callable[
27+
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
28+
]
2329

2430

2531
def _aten_lowering_pass(
2632
*args: LoweringPassSignature,
2733
index: Optional[int] = None,
28-
) -> LoweringPassSignature:
34+
) -> Union[
35+
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
36+
]:
2937
"""Adds a lowering pass to the registry, at a specified index if desired
3038
3139
If no index is specified, the lowering pass is inserted at the end of the list
@@ -65,12 +73,14 @@ def _remove_lowering_pass(*, index: int) -> None:
6573
return
6674

6775

68-
def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
76+
def apply_lowering_passes(
77+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
78+
) -> torch.fx.GraphModule:
6979
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
7080
logging.debug(
7181
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
7282
)
73-
return ATEN_LOWERING_PASSES(gm)
83+
return ATEN_LOWERING_PASSES(gm, sample_inputs)
7484

7585

7686
def dump_lowering_passes() -> str:

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Sequence
23

34
import torch
45
from torch_tensorrt._utils import sanitized_torch_version
@@ -21,7 +22,9 @@
2122

2223

2324
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
24-
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
25+
def constant_fold(
26+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
27+
) -> torch.fx.GraphModule:
2528
"""Adapted from:
2629
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
2730
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import logging
2+
from typing import Sequence
3+
4+
import torch
5+
from torch.fx.passes.shape_prop import ShapeProp
6+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
7+
clean_up_graph_after_modifications,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
# TODO: Add relevant prims to this fusion
14+
def fuse_prims_broadcast(
15+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
16+
) -> torch.fx.GraphModule:
17+
"""Fuses prim nodes which are effectively the ATen equivalents with keep_dim=True"""
18+
modified_graph = False
19+
20+
# Propagate shapes through the graph to determine if broadcast can be resolved
21+
try:
22+
ShapeProp(gm).propagate(*sample_inputs)
23+
except (RuntimeError, AssertionError):
24+
logger.warning(
25+
"Shape Propagation Failed on Graph, skipping fuse_prims_broadcast lowering pass",
26+
exc_info=True,
27+
)
28+
return gm
29+
30+
for node in gm.graph.nodes:
31+
# If the node is a sum prims operator, with broadcast_in_dim being the only consumer
32+
# it is a candidate for fusing
33+
if (
34+
node.target in (torch.ops.prims.sum.default,)
35+
and len(node.users) == 1
36+
and list(node.users)[0].target == torch.ops.prims.broadcast_in_dim.default
37+
):
38+
# Get broadcasted shape, reduced dimensions, and original tensor shape
39+
broadcast_node = list(node.users)[0]
40+
broadcasted_shape = broadcast_node.args[1]
41+
reduced_dims = node.args[1]
42+
original_shape = node.args[0].meta["tensor_meta"].shape
43+
44+
# If the rank of the broadcasted shape is the same as the original
45+
# and the broadcasts are all singletons for the reduced dimensions
46+
# and all of the non-reduced dimensions are identical to the originals
47+
48+
# Then the broadcast is effectively performing a "keep_dim=True" operation
49+
if (
50+
len(broadcasted_shape) == len(original_shape)
51+
and all(broadcasted_shape[i] == 1 for i in reduced_dims)
52+
and all(
53+
broadcasted_shape[j] == original_shape[j]
54+
for j in range(len(original_shape))
55+
if j not in reduced_dims
56+
)
57+
):
58+
# Fuse the operator to its convertible alternative
59+
with gm.graph.inserting_after(broadcast_node):
60+
modified_graph = True
61+
62+
if node.target == torch.ops.prims.sum.default:
63+
fused_node = gm.graph.call_function(
64+
torch.ops.aten.sum.dim_IntList,
65+
args=(node.args[0], reduced_dims, True),
66+
)
67+
68+
# Replace all uses of the placeholder except the cloned node
69+
# with the cloned placeholder
70+
broadcast_node.replace_all_uses_with(
71+
fused_node,
72+
)
73+
74+
# Erase uses of the broadcast node and original
75+
gm.graph.erase_node(broadcast_node)
76+
gm.graph.erase_node(node)
77+
78+
if modified_graph:
79+
gm = clean_up_graph_after_modifications(gm)
80+
logger.debug(f"Graph after fusing prims-broadcast paradigm:\n{gm.graph}")
81+
82+
return gm

0 commit comments

Comments
 (0)