Skip to content

Commit 0643a56

Browse files
torchscript in 1.9
1 parent 4dcf221 commit 0643a56

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

opt_einsum_fx/_script.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Union
2+
from packaging import version
23

34
import torch
45
from torch import fx
@@ -21,6 +22,8 @@ def jitable(obj: Union[fx.GraphModule, fx.Graph]) -> Union[fx.GraphModule, fx.Gr
2122
else:
2223
graph = obj
2324

25+
torch_is_ge_19: bool = version.parse(torch.__version__) >= version.parse("1.9.0")
26+
2427
for node in graph.nodes:
2528
if node.op == "call_function":
2629
if (
@@ -32,8 +35,13 @@ def jitable(obj: Union[fx.GraphModule, fx.Graph]) -> Union[fx.GraphModule, fx.Gr
3235
kwargs = dict(node.kwargs)
3336
dim_self, dim_other = kwargs.pop("dims")
3437
assert len(args) == 2 # tensors 1 and 2
35-
args.append(list(dim_self))
36-
args.append(list(dim_other))
38+
if torch_is_ge_19:
39+
# In torch >= 1.9.0, they've corrected the torchscript interface
40+
# to align with the python one:
41+
args.append((list(dim_self), list(dim_other)))
42+
else:
43+
args.append(list(dim_self))
44+
args.append(list(dim_other))
3745
node.args = tuple(args)
3846
node.kwargs = kwargs
3947
elif node.op == "call_method":

0 commit comments

Comments
 (0)