Skip to content

Commit dfdb023

Browse files
Merge pull request #15 from mister-bailey/efficient_opt_einsum
fx shape propagation avoids allocating expensive default summation intermediates, computes intermediate shapes formally
2 parents 4888c79 + e159c75 commit dfdb023

File tree

5 files changed

+120
-26
lines changed

5 files changed

+120
-26
lines changed

opt_einsum_fx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from ._script import jitable
44
from ._opt_ein import optimize_einsums, optimize_einsums_full
55
from ._fuse import fuse_einsums, fuse_scalars
6+
from ._efficient_shape_prop import EfficientShapeProp
67

78
__all__ = [
89
"jitable",
910
"optimize_einsums",
1011
"optimize_einsums_full",
1112
"fuse_einsums",
1213
"fuse_scalars",
14+
"EfficientShapeProp",
1315
]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import Any, NamedTuple
2+
3+
import opt_einsum
4+
import torch
5+
from torch.fx.node import Node
6+
7+
from ._fuse import _EINSUM_FUNCS
8+
9+
10+
class SimpleMeta(NamedTuple):
11+
"""
12+
The full ShapeProp defines and uses a NamedTuple to
13+
store a whole bunch of metadata about the tensors
14+
going into and out of the Node op. But we don't
15+
have most of that info, and anyway, I don't think
16+
most of it's used in opt_einsum or opt_einsum_fx.
17+
(These are only concerned with computing a summation
18+
order.)
19+
20+
Rather than give dummy or default values, which I
21+
only *assume* would be fine, I'm defining a NamedTuple
22+
with only the values we actually know. So if I'm wrong
23+
we will get a very clear error message, rather than
24+
some invisible error.
25+
"""
26+
27+
shape: torch.Size
28+
dtype: torch.dtype
29+
30+
31+
class EfficientShapeProp(torch.fx.Interpreter):
32+
"""
33+
Like ShapeProp, traverses a graph Node-by-Node
34+
and records the shape and type of the result
35+
into each Node.
36+
37+
Except we treat 'einsum' as a special case.
38+
We don't actually execute 'einsum' on tensors,
39+
since the einsums will typically not be optimized
40+
yet (ShapeProp is called before optimization),
41+
and inefficient summation order can create
42+
enormous intermediate tensors, which often creates
43+
needless out-of-memory errors.
44+
45+
So we override 'run_node' only for 'einsums'.
46+
It's straightforward to determine the shape of the
47+
result just from the output indices.
48+
49+
(The call to opt_einsum that will typically follow
50+
this, also doesn't actually build the tensors
51+
during its exploration.)
52+
"""
53+
54+
def run_node(self, n: Node) -> Any:
55+
if n.op == "call_function" and n.target in _EINSUM_FUNCS:
56+
args, kwargs = self.fetch_args_kwargs_from_env(n)
57+
equation, *operands = args
58+
shapes = [op.shape for op in operands]
59+
60+
assert len({op.dtype for op in operands}) == 1
61+
meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].dtype)
62+
result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=operands[0].device).expand(meta.shape)
63+
elif n.op == "call_function" and n.target == torch.tensordot:
64+
args, kwargs = self.fetch_args_kwargs_from_env(n)
65+
shape_a = [dim for i, dim in enumerate(args[0].shape) if i not in kwargs['dims'][0]]
66+
shape_b = [dim for i, dim in enumerate(args[1].shape) if i not in kwargs['dims'][1]]
67+
68+
assert len({op.dtype for op in args}) == 1
69+
meta = SimpleMeta(shape_a + shape_b, args[0].dtype)
70+
result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=args[0].device).expand(meta.shape)
71+
else:
72+
result = super().run_node(n)
73+
74+
if isinstance(result, torch.Tensor):
75+
meta = SimpleMeta(result.shape, result.dtype)
76+
else:
77+
meta = None
78+
79+
n.meta = dict()
80+
n.meta['tensor_meta'] = meta
81+
n.meta['type'] = type(result)
82+
83+
return result
84+
85+
def propagate(self, *args):
86+
return super().run(*args)
87+
88+
89+
def einsum_shape(subscripts, *shapes):
90+
"""
91+
Given an einsum equation and input shapes, returns the output
92+
shape of the einsum.
93+
94+
Args:
95+
subscripts: the einsum formula
96+
shapes: the input shapes
97+
"""
98+
Shaped = NamedTuple('Shaped', [('shape', tuple)])
99+
input_subscripts, output_subscript, _ = opt_einsum.parser.parse_einsum_input(
100+
(subscripts,) + tuple(Shaped(shape) for shape in shapes)
101+
)
102+
dims = {
103+
i: dim
104+
for ii, shape in zip(input_subscripts.split(','), shapes)
105+
for i, dim in zip(ii, shape)
106+
}
107+
return tuple(dims[i] for i in output_subscript)

opt_einsum_fx/_opt_ein.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from torch import fx
6-
from torch.fx.passes.shape_prop import ShapeProp
6+
from ._efficient_shape_prop import EfficientShapeProp as ShapeProp
77

88
import opt_einsum
99
from opt_einsum.contract import _core_contract

opt_einsum_fx/fx_utils.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,14 @@
11
from typing import Optional
2-
from packaging import version
32

43
import torch
54
from torch import fx
65

7-
_TORCH_IS_GE_19: bool = version.parse(torch.__version__) >= version.parse("1.9.0")
86

9-
# The torch FX APIs are not stable, so we need helper wrappers
10-
11-
if _TORCH_IS_GE_19:
12-
13-
def get_shape(n: fx.Node) -> Optional[torch.Size]:
14-
"""Get the shape of a node after ``ShapeProp``"""
15-
try:
16-
return n.meta["tensor_meta"].shape
17-
except KeyError:
18-
return None
19-
20-
21-
else:
22-
23-
def get_shape(n: fx.Node) -> Optional[torch.Size]:
24-
"""Get the shape of a node after ``ShapeProp``"""
25-
try:
26-
return n.shape
27-
except AttributeError:
28-
return None
7+
def get_shape(n: fx.Node) -> Optional[torch.Size]:
8+
"""Get the shape of a node after ``ShapeProp``"""
9+
try:
10+
return n.meta["tensor_meta"].shape
11+
except KeyError:
12+
return None
13+
except AttributeError:
14+
return None

tests/test_einsum_optimizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import torch
44
import torch.fx
5-
from torch.fx.passes.shape_prop import ShapeProp
65

7-
from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable
6+
from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable, EfficientShapeProp
87

98

109
def einmatmul(x, y):
@@ -74,7 +73,7 @@ def test_optimize_einsums(einfunc, allclose):
7473
func_res = einfunc(x, y)
7574

7675
func_fx = torch.fx.symbolic_trace(einfunc)
77-
sp = ShapeProp(func_fx)
76+
sp = EfficientShapeProp(func_fx)
7877
sp.run(x, y)
7978

8079
func_fx_res = func_fx(x, y)

0 commit comments

Comments
 (0)