|
| 1 | +from typing import Any, NamedTuple |
| 2 | + |
| 3 | +import opt_einsum |
| 4 | +import torch |
| 5 | +from torch.fx.node import Node |
| 6 | + |
| 7 | +from ._fuse import _EINSUM_FUNCS |
| 8 | + |
| 9 | + |
| 10 | +class SimpleMeta(NamedTuple): |
| 11 | + """ |
| 12 | + The full ShapeProp defines and uses a NamedTuple to |
| 13 | + store a whole bunch of metadata about the tensors |
| 14 | + going into and out of the Node op. But we don't |
| 15 | + have most of that info, and anyway, I don't think |
| 16 | + most of it's used in opt_einsum or opt_einsum_fx. |
| 17 | + (These are only concerned with computing a summation |
| 18 | + order.) |
| 19 | +
|
| 20 | + Rather than give dummy or default values, which I |
| 21 | + only *assume* would be fine, I'm defining a NamedTuple |
| 22 | + with only the values we actually know. So if I'm wrong |
| 23 | + we will get a very clear error message, rather than |
| 24 | + some invisible error. |
| 25 | + """ |
| 26 | + |
| 27 | + shape: torch.Size |
| 28 | + dtype: torch.dtype |
| 29 | + |
| 30 | + |
| 31 | +class EfficientShapeProp(torch.fx.Interpreter): |
| 32 | + """ |
| 33 | + Like ShapeProp, traverses a graph Node-by-Node |
| 34 | + and records the shape and type of the result |
| 35 | + into each Node. |
| 36 | +
|
| 37 | + Except we treat 'einsum' as a special case. |
| 38 | + We don't actually execute 'einsum' on tensors, |
| 39 | + since the einsums will typically not be optimized |
| 40 | + yet (ShapeProp is called before optimization), |
| 41 | + and inefficient summation order can create |
| 42 | + enormous intermediate tensors, which often creates |
| 43 | + needless out-of-memory errors. |
| 44 | +
|
| 45 | + So we override 'run_node' only for 'einsums'. |
| 46 | + It's straightforward to determine the shape of the |
| 47 | + result just from the output indices. |
| 48 | +
|
| 49 | + (The call to opt_einsum that will typically follow |
| 50 | + this, also doesn't actually build the tensors |
| 51 | + during its exploration.) |
| 52 | + """ |
| 53 | + |
| 54 | + def run_node(self, n: Node) -> Any: |
| 55 | + if n.op == "call_function" and n.target in _EINSUM_FUNCS: |
| 56 | + args, kwargs = self.fetch_args_kwargs_from_env(n) |
| 57 | + equation, *operands = args |
| 58 | + shapes = [op.shape for op in operands] |
| 59 | + |
| 60 | + assert len({op.dtype for op in operands}) == 1 |
| 61 | + meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].dtype) |
| 62 | + result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=operands[0].device).expand(meta.shape) |
| 63 | + elif n.op == "call_function" and n.target == torch.tensordot: |
| 64 | + args, kwargs = self.fetch_args_kwargs_from_env(n) |
| 65 | + shape_a = [dim for i, dim in enumerate(args[0].shape) if i not in kwargs['dims'][0]] |
| 66 | + shape_b = [dim for i, dim in enumerate(args[1].shape) if i not in kwargs['dims'][1]] |
| 67 | + |
| 68 | + assert len({op.dtype for op in args}) == 1 |
| 69 | + meta = SimpleMeta(shape_a + shape_b, args[0].dtype) |
| 70 | + result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=args[0].device).expand(meta.shape) |
| 71 | + else: |
| 72 | + result = super().run_node(n) |
| 73 | + |
| 74 | + if isinstance(result, torch.Tensor): |
| 75 | + meta = SimpleMeta(result.shape, result.dtype) |
| 76 | + else: |
| 77 | + meta = None |
| 78 | + |
| 79 | + n.meta = dict() |
| 80 | + n.meta['tensor_meta'] = meta |
| 81 | + n.meta['type'] = type(result) |
| 82 | + |
| 83 | + return result |
| 84 | + |
| 85 | + def propagate(self, *args): |
| 86 | + return super().run(*args) |
| 87 | + |
| 88 | + |
| 89 | +def einsum_shape(subscripts, *shapes): |
| 90 | + """ |
| 91 | + Given an einsum equation and input shapes, returns the output |
| 92 | + shape of the einsum. |
| 93 | +
|
| 94 | + Args: |
| 95 | + subscripts: the einsum formula |
| 96 | + shapes: the input shapes |
| 97 | + """ |
| 98 | + Shaped = NamedTuple('Shaped', [('shape', tuple)]) |
| 99 | + input_subscripts, output_subscript, _ = opt_einsum.parser.parse_einsum_input( |
| 100 | + (subscripts,) + tuple(Shaped(shape) for shape in shapes) |
| 101 | + ) |
| 102 | + dims = { |
| 103 | + i: dim |
| 104 | + for ii, shape in zip(input_subscripts.split(','), shapes) |
| 105 | + for i, dim in zip(ii, shape) |
| 106 | + } |
| 107 | + return tuple(dims[i] for i in output_subscript) |
0 commit comments