29
29
import pytensor .scalar .basic as ps
30
30
from pytensor import compile , config
31
31
from pytensor .compile .ops import ViewOp
32
- from pytensor .graph import FunctionGraph
32
+ from pytensor .graph import FunctionGraph , Op
33
33
from pytensor .graph .basic import Constant
34
34
from pytensor .graph .rewriting .basic import (
35
35
NodeProcessingGraphRewriter ,
40
40
node_rewriter ,
41
41
)
42
42
from pytensor .graph .rewriting .db import RewriteDatabase
43
+ from pytensor .graph .rewriting .unify import OpPattern
43
44
from pytensor .npy_2_compat import normalize_axis_index
44
45
from pytensor .raise_op import Assert , CheckAndRaise , assert_op
45
- from pytensor .scalar .basic import Second
46
46
from pytensor .tensor .basic import (
47
47
Alloc ,
48
48
AllocEmpty ,
@@ -225,6 +225,12 @@ def register(inner_rewriter: RewriteDatabase | Rewriter):
225
225
return node_rewriter
226
226
227
227
228
+ def elemwise_of (scalar_op ) -> OpPattern :
229
+ if not isinstance (scalar_op , Op | OpPattern ):
230
+ scalar_op = OpPattern (scalar_op )
231
+ return OpPattern (Elemwise , scalar_op = scalar_op )
232
+
233
+
228
234
@register_canonicalize
229
235
@register_specialize
230
236
@node_rewriter ([TensorFromScalar ])
@@ -324,15 +330,12 @@ def dimshuffled_alloc(i):
324
330
return new_outs
325
331
326
332
327
- @node_rewriter ([Elemwise ])
333
+ @node_rewriter ([fill ])
328
334
def local_fill_sink (fgraph , node ):
329
335
"""
330
336
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
331
337
f need to be an elemwise that isn't a fill.
332
338
"""
333
- if isinstance (node .op .scalar_op , Second ):
334
- return False
335
-
336
339
models = []
337
340
inputs = []
338
341
for inp in node .inputs :
@@ -653,7 +656,7 @@ def local_alloc_unary(fgraph, node):
653
656
654
657
@register_canonicalize
655
658
@register_specialize
656
- @node_rewriter ([Elemwise ])
659
+ @node_rewriter ([elemwise_of ( ps . Cast ) ])
657
660
def local_cast_cast (fgraph , node ):
658
661
"""cast(cast(x, dtype1), dtype2)
659
662
@@ -663,8 +666,6 @@ def local_cast_cast(fgraph, node):
663
666
and the first cast cause an upcast.
664
667
665
668
"""
666
- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , ps .Cast )):
667
- return
668
669
x = node .inputs [0 ]
669
670
if not (
670
671
x .owner
@@ -1031,19 +1032,13 @@ def local_useless_switch(fgraph, node):
1031
1032
1032
1033
1033
1034
@register_canonicalize
1034
- @node_rewriter ([Elemwise ])
1035
+ @node_rewriter ([elemwise_of ( ps . BinaryScalarOp | ps . Add | ps . Mul ) ])
1035
1036
def local_merge_switch_same_cond (fgraph , node ):
1036
1037
"""
1037
1038
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
1038
1039
condition, to enable further simplification of their branches
1039
1040
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
1040
1041
"""
1041
- # node must be binary elemwise or add or mul
1042
- if not (
1043
- isinstance (node .op , Elemwise )
1044
- and isinstance (node .op .scalar_op , ps .BinaryScalarOp | ps .Add | ps .Mul )
1045
- ):
1046
- return
1047
1042
# all inputs must be switch
1048
1043
if not all (
1049
1044
s .owner
@@ -1174,10 +1169,9 @@ def constant_folding(fgraph, node):
1174
1169
@register_infer_shape
1175
1170
@register_canonicalize ("fast_compile" )
1176
1171
@register_useless ("fast_compile" )
1177
- @node_rewriter (None )
1172
+ @node_rewriter ([ ViewOp ] )
1178
1173
def local_view_op (fgraph , node ):
1179
- if isinstance (node .op , ViewOp ):
1180
- return node .inputs
1174
+ return node .inputs
1181
1175
1182
1176
1183
1177
@register_infer_shape
0 commit comments