26
26
27
27
import numpy as np
28
28
29
- import pytensor .scalar .basic as ps
30
29
from pytensor import compile , config
31
30
from pytensor .compile .ops import ViewOp
32
- from pytensor .graph import FunctionGraph
31
+ from pytensor .graph import FunctionGraph , Op
33
32
from pytensor .graph .basic import Constant
34
33
from pytensor .graph .rewriting .basic import (
35
34
NodeProcessingGraphRewriter ,
40
39
node_rewriter ,
41
40
)
42
41
from pytensor .graph .rewriting .db import RewriteDatabase
42
+ from pytensor .graph .rewriting .unify import OpPattern
43
43
from pytensor .npy_2_compat import normalize_axis_index
44
44
from pytensor .raise_op import Assert , CheckAndRaise , assert_op
45
- from pytensor .scalar .basic import Second
45
+ from pytensor .scalar import (
46
+ AND ,
47
+ EQ ,
48
+ LE ,
49
+ NEQ ,
50
+ OR ,
51
+ XOR ,
52
+ Add ,
53
+ BinaryScalarOp ,
54
+ Cast ,
55
+ Identity ,
56
+ Mul ,
57
+ Second ,
58
+ Switch ,
59
+ )
46
60
from pytensor .tensor .basic import (
47
61
Alloc ,
48
62
AllocEmpty ,
@@ -225,6 +239,12 @@ def register(inner_rewriter: RewriteDatabase | Rewriter):
225
239
return node_rewriter
226
240
227
241
242
+ def elemwise_of (scalar_op : type [Op ] | OpPattern ) -> OpPattern :
243
+ if not isinstance (scalar_op , Op | OpPattern ):
244
+ scalar_op = OpPattern (scalar_op )
245
+ return OpPattern (Elemwise , scalar_op = scalar_op )
246
+
247
+
228
248
@register_canonicalize
229
249
@register_specialize
230
250
@node_rewriter ([TensorFromScalar ])
@@ -551,15 +571,15 @@ def local_useless_elemwise(fgraph, node):
551
571
dtype = node .outputs [0 ].type .dtype
552
572
scalar_op = node .op .scalar_op
553
573
554
- if isinstance (scalar_op , ps . EQ ) and len (node .inputs ) == 2 :
574
+ if isinstance (scalar_op , EQ ) and len (node .inputs ) == 2 :
555
575
if node .inputs [0 ] is node .inputs [1 ]:
556
576
# it is the same var in the graph. That will always be true
557
577
ret = ones_like (node .inputs [0 ], dtype = dtype , opt = True )
558
578
559
579
# Copy stack trace from input to constant output
560
580
copy_stack_trace (node .outputs [0 ], ret )
561
581
return [ret ]
562
- elif isinstance (scalar_op , ps . NEQ | ps . XOR ) and len (node .inputs ) == 2 :
582
+ elif isinstance (scalar_op , NEQ | XOR ) and len (node .inputs ) == 2 :
563
583
if node .inputs [0 ] is node .inputs [1 ]:
564
584
# it is the same var in the graph. That will always be false
565
585
ret = zeros_like (node .inputs [0 ], dtype = dtype , opt = True )
@@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node):
568
588
copy_stack_trace (node .outputs [0 ], ret )
569
589
return [ret ]
570
590
571
- elif (
572
- isinstance (node .op .scalar_op , ps .Mul | ps .Add | ps .Identity )
573
- and len (node .inputs ) == 1
574
- ):
591
+ elif isinstance (node .op .scalar_op , Mul | Add | Identity ) and len (node .inputs ) == 1 :
575
592
# No need to copy over any stack trace
576
593
return [node .inputs [0 ]]
577
594
578
- elif isinstance (node .op .scalar_op , ps . AND ) and len (node .inputs ) == 2 :
595
+ elif isinstance (node .op .scalar_op , AND ) and len (node .inputs ) == 2 :
579
596
if (
580
597
isinstance (node .inputs [0 ], TensorConstant )
581
598
and node .inputs [1 ].type .broadcastable == out_bcast
@@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node):
602
619
# and this rewrite would be wrong
603
620
return [node .inputs [0 ].astype (node .outputs [0 ].dtype )]
604
621
605
- elif isinstance (node .op .scalar_op , ps . OR ) and len (node .inputs ) == 2 :
622
+ elif isinstance (node .op .scalar_op , OR ) and len (node .inputs ) == 2 :
606
623
if (
607
624
isinstance (node .inputs [0 ], TensorConstant )
608
625
and node .inputs [1 ].type .broadcastable == out_bcast
@@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node):
653
670
654
671
@register_canonicalize
655
672
@register_specialize
656
- @node_rewriter ([Elemwise ])
673
+ @node_rewriter ([elemwise_of ( Cast ) ])
657
674
def local_cast_cast (fgraph , node ):
658
675
"""cast(cast(x, dtype1), dtype2)
659
676
@@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node):
663
680
and the first cast cause an upcast.
664
681
665
682
"""
666
- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , ps .Cast )):
667
- return
668
683
x = node .inputs [0 ]
669
684
if not (
670
685
x .owner
671
686
and isinstance (x .owner .op , Elemwise )
672
- and isinstance (x .owner .op .scalar_op , ps . Cast )
687
+ and isinstance (x .owner .op .scalar_op , Cast )
673
688
):
674
689
return
675
690
@@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node):
1009
1024
node .outputs [0 ].type .ndim == 0
1010
1025
and cond_var .owner
1011
1026
and isinstance (cond_var .owner .op , Elemwise )
1012
- and isinstance (cond_var .owner .op .scalar_op , ps . LE )
1027
+ and isinstance (cond_var .owner .op .scalar_op , LE )
1013
1028
and cond_var .owner .inputs [0 ].owner
1014
1029
and isinstance (cond_var .owner .inputs [0 ].owner .op , Shape_i )
1015
1030
and get_scalar_constant_value (
@@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node):
1031
1046
1032
1047
1033
1048
@register_canonicalize
1034
- @node_rewriter ([Elemwise ])
1049
+ @node_rewriter ([elemwise_of ( BinaryScalarOp | Add | Mul ) ])
1035
1050
def local_merge_switch_same_cond (fgraph , node ):
1036
1051
"""
1037
1052
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
1038
1053
condition, to enable further simplification of their branches
1039
1054
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
1040
1055
"""
1041
- # node must be binary elemwise or add or mul
1042
- if not (
1043
- isinstance (node .op , Elemwise )
1044
- and isinstance (node .op .scalar_op , ps .BinaryScalarOp | ps .Add | ps .Mul )
1045
- ):
1046
- return
1047
1056
# all inputs must be switch
1048
1057
if not all (
1049
1058
s .owner
1050
1059
and isinstance (s .owner .op , Elemwise )
1051
- and isinstance (s .owner .op .scalar_op , ps . Switch )
1060
+ and isinstance (s .owner .op .scalar_op , Switch )
1052
1061
for s in node .inputs
1053
1062
):
1054
1063
return
@@ -1174,10 +1183,9 @@ def constant_folding(fgraph, node):
1174
1183
@register_infer_shape
1175
1184
@register_canonicalize ("fast_compile" )
1176
1185
@register_useless ("fast_compile" )
1177
- @node_rewriter (None )
1186
+ @node_rewriter ([ ViewOp ] )
1178
1187
def local_view_op (fgraph , node ):
1179
- if isinstance (node .op , ViewOp ):
1180
- return node .inputs
1188
+ return node .inputs
1181
1189
1182
1190
1183
1191
@register_infer_shape
0 commit comments