Skip to content

Commit f76771a

Browse files
committed
Use OpPattern in tracks
1 parent 183f4c5 commit f76771a

File tree

8 files changed

+189
-267
lines changed

8 files changed

+189
-267
lines changed

pytensor/scalar/basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,8 @@ def __init__(self, output_types_preference=None, name=None):
12281228
f"(got: {output_types_preference})"
12291229
)
12301230
self.output_types_preference = output_types_preference
1231+
elif not hasattr(self, "output_types_preference"):
1232+
self.output_types_preference = None
12311233

12321234
def make_node(self, *inputs):
12331235
if self.nin >= 0:
@@ -1247,7 +1249,7 @@ def make_node(self, *inputs):
12471249
return Apply(self, inputs, outputs)
12481250

12491251
def output_types(self, types):
1250-
if hasattr(self, "output_types_preference"):
1252+
if self.output_types_preference is not None:
12511253
variables = self.output_types_preference(*types)
12521254
if not isinstance(variables, list | tuple) or any(
12531255
not isinstance(x, CType) for x in variables
@@ -2696,7 +2698,7 @@ class Sign(UnaryScalarOp):
26962698
nfunc_spec = ("sign", 1, 1)
26972699

26982700
@staticmethod
2699-
def output_types_preference(x):
2701+
def _output_types_preference(x):
27002702
if x == bool:
27012703
raise TypeError(x)
27022704
return same_out_nocomplex(x)
@@ -2737,7 +2739,7 @@ def c_code_cache_version(self):
27372739
return s
27382740

27392741

2740-
sign = Sign(name="sign")
2742+
sign = Sign(name="sign", output_types_preference=Sign._output_types_preference)
27412743

27422744

27432745
class Ceil(UnaryScalarOp):

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytensor.tensor.blockwise import Blockwise
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.rewriting.basic import register_specialize
17+
from pytensor.tensor.rewriting.blockwise import blockwise_of
1718
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
1819
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
1920
from pytensor.tensor.variable import TensorVariable
@@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve(
227228

228229

229230
@register_specialize
230-
@node_rewriter([Blockwise])
231+
@node_rewriter([blockwise_of(Solve)])
231232
def reuse_decomposition_multiple_solves(fgraph, node):
232233
return _split_decomp_and_solve_steps(
233234
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}

pytensor/tensor/rewriting/basic.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@
2626

2727
import numpy as np
2828

29-
import pytensor.scalar.basic as ps
3029
from pytensor import compile, config
3130
from pytensor.compile.ops import ViewOp
32-
from pytensor.graph import FunctionGraph
31+
from pytensor.graph import FunctionGraph, Op
3332
from pytensor.graph.basic import Constant
3433
from pytensor.graph.rewriting.basic import (
3534
NodeProcessingGraphRewriter,
@@ -40,9 +39,24 @@
4039
node_rewriter,
4140
)
4241
from pytensor.graph.rewriting.db import RewriteDatabase
42+
from pytensor.graph.rewriting.unify import OpPattern
4343
from pytensor.npy_2_compat import normalize_axis_index
4444
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
45-
from pytensor.scalar.basic import Second
45+
from pytensor.scalar import (
46+
AND,
47+
EQ,
48+
LE,
49+
NEQ,
50+
OR,
51+
XOR,
52+
Add,
53+
BinaryScalarOp,
54+
Cast,
55+
Identity,
56+
Mul,
57+
Second,
58+
Switch,
59+
)
4660
from pytensor.tensor.basic import (
4761
Alloc,
4862
AllocEmpty,
@@ -225,6 +239,12 @@ def register(inner_rewriter: RewriteDatabase | Rewriter):
225239
return node_rewriter
226240

227241

242+
def elemwise_of(scalar_op: type[Op] | OpPattern) -> OpPattern:
243+
if not isinstance(scalar_op, Op | OpPattern):
244+
scalar_op = OpPattern(scalar_op)
245+
return OpPattern(Elemwise, scalar_op=scalar_op)
246+
247+
228248
@register_canonicalize
229249
@register_specialize
230250
@node_rewriter([TensorFromScalar])
@@ -551,15 +571,15 @@ def local_useless_elemwise(fgraph, node):
551571
dtype = node.outputs[0].type.dtype
552572
scalar_op = node.op.scalar_op
553573

554-
if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2:
574+
if isinstance(scalar_op, EQ) and len(node.inputs) == 2:
555575
if node.inputs[0] is node.inputs[1]:
556576
# it is the same var in the graph. That will always be true
557577
ret = ones_like(node.inputs[0], dtype=dtype, opt=True)
558578

559579
# Copy stack trace from input to constant output
560580
copy_stack_trace(node.outputs[0], ret)
561581
return [ret]
562-
elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2:
582+
elif isinstance(scalar_op, NEQ | XOR) and len(node.inputs) == 2:
563583
if node.inputs[0] is node.inputs[1]:
564584
# it is the same var in the graph. That will always be false
565585
ret = zeros_like(node.inputs[0], dtype=dtype, opt=True)
@@ -568,14 +588,11 @@ def local_useless_elemwise(fgraph, node):
568588
copy_stack_trace(node.outputs[0], ret)
569589
return [ret]
570590

571-
elif (
572-
isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity)
573-
and len(node.inputs) == 1
574-
):
591+
elif isinstance(node.op.scalar_op, Mul | Add | Identity) and len(node.inputs) == 1:
575592
# No need to copy over any stack trace
576593
return [node.inputs[0]]
577594

578-
elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2:
595+
elif isinstance(node.op.scalar_op, AND) and len(node.inputs) == 2:
579596
if (
580597
isinstance(node.inputs[0], TensorConstant)
581598
and node.inputs[1].type.broadcastable == out_bcast
@@ -602,7 +619,7 @@ def local_useless_elemwise(fgraph, node):
602619
# and this rewrite would be wrong
603620
return [node.inputs[0].astype(node.outputs[0].dtype)]
604621

605-
elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2:
622+
elif isinstance(node.op.scalar_op, OR) and len(node.inputs) == 2:
606623
if (
607624
isinstance(node.inputs[0], TensorConstant)
608625
and node.inputs[1].type.broadcastable == out_bcast
@@ -653,7 +670,7 @@ def local_alloc_unary(fgraph, node):
653670

654671
@register_canonicalize
655672
@register_specialize
656-
@node_rewriter([Elemwise])
673+
@node_rewriter([elemwise_of(Cast)])
657674
def local_cast_cast(fgraph, node):
658675
"""cast(cast(x, dtype1), dtype2)
659676
@@ -663,13 +680,11 @@ def local_cast_cast(fgraph, node):
663680
and the first cast cause an upcast.
664681
665682
"""
666-
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)):
667-
return
668683
x = node.inputs[0]
669684
if not (
670685
x.owner
671686
and isinstance(x.owner.op, Elemwise)
672-
and isinstance(x.owner.op.scalar_op, ps.Cast)
687+
and isinstance(x.owner.op.scalar_op, Cast)
673688
):
674689
return
675690

@@ -1009,7 +1024,7 @@ def local_useless_switch(fgraph, node):
10091024
node.outputs[0].type.ndim == 0
10101025
and cond_var.owner
10111026
and isinstance(cond_var.owner.op, Elemwise)
1012-
and isinstance(cond_var.owner.op.scalar_op, ps.LE)
1027+
and isinstance(cond_var.owner.op.scalar_op, LE)
10131028
and cond_var.owner.inputs[0].owner
10141029
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
10151030
and get_scalar_constant_value(
@@ -1031,24 +1046,18 @@ def local_useless_switch(fgraph, node):
10311046

10321047

10331048
@register_canonicalize
1034-
@node_rewriter([Elemwise])
1049+
@node_rewriter([elemwise_of(BinaryScalarOp | Add | Mul)])
10351050
def local_merge_switch_same_cond(fgraph, node):
10361051
"""
10371052
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
10381053
condition, to enable further simplification of their branches
10391054
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
10401055
"""
1041-
# node must be binary elemwise or add or mul
1042-
if not (
1043-
isinstance(node.op, Elemwise)
1044-
and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul)
1045-
):
1046-
return
10471056
# all inputs must be switch
10481057
if not all(
10491058
s.owner
10501059
and isinstance(s.owner.op, Elemwise)
1051-
and isinstance(s.owner.op.scalar_op, ps.Switch)
1060+
and isinstance(s.owner.op.scalar_op, Switch)
10521061
for s in node.inputs
10531062
):
10541063
return
@@ -1174,10 +1183,9 @@ def constant_folding(fgraph, node):
11741183
@register_infer_shape
11751184
@register_canonicalize("fast_compile")
11761185
@register_useless("fast_compile")
1177-
@node_rewriter(None)
1186+
@node_rewriter([ViewOp])
11781187
def local_view_op(fgraph, node):
1179-
if isinstance(node.op, ViewOp):
1180-
return node.inputs
1188+
return node.inputs
11811189

11821190

11831191
@register_infer_shape

pytensor/tensor/rewriting/blockwise.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from pytensor.compile.mode import optdb
2-
from pytensor.graph import Constant, node_rewriter
2+
from pytensor.graph import Constant, Op, node_rewriter
33
from pytensor.graph.destroyhandler import inplace_candidates
44
from pytensor.graph.replace import vectorize_node
55
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
6+
from pytensor.graph.rewriting.unify import OpPattern
67
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
78
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
89
from pytensor.tensor.math import Dot
@@ -20,6 +21,12 @@
2021
)
2122

2223

24+
def blockwise_of(core_op: type[Op] | OpPattern) -> OpPattern:
25+
if not isinstance(core_op, Op | OpPattern):
26+
core_op = OpPattern(core_op)
27+
return OpPattern(Blockwise, core_op=core_op)
28+
29+
2330
@node_rewriter([Blockwise])
2431
def local_useless_blockwise(fgraph, node):
2532
"""
@@ -71,22 +78,24 @@ def local_useless_unbatched_blockwise(fgraph, node):
7178
@register_canonicalize
7279
@register_stabilize
7380
@register_specialize
74-
@node_rewriter(tracks=[Blockwise])
81+
@node_rewriter(
82+
tracks=[
83+
blockwise_of(
84+
Dot
85+
| Alloc
86+
| ARange
87+
| Subtensor
88+
| AdvancedSubtensor
89+
| AdvancedIncSubtensor
90+
| Reshape
91+
)
92+
]
93+
)
7594
def local_eager_useless_unbatched_blockwise(fgraph, node):
76-
if isinstance(
77-
node.op.core_op,
78-
Dot
79-
| Alloc
80-
| ARange
81-
| Subtensor
82-
| AdvancedSubtensor
83-
| AdvancedIncSubtensor
84-
| Reshape,
85-
):
86-
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
87-
# These other Ops can't always be trivially vectorized at runtime,
88-
# since their inputs may imply non-rectangular shapes.
89-
return local_useless_unbatched_blockwise.fn(fgraph, node)
95+
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
96+
# These other Ops can't always be trivially vectorized at runtime,
97+
# since their inputs may imply non-rectangular shapes.
98+
return local_useless_unbatched_blockwise.fn(fgraph, node)
9099

91100

92101
@register_specialize("shape_unsafe")
@@ -204,7 +213,7 @@ def local_blockwise_alloc(fgraph, node):
204213

205214

206215
@register_specialize
207-
@node_rewriter([Blockwise])
216+
@node_rewriter([blockwise_of(Reshape)])
208217
def local_blockwise_reshape(fgraph, node):
209218
"""Rewrite away square Blockwise reshapes.
210219
@@ -215,9 +224,6 @@ def local_blockwise_reshape(fgraph, node):
215224
For the square Reshape case, we must wait for all the intermediate
216225
operations to be lifted as Allocs
217226
"""
218-
if not isinstance(node.op.core_op, Reshape):
219-
return None
220-
221227
x, output_shape = node.inputs
222228
batch_ndim = node.op.batch_ndim(node)
223229
if all(output_shape.type.broadcastable[:batch_ndim]):

0 commit comments

Comments
 (0)