Skip to content

Commit 6c9d1cb

Browse files
committed
Fix shape_env handling in SpecPropPass (WIP) (pytorch#15485)
Summary: Pull Request resolved: pytorch#15485 Differential Revision: D85913581
1 parent 3405317 commit 6c9d1cb

File tree

2 files changed

+91
-85
lines changed

2 files changed

+91
-85
lines changed

exir/passes/spec_prop_pass.py

Lines changed: 35 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
# pyre-strict
88

9-
from typing import List, Optional
9+
import operator
10+
from typing import Optional
1011

1112
import torch
12-
from executorch.exir.delegate import executorch_call_delegate
13-
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
13+
from executorch.exir.pass_base import ExportPass, ProxyValue
1414
from executorch.exir.tensor import TensorSpec
1515
from torch.export.exported_program import ExportGraphSignature
1616
from torch.fx.node import Node
17+
from torch.fx.passes.infra.pass_base import PassResult
1718
from torch.utils import _pytree as pytree
1819

1920

@@ -52,6 +53,37 @@ class SpecPropPass(ExportPass):
5253
def __init__(self) -> None:
5354
super().__init__()
5455

56+
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
57+
# Re-trace metadata to ensure it's up to date.
58+
res = ExportPass()(graph_module)
59+
assert res is not None
60+
gm = res.graph_module
61+
62+
def get_spec(x):
63+
if hasattr(x, "meta"):
64+
return x.meta.get("spec", None)
65+
else:
66+
return None
67+
68+
for module in gm.modules():
69+
if isinstance(module, torch.fx.GraphModule):
70+
for node in module.graph.nodes:
71+
# Preserve pre-existing specs, such as for call_delegate nodes.
72+
if not "spec" in node.meta:
73+
if node.op == "output":
74+
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
75+
elif (
76+
node.op == "call_function"
77+
and node.target == operator.getitem
78+
):
79+
value_spec = pytree.tree_map(get_spec, node.args[0])
80+
node.meta["spec"] = value_spec[node.args[1]]
81+
else:
82+
meta_val = node.meta.get("val")
83+
if meta_val is not None:
84+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
85+
return res
86+
5587
def on_attr(self, attr: ProxyValue) -> None:
5688
attr.node.meta["spec"] = pytree.tree_map_only(
5789
torch.Tensor,
@@ -84,85 +116,3 @@ def update_placeholder_tensor_specs(
84116
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
85117
):
86118
spec.const = True
87-
88-
# pyre-ignore
89-
def placeholder(self, name: str, arg, meta):
90-
meta["spec"] = make_spec(arg)
91-
return super().placeholder(name, arg, meta)
92-
93-
# pyre-ignore
94-
def call_operator(self, op, args, kwargs, meta):
95-
args_data, kwargs_data = pytree.tree_map_only(
96-
ProxyValue, lambda x: x.data, (args, kwargs)
97-
)
98-
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
99-
return super().call_operator(op, args, kwargs, meta)
100-
101-
# pyre-ignore
102-
def call_getitem(self, value, key: int, meta):
103-
meta["spec"] = value.node.meta["spec"][key]
104-
return super().call_getitem(value, key, meta)
105-
106-
# pyre-ignore
107-
def call_cond(self, pred, true_fn, false_fn, inputs, meta):
108-
# true_fn/false_fn return tensors of the same shape, so we can pick
109-
# either one here.
110-
*_, true_out_node = true_fn.graph.nodes
111-
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
112-
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
113-
114-
def call_while(
115-
self,
116-
cond_fn: torch.fx.GraphModule,
117-
body_fn: torch.fx.GraphModule,
118-
carried_inputs: List[ProxyValue],
119-
additional_inputs: List[ProxyValue],
120-
meta: NodeMetadata,
121-
):
122-
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
123-
return super().call_while(
124-
cond_fn, body_fn, carried_inputs, additional_inputs, meta
125-
)
126-
127-
def call_map(
128-
self,
129-
f: torch.fx.GraphModule,
130-
mapped_args: List[ProxyValue],
131-
operands: List[ProxyValue],
132-
meta: NodeMetadata,
133-
) -> ProxyValue:
134-
mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
135-
*_, body_out_node = f.graph.nodes
136-
body_out_node_fake_tensor = body_out_node.meta["val"]
137-
map_fake_tensor = pytree.tree_map_only(
138-
torch.Tensor,
139-
lambda x: x.new_empty(mapped_dim_size, *x.shape),
140-
body_out_node_fake_tensor,
141-
)
142-
meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
143-
return super().call_map(f, mapped_args, operands, meta)
144-
145-
# pyre-ignore
146-
def call_delegate(self, lowered_module, args, kwargs, meta):
147-
args_data, kwargs_data = pytree.tree_map_only(
148-
ProxyValue, lambda x: x.data, (args, kwargs)
149-
)
150-
# If spec is missing, re-genenrate it with args data
151-
if "spec" not in meta:
152-
meta["spec"] = pytree.tree_map(
153-
make_spec,
154-
executorch_call_delegate(lowered_module, *args_data),
155-
)
156-
return super().call_delegate(lowered_module, args, kwargs, meta)
157-
158-
# pyre-ignore
159-
def output(self, results, meta):
160-
# pyre-ignore
161-
def get_spec(x):
162-
if isinstance(x, ProxyValue):
163-
return x.node.meta["spec"]
164-
else:
165-
return make_spec(x)
166-
167-
meta["spec"] = pytree.tree_map(get_spec, results)
168-
return super().output(results, meta)

exir/tests/test_passes.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
7575
from executorch.exir.program._program import lift_constant_tensor_pass
7676
from executorch.exir.schema import TensorShapeDynamism
77+
from executorch.exir.sym_util import eval_upper_bound
7778
from executorch.exir.tensor import TensorSpec
7879
from executorch.exir.tests.common import register_additional_test_aten_ops
7980
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
@@ -113,6 +114,7 @@ def collect_ops(gm: torch.fx.GraphModule):
113114

114115
lib.define("foo(Tensor self) -> (Tensor, Tensor)")
115116
lib.define("add_relu(Tensor self, Tensor other) -> Tensor")
117+
lib.define("unbacked(Tensor self) -> Tensor")
116118

117119

118120
@impl(lib, "foo", "CompositeExplicitAutograd")
@@ -132,6 +134,29 @@ def foo_out(
132134
return a + 1, None
133135

134136

137+
@impl(lib, "unbacked", "CPU")
138+
def unbacked(a: torch.Tensor) -> torch.Tensor:
139+
return a[: a[0]]
140+
141+
142+
@torch.library.register_fake(f"{lib.ns}::unbacked")
143+
def meta_unbacked(x):
144+
ctx = torch._custom_ops.get_ctx()
145+
out_size = ctx.create_unbacked_symint()
146+
return x.new_empty(out_size)
147+
148+
149+
lib.define("unbacked.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
150+
151+
152+
@impl(lib, "unbacked.out", "CPU")
153+
def unbacked_out(
154+
x: torch.Tensor,
155+
out: torch.Tensor,
156+
) -> torch.Tensor:
157+
out.copy_(x[x[0]])
158+
159+
135160
def simple_promote_dtype(
136161
dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
137162
) -> torch.dtype:
@@ -611,6 +636,37 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
611636

612637
self.assertEqual(counter, 1)
613638

639+
def test_spec_prop_pass_unbacked_symint(self) -> None:
640+
# Verify that the spec prop pass picks up on guards for
641+
# unbacked symints.
642+
class Unbacked(torch.nn.Module):
643+
def forward(self, x):
644+
output = torch.ops.DO_NOT_USE_TEST_ONLY.unbacked(x)
645+
torch._constrain_as_size(output.shape[0], max=10)
646+
return output
647+
648+
model = Unbacked()
649+
gm = (
650+
to_edge(export(model, (torch.LongTensor([5, 4, 3, 2, 1, 0, 1, 2]),)))
651+
.exported_program()
652+
.graph_module
653+
)
654+
new_gm = SpecPropPass()(gm)
655+
self.assertIsNotNone(new_gm)
656+
657+
# Check the spec for the custom op node. It should have a max size of 10.
658+
op_node = next(
659+
n
660+
for n in new_gm.graph_module.graph.nodes
661+
if n.target == exir_ops.edge.DO_NOT_USE_TEST_ONLY.unbacked.default
662+
)
663+
self.assertIsNotNone(op_node)
664+
665+
spec: TensorSpec = op_node.meta["spec"]
666+
self.assertEqual(len(spec.shape), 1) # Should be rank 1
667+
upper_bound = eval_upper_bound(spec.shape[0])
668+
self.assertEqual(upper_bound, 10) # Should be a concrete value
669+
614670
def test_compile_fix_broken_ops(self) -> None:
615671
class ExportableLoop(nn.Module):
616672
def __init__(self, hidden_size, out_channels):

0 commit comments

Comments
 (0)