Skip to content

Commit 01418d5

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Fix shape_env handling in SpecPropPass (WIP) (#15485)
Summary: Pull Request resolved: #15485 Differential Revision: D85913581
1 parent c6308a9 commit 01418d5

File tree

2 files changed

+87
-84
lines changed

2 files changed

+87
-84
lines changed

exir/passes/spec_prop_pass.py

Lines changed: 37 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
# pyre-strict
88

9-
from typing import List, Optional
9+
from typing import Optional
1010

11+
import operator
1112
import torch
1213
from executorch.exir.delegate import executorch_call_delegate
13-
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
14+
from executorch.exir.pass_base import ExportPass, ProxyValue
1415
from executorch.exir.tensor import TensorSpec
1516
from torch.export.exported_program import ExportGraphSignature
1617
from torch.fx.node import Node
18+
from torch.fx.passes.infra.pass_base import PassResult
1719
from torch.utils import _pytree as pytree
1820

1921

@@ -51,6 +53,39 @@ def _is_mutable_buffer(
5153
class SpecPropPass(ExportPass):
5254
def __init__(self) -> None:
5355
super().__init__()
56+
57+
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
58+
# Re-trace metadata to ensure it's up to date.
59+
res = ExportPass()(graph_module)
60+
assert res is not None
61+
gm = res.graph_module
62+
63+
def get_spec(x):
64+
if hasattr(x, "meta"):
65+
return x.meta.get("spec", None)
66+
else:
67+
return None
68+
69+
for module in gm.modules():
70+
if isinstance(module, torch.fx.GraphModule):
71+
for node in module.graph.nodes:
72+
meta_val = node.meta.get("val", None)
73+
74+
if node.op == "output":
75+
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
76+
elif node.op == "call_function" and node.target == operator.getitem:
77+
value_spec = pytree.tree_map(get_spec, node.args[0])
78+
node.meta["spec"] = value_spec[node.args[1]]
79+
elif node.op == "call_function" and node.target == executorch_call_delegate:
80+
if "spec" not in node.meta:
81+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
82+
else:
83+
if meta_val is not None:
84+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
85+
return res
86+
87+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
88+
return self(graph_module)
5489

5590
def on_attr(self, attr: ProxyValue) -> None:
5691
attr.node.meta["spec"] = pytree.tree_map_only(
@@ -84,85 +119,3 @@ def update_placeholder_tensor_specs(
84119
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
85120
):
86121
spec.const = True
87-
88-
# pyre-ignore
89-
def placeholder(self, name: str, arg, meta):
90-
meta["spec"] = make_spec(arg)
91-
return super().placeholder(name, arg, meta)
92-
93-
# pyre-ignore
94-
def call_operator(self, op, args, kwargs, meta):
95-
args_data, kwargs_data = pytree.tree_map_only(
96-
ProxyValue, lambda x: x.data, (args, kwargs)
97-
)
98-
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
99-
return super().call_operator(op, args, kwargs, meta)
100-
101-
# pyre-ignore
102-
def call_getitem(self, value, key: int, meta):
103-
meta["spec"] = value.node.meta["spec"][key]
104-
return super().call_getitem(value, key, meta)
105-
106-
# pyre-ignore
107-
def call_cond(self, pred, true_fn, false_fn, inputs, meta):
108-
# true_fn/false_fn return tensors of the same shape, so we can pick
109-
# either one here.
110-
*_, true_out_node = true_fn.graph.nodes
111-
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
112-
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
113-
114-
def call_while(
115-
self,
116-
cond_fn: torch.fx.GraphModule,
117-
body_fn: torch.fx.GraphModule,
118-
carried_inputs: List[ProxyValue],
119-
additional_inputs: List[ProxyValue],
120-
meta: NodeMetadata,
121-
):
122-
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
123-
return super().call_while(
124-
cond_fn, body_fn, carried_inputs, additional_inputs, meta
125-
)
126-
127-
def call_map(
128-
self,
129-
f: torch.fx.GraphModule,
130-
mapped_args: List[ProxyValue],
131-
operands: List[ProxyValue],
132-
meta: NodeMetadata,
133-
) -> ProxyValue:
134-
mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
135-
*_, body_out_node = f.graph.nodes
136-
body_out_node_fake_tensor = body_out_node.meta["val"]
137-
map_fake_tensor = pytree.tree_map_only(
138-
torch.Tensor,
139-
lambda x: x.new_empty(mapped_dim_size, *x.shape),
140-
body_out_node_fake_tensor,
141-
)
142-
meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
143-
return super().call_map(f, mapped_args, operands, meta)
144-
145-
# pyre-ignore
146-
def call_delegate(self, lowered_module, args, kwargs, meta):
147-
args_data, kwargs_data = pytree.tree_map_only(
148-
ProxyValue, lambda x: x.data, (args, kwargs)
149-
)
150-
# If spec is missing, re-genenrate it with args data
151-
if "spec" not in meta:
152-
meta["spec"] = pytree.tree_map(
153-
make_spec,
154-
executorch_call_delegate(lowered_module, *args_data),
155-
)
156-
return super().call_delegate(lowered_module, args, kwargs, meta)
157-
158-
# pyre-ignore
159-
def output(self, results, meta):
160-
# pyre-ignore
161-
def get_spec(x):
162-
if isinstance(x, ProxyValue):
163-
return x.node.meta["spec"]
164-
else:
165-
return make_spec(x)
166-
167-
meta["spec"] = pytree.tree_map(get_spec, results)
168-
return super().output(results, meta)

exir/tests/test_passes.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
7575
from executorch.exir.program._program import lift_constant_tensor_pass
7676
from executorch.exir.schema import TensorShapeDynamism
77+
from executorch.exir.sym_util import eval_upper_bound
7778
from executorch.exir.tensor import TensorSpec
7879
from executorch.exir.tests.common import register_additional_test_aten_ops
7980
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
@@ -113,6 +114,7 @@ def collect_ops(gm: torch.fx.GraphModule):
113114

114115
lib.define("foo(Tensor self) -> (Tensor, Tensor)")
115116
lib.define("add_relu(Tensor self, Tensor other) -> Tensor")
117+
lib.define("unbacked(Tensor self) -> Tensor")
116118

117119

118120
@impl(lib, "foo", "CompositeExplicitAutograd")
@@ -131,6 +133,27 @@ def foo_out(
131133
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
132134
return a + 1, None
133135

136+
@impl(lib, "unbacked", "CPU")
137+
def unbacked(a: torch.Tensor) -> torch.Tensor:
138+
return a[:a[0]]
139+
140+
@torch.library.register_fake(f"{lib.ns}::unbacked")
141+
def meta_unbacked(x):
142+
ctx = torch._custom_ops.get_ctx()
143+
out_size = ctx.create_unbacked_symint()
144+
return x.new_empty(out_size)
145+
146+
lib.define(
147+
"unbacked.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"
148+
)
149+
150+
@impl(lib, "unbacked.out", "CPU")
151+
def unbacked_out(
152+
x: torch.Tensor,
153+
out: torch.Tensor,
154+
) -> torch.Tensor:
155+
out.copy_(x[x[0]])
156+
134157

135158
def simple_promote_dtype(
136159
dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
@@ -610,6 +633,33 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
610633
self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"])
611634

612635
self.assertEqual(counter, 1)
636+
637+
def test_spec_prop_pass_unbacked_symint(self) -> None:
638+
# Verify that the spec prop pass picks up on guards for
639+
# unbacked symints.
640+
class Unbacked(torch.nn.Module):
641+
def forward(self, x):
642+
output = torch.ops.DO_NOT_USE_TEST_ONLY.unbacked(x)
643+
torch._constrain_as_size(output.shape[0], max=10)
644+
return output
645+
646+
model = Unbacked()
647+
gm = (
648+
to_edge(export(model, (torch.LongTensor([5, 4, 3, 2, 1, 0, 1, 2]),)))
649+
.exported_program()
650+
.graph_module
651+
)
652+
new_gm = SpecPropPass()(gm)
653+
self.assertIsNotNone(new_gm)
654+
655+
# Check the spec for the custom op node. It should have a max size of 10.
656+
op_node = next(n for n in new_gm.graph_module.graph.nodes if n.target == exir_ops.edge.DO_NOT_USE_TEST_ONLY.unbacked.default)
657+
self.assertIsNotNone(op_node)
658+
659+
spec: TensorSpec = op_node.meta["spec"]
660+
self.assertEqual(len(spec.shape), 1) # Should be rank 1
661+
upper_bound = eval_upper_bound(spec.shape[0])
662+
self.assertEqual(upper_bound, 10) # Should be a concrete value
613663

614664
def test_compile_fix_broken_ops(self) -> None:
615665
class ExportableLoop(nn.Module):

0 commit comments

Comments
 (0)