|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | | -from typing import List, Optional |
| 9 | +import operator |
| 10 | +from typing import Optional |
10 | 11 |
|
11 | 12 | import torch |
12 | 13 | from executorch.exir.delegate import executorch_call_delegate |
13 | | -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue |
| 14 | +from executorch.exir.pass_base import ExportPass, ProxyValue |
14 | 15 | from executorch.exir.tensor import TensorSpec |
15 | 16 | from torch.export.exported_program import ExportGraphSignature |
16 | 17 | from torch.fx.node import Node |
| 18 | +from torch.fx.passes.infra.pass_base import PassResult |
17 | 19 | from torch.utils import _pytree as pytree |
18 | 20 |
|
19 | 21 |
|
@@ -52,6 +54,42 @@ class SpecPropPass(ExportPass): |
52 | 54 | def __init__(self) -> None: |
53 | 55 | super().__init__() |
54 | 56 |
|
| 57 | + def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 58 | + # Re-trace metadata to ensure it's up to date. |
| 59 | + res = ExportPass()(graph_module) |
| 60 | + assert res is not None |
| 61 | + gm = res.graph_module |
| 62 | + |
| 63 | + def get_spec(x): |
| 64 | + if hasattr(x, "meta"): |
| 65 | + return x.meta.get("spec", None) |
| 66 | + else: |
| 67 | + return None |
| 68 | + |
| 69 | + for module in gm.modules(): |
| 70 | + if isinstance(module, torch.fx.GraphModule): |
| 71 | + for node in module.graph.nodes: |
| 72 | + meta_val = node.meta.get("val", None) |
| 73 | + |
| 74 | + if node.op == "output": |
| 75 | + node.meta["spec"] = pytree.tree_map(get_spec, node.args[0]) |
| 76 | + elif node.op == "call_function" and node.target == operator.getitem: |
| 77 | + value_spec = pytree.tree_map(get_spec, node.args[0]) |
| 78 | + node.meta["spec"] = value_spec[node.args[1]] |
| 79 | + elif ( |
| 80 | + node.op == "call_function" |
| 81 | + and node.target == executorch_call_delegate |
| 82 | + ): |
| 83 | + if "spec" not in node.meta: |
| 84 | + node.meta["spec"] = pytree.tree_map(make_spec, meta_val) |
| 85 | + else: |
| 86 | + if meta_val is not None: |
| 87 | + node.meta["spec"] = pytree.tree_map(make_spec, meta_val) |
| 88 | + return res |
| 89 | + |
| 90 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 91 | + return self(graph_module) |
| 92 | + |
55 | 93 | def on_attr(self, attr: ProxyValue) -> None: |
56 | 94 | attr.node.meta["spec"] = pytree.tree_map_only( |
57 | 95 | torch.Tensor, |
@@ -84,85 +122,3 @@ def update_placeholder_tensor_specs( |
84 | 122 | in exported_program.graph_signature.inputs_to_lifted_tensor_constants |
85 | 123 | ): |
86 | 124 | spec.const = True |
87 | | - |
88 | | - # pyre-ignore |
89 | | - def placeholder(self, name: str, arg, meta): |
90 | | - meta["spec"] = make_spec(arg) |
91 | | - return super().placeholder(name, arg, meta) |
92 | | - |
93 | | - # pyre-ignore |
94 | | - def call_operator(self, op, args, kwargs, meta): |
95 | | - args_data, kwargs_data = pytree.tree_map_only( |
96 | | - ProxyValue, lambda x: x.data, (args, kwargs) |
97 | | - ) |
98 | | - meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data)) |
99 | | - return super().call_operator(op, args, kwargs, meta) |
100 | | - |
101 | | - # pyre-ignore |
102 | | - def call_getitem(self, value, key: int, meta): |
103 | | - meta["spec"] = value.node.meta["spec"][key] |
104 | | - return super().call_getitem(value, key, meta) |
105 | | - |
106 | | - # pyre-ignore |
107 | | - def call_cond(self, pred, true_fn, false_fn, inputs, meta): |
108 | | - # true_fn/false_fn return tensors of the same shape, so we can pick |
109 | | - # either one here. |
110 | | - *_, true_out_node = true_fn.graph.nodes |
111 | | - meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) |
112 | | - return super().call_cond(pred, true_fn, false_fn, inputs, meta) |
113 | | - |
114 | | - def call_while( |
115 | | - self, |
116 | | - cond_fn: torch.fx.GraphModule, |
117 | | - body_fn: torch.fx.GraphModule, |
118 | | - carried_inputs: List[ProxyValue], |
119 | | - additional_inputs: List[ProxyValue], |
120 | | - meta: NodeMetadata, |
121 | | - ): |
122 | | - meta["spec"] = pytree.tree_map(make_spec, carried_inputs) |
123 | | - return super().call_while( |
124 | | - cond_fn, body_fn, carried_inputs, additional_inputs, meta |
125 | | - ) |
126 | | - |
127 | | - def call_map( |
128 | | - self, |
129 | | - f: torch.fx.GraphModule, |
130 | | - mapped_args: List[ProxyValue], |
131 | | - operands: List[ProxyValue], |
132 | | - meta: NodeMetadata, |
133 | | - ) -> ProxyValue: |
134 | | - mapped_dim_size = [arg.data for arg in mapped_args][0].size(0) |
135 | | - *_, body_out_node = f.graph.nodes |
136 | | - body_out_node_fake_tensor = body_out_node.meta["val"] |
137 | | - map_fake_tensor = pytree.tree_map_only( |
138 | | - torch.Tensor, |
139 | | - lambda x: x.new_empty(mapped_dim_size, *x.shape), |
140 | | - body_out_node_fake_tensor, |
141 | | - ) |
142 | | - meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) |
143 | | - return super().call_map(f, mapped_args, operands, meta) |
144 | | - |
145 | | - # pyre-ignore |
146 | | - def call_delegate(self, lowered_module, args, kwargs, meta): |
147 | | - args_data, kwargs_data = pytree.tree_map_only( |
148 | | - ProxyValue, lambda x: x.data, (args, kwargs) |
149 | | - ) |
150 | | - # If spec is missing, re-genenrate it with args data |
151 | | - if "spec" not in meta: |
152 | | - meta["spec"] = pytree.tree_map( |
153 | | - make_spec, |
154 | | - executorch_call_delegate(lowered_module, *args_data), |
155 | | - ) |
156 | | - return super().call_delegate(lowered_module, args, kwargs, meta) |
157 | | - |
158 | | - # pyre-ignore |
159 | | - def output(self, results, meta): |
160 | | - # pyre-ignore |
161 | | - def get_spec(x): |
162 | | - if isinstance(x, ProxyValue): |
163 | | - return x.node.meta["spec"] |
164 | | - else: |
165 | | - return make_spec(x) |
166 | | - |
167 | | - meta["spec"] = pytree.tree_map(get_spec, results) |
168 | | - return super().output(results, meta) |
0 commit comments