|
| 1 | +import logging |
| 2 | +import operator |
| 3 | +from typing import Callable, Sequence, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( |
| 7 | + clean_up_graph_after_modifications, |
| 8 | +) |
| 9 | + |
| 10 | +logger = logging.getLogger(__name__) |
| 11 | +REPLACEABLE_ATEN_OPS = { |
| 12 | + torch.ops.aten._scaled_dot_product_efficient_attention.default, |
| 13 | + torch.ops.aten._scaled_dot_product_flash_attention.default, |
| 14 | +} |
| 15 | + |
| 16 | + |
| 17 | +def lower_scaled_dot_product_attention( |
| 18 | + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] |
| 19 | +) -> torch.fx.GraphModule: |
| 20 | + """Replace specific versions of scaled_dot_product_attention with an equivalent |
| 21 | + implementation which can be easily converted to TRT |
| 22 | + """ |
| 23 | + original_fns, replacement = scaled_dot_product_attention_replacement() |
| 24 | + replaced_nodes = [] |
| 25 | + |
| 26 | + # For each original function, search for it in the graph and replace |
| 27 | + for original in original_fns: |
| 28 | + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( |
| 29 | + gm, |
| 30 | + original, |
| 31 | + replacement, |
| 32 | + ignore_literals=True, |
| 33 | + ) |
| 34 | + |
| 35 | + if replaced_nodes: |
| 36 | + # Repair instances which use the kwargs field (specifically the "scale" kwarg) |
| 37 | + for match in replaced_nodes: |
| 38 | + attention_node_replaced = None |
| 39 | + # Seek the attention operator being replaced |
| 40 | + for node in match.nodes_map: |
| 41 | + if node.target in REPLACEABLE_ATEN_OPS: |
| 42 | + attention_node_replaced = match.nodes_map[node] |
| 43 | + break |
| 44 | + |
| 45 | + assert attention_node_replaced is not None |
| 46 | + |
| 47 | + # If the attention operator had keyword-args, copy them to the new node |
| 48 | + if attention_node_replaced.kwargs: |
| 49 | + assert len(match.replacements) == 1 |
| 50 | + new_attention_node = match.replacements[0] |
| 51 | + assert ( |
| 52 | + new_attention_node.target |
| 53 | + == torch.nn.functional.scaled_dot_product_attention |
| 54 | + ) |
| 55 | + new_attention_node.kwargs = {**attention_node_replaced.kwargs} |
| 56 | + |
| 57 | + gm = clean_up_graph_after_modifications(gm) |
| 58 | + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") |
| 59 | + |
| 60 | + return gm |
| 61 | + |
| 62 | + |
| 63 | +def scaled_dot_product_attention_replacement() -> Tuple[ |
| 64 | + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], |
| 65 | + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], |
| 66 | +]: |
| 67 | + """Constructs the original and replacement functions for efficient attention""" |
| 68 | + |
| 69 | + # Efficient Attention original graph |
| 70 | + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
| 71 | + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( |
| 72 | + q, |
| 73 | + k, |
| 74 | + v, |
| 75 | + None, |
| 76 | + False, |
| 77 | + ) |
| 78 | + out = operator.getitem(outputs, 0) |
| 79 | + return out |
| 80 | + |
| 81 | + # Flash Attention original graph |
| 82 | + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
| 83 | + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( |
| 84 | + q, |
| 85 | + k, |
| 86 | + v, |
| 87 | + ) |
| 88 | + out = operator.getitem(outputs, 0) |
| 89 | + return out |
| 90 | + |
| 91 | + # Efficient Attention w/Scale original graph |
| 92 | + def efficient_scale( |
| 93 | + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| 94 | + ) -> torch.Tensor: |
| 95 | + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( |
| 96 | + q, |
| 97 | + k, |
| 98 | + v, |
| 99 | + None, |
| 100 | + False, |
| 101 | + scale=1.0, |
| 102 | + ) |
| 103 | + out = operator.getitem(outputs, 0) |
| 104 | + return out |
| 105 | + |
| 106 | + # Flash Attention w/Scale original graph |
| 107 | + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
| 108 | + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( |
| 109 | + q, |
| 110 | + k, |
| 111 | + v, |
| 112 | + scale=1.0, |
| 113 | + ) |
| 114 | + out = operator.getitem(outputs, 0) |
| 115 | + return out |
| 116 | + |
| 117 | + # Replacement graph consists of the functional version of scaled_dot_product_attention |
| 118 | + def replacement( |
| 119 | + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| 120 | + ) -> torch.Tensor: |
| 121 | + return torch.nn.functional.scaled_dot_product_attention(query, key, value) |
| 122 | + |
| 123 | + return (efficient, flash, efficient_scale, flash_scale), replacement |
0 commit comments