Skip to content

Commit b5fb239

Browse files
committed
Use OpPattern in tracks
1 parent 288bd72 commit b5fb239

File tree

8 files changed

+165
-260
lines changed

8 files changed

+165
-260
lines changed

pytensor/scalar/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ def __init__(self, output_types_preference=None, name=None):
12271227
f"Expected a callable for the 'output_types_preference' argument to {self.__class__}. "
12281228
f"(got: {output_types_preference})"
12291229
)
1230-
self.output_types_preference = output_types_preference
1230+
self.output_types_preference = output_types_preference
12311231

12321232
def make_node(self, *inputs):
12331233
if self.nin >= 0:
@@ -1247,7 +1247,7 @@ def make_node(self, *inputs):
12471247
return Apply(self, inputs, outputs)
12481248

12491249
def output_types(self, types):
1250-
if hasattr(self, "output_types_preference"):
1250+
if self.output_types_preference is not None:
12511251
variables = self.output_types_preference(*types)
12521252
if not isinstance(variables, list | tuple) or any(
12531253
not isinstance(x, CType) for x in variables
@@ -2696,7 +2696,7 @@ class Sign(UnaryScalarOp):
26962696
nfunc_spec = ("sign", 1, 1)
26972697

26982698
@staticmethod
2699-
def output_types_preference(x):
2699+
def _output_types_preference(x):
27002700
if x == bool:
27012701
raise TypeError(x)
27022702
return same_out_nocomplex(x)
@@ -2737,7 +2737,7 @@ def c_code_cache_version(self):
27372737
return s
27382738

27392739

2740-
sign = Sign(name="sign")
2740+
sign = Sign(name="sign", output_types_preference=Sign._output_types_preference)
27412741

27422742

27432743
class Ceil(UnaryScalarOp):

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytensor.tensor.blockwise import Blockwise
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.rewriting.basic import register_specialize
17+
from pytensor.tensor.rewriting.blockwise import blockwise_of
1718
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
1819
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
1920
from pytensor.tensor.variable import TensorVariable
@@ -227,7 +228,7 @@ def _scan_split_non_sequence_decomposition_and_solve(
227228

228229

229230
@register_specialize
230-
@node_rewriter([Blockwise])
231+
@node_rewriter([blockwise_of(Solve)])
231232
def reuse_decomposition_multiple_solves(fgraph, node):
232233
return _split_decomp_and_solve_steps(
233234
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"}

pytensor/tensor/rewriting/basic.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import pytensor.scalar.basic as ps
3030
from pytensor import compile, config
3131
from pytensor.compile.ops import ViewOp
32-
from pytensor.graph import FunctionGraph
32+
from pytensor.graph import FunctionGraph, Op
3333
from pytensor.graph.basic import Constant
3434
from pytensor.graph.rewriting.basic import (
3535
NodeProcessingGraphRewriter,
@@ -40,9 +40,9 @@
4040
node_rewriter,
4141
)
4242
from pytensor.graph.rewriting.db import RewriteDatabase
43+
from pytensor.graph.rewriting.unify import OpPattern
4344
from pytensor.npy_2_compat import normalize_axis_index
4445
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
45-
from pytensor.scalar.basic import Second
4646
from pytensor.tensor.basic import (
4747
Alloc,
4848
AllocEmpty,
@@ -225,6 +225,12 @@ def register(inner_rewriter: RewriteDatabase | Rewriter):
225225
return node_rewriter
226226

227227

228+
def elemwise_of(scalar_op) -> OpPattern:
229+
if not isinstance(scalar_op, Op | OpPattern):
230+
scalar_op = OpPattern(scalar_op)
231+
return OpPattern(Elemwise, scalar_op=scalar_op)
232+
233+
228234
@register_canonicalize
229235
@register_specialize
230236
@node_rewriter([TensorFromScalar])
@@ -324,15 +330,12 @@ def dimshuffled_alloc(i):
324330
return new_outs
325331

326332

327-
@node_rewriter([Elemwise])
333+
@node_rewriter([fill])
328334
def local_fill_sink(fgraph, node):
329335
"""
330336
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
331337
f need to be an elemwise that isn't a fill.
332338
"""
333-
if isinstance(node.op.scalar_op, Second):
334-
return False
335-
336339
models = []
337340
inputs = []
338341
for inp in node.inputs:
@@ -653,7 +656,7 @@ def local_alloc_unary(fgraph, node):
653656

654657
@register_canonicalize
655658
@register_specialize
656-
@node_rewriter([Elemwise])
659+
@node_rewriter([elemwise_of(ps.Cast)])
657660
def local_cast_cast(fgraph, node):
658661
"""cast(cast(x, dtype1), dtype2)
659662
@@ -663,8 +666,6 @@ def local_cast_cast(fgraph, node):
663666
and the first cast cause an upcast.
664667
665668
"""
666-
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)):
667-
return
668669
x = node.inputs[0]
669670
if not (
670671
x.owner
@@ -1031,19 +1032,13 @@ def local_useless_switch(fgraph, node):
10311032

10321033

10331034
@register_canonicalize
1034-
@node_rewriter([Elemwise])
1035+
@node_rewriter([elemwise_of(ps.BinaryScalarOp | ps.Add | ps.Mul)])
10351036
def local_merge_switch_same_cond(fgraph, node):
10361037
"""
10371038
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
10381039
condition, to enable further simplification of their branches
10391040
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
10401041
"""
1041-
# node must be binary elemwise or add or mul
1042-
if not (
1043-
isinstance(node.op, Elemwise)
1044-
and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul)
1045-
):
1046-
return
10471042
# all inputs must be switch
10481043
if not all(
10491044
s.owner
@@ -1174,10 +1169,9 @@ def constant_folding(fgraph, node):
11741169
@register_infer_shape
11751170
@register_canonicalize("fast_compile")
11761171
@register_useless("fast_compile")
1177-
@node_rewriter(None)
1172+
@node_rewriter([ViewOp])
11781173
def local_view_op(fgraph, node):
1179-
if isinstance(node.op, ViewOp):
1180-
return node.inputs
1174+
return node.inputs
11811175

11821176

11831177
@register_infer_shape

pytensor/tensor/rewriting/blockwise.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from pytensor.compile.mode import optdb
2-
from pytensor.graph import Constant, node_rewriter
2+
from pytensor.graph import Constant, Op, node_rewriter
33
from pytensor.graph.destroyhandler import inplace_candidates
44
from pytensor.graph.replace import vectorize_node
55
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
6+
from pytensor.graph.rewriting.unify import OpPattern
67
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
78
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
89
from pytensor.tensor.math import Dot
@@ -20,6 +21,12 @@
2021
)
2122

2223

24+
def blockwise_of(core_op) -> OpPattern:
25+
if not isinstance(core_op, Op | OpPattern):
26+
core_op = OpPattern(core_op)
27+
return OpPattern(Blockwise, core_op=core_op)
28+
29+
2330
@node_rewriter([Blockwise])
2431
def local_useless_blockwise(fgraph, node):
2532
"""
@@ -71,22 +78,24 @@ def local_useless_unbatched_blockwise(fgraph, node):
7178
@register_canonicalize
7279
@register_stabilize
7380
@register_specialize
74-
@node_rewriter(tracks=[Blockwise])
81+
@node_rewriter(
82+
tracks=[
83+
blockwise_of(
84+
Dot
85+
| Alloc
86+
| ARange
87+
| Subtensor
88+
| AdvancedSubtensor
89+
| AdvancedIncSubtensor
90+
| Reshape
91+
)
92+
]
93+
)
7594
def local_eager_useless_unbatched_blockwise(fgraph, node):
76-
if isinstance(
77-
node.op.core_op,
78-
Dot
79-
| Alloc
80-
| ARange
81-
| Subtensor
82-
| AdvancedSubtensor
83-
| AdvancedIncSubtensor
84-
| Reshape,
85-
):
86-
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
87-
# These other Ops can't always be trivially vectorized at runtime,
88-
# since their inputs may imply non-rectangular shapes.
89-
return local_useless_unbatched_blockwise.fn(fgraph, node)
95+
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
96+
# These other Ops can't always be trivially vectorized at runtime,
97+
# since their inputs may imply non-rectangular shapes.
98+
return local_useless_unbatched_blockwise.fn(fgraph, node)
9099

91100

92101
@register_specialize("shape_unsafe")
@@ -204,7 +213,7 @@ def local_blockwise_alloc(fgraph, node):
204213

205214

206215
@register_specialize
207-
@node_rewriter([Blockwise])
216+
@node_rewriter([blockwise_of(Reshape)])
208217
def local_blockwise_reshape(fgraph, node):
209218
"""Rewrite away square Blockwise reshapes.
210219
@@ -215,9 +224,6 @@ def local_blockwise_reshape(fgraph, node):
215224
For the square Reshape case, we must wait for all the intermediate
216225
operations to be lifted as Allocs
217226
"""
218-
if not isinstance(node.op.core_op, Reshape):
219-
return None
220-
221227
x, output_shape = node.inputs
222228
batch_ndim = node.op.batch_ndim(node)
223229
if all(output_shape.type.broadcastable[:batch_ndim]):

pytensor/tensor/rewriting/elemwise.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
out2in,
2727
)
2828
from pytensor.graph.rewriting.db import SequenceDB
29+
from pytensor.graph.rewriting.unify import OpPattern
2930
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
3031
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
3132
from pytensor.tensor.basic import (
@@ -37,6 +38,7 @@
3738
from pytensor.tensor.rewriting.basic import (
3839
alloc_like,
3940
broadcasted_by,
41+
elemwise_of,
4042
register_canonicalize,
4143
register_specialize,
4244
register_stabilize,
@@ -422,7 +424,14 @@ def local_useless_dimshuffle_makevector(fgraph, node):
422424

423425

424426
@register_canonicalize
425-
@node_rewriter([Elemwise])
427+
@node_rewriter(
428+
[
429+
elemwise_of(
430+
OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float)
431+
),
432+
elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)),
433+
]
434+
)
426435
def local_upcast_elemwise_constant_inputs(fgraph, node):
427436
"""This explicitly upcasts constant inputs to elemwise Ops, when
428437
those Ops do implicit upcasting anyway.
@@ -433,12 +442,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
433442
if len(node.outputs) > 1:
434443
return None
435444

436-
if getattr(node.op.scalar_op, "output_types_preference", None) not in (
437-
ps.upgrade_to_float,
438-
ps.upcast_out,
439-
):
440-
return None
441-
442445
# this is the kind of op that we can screw with the input
443446
# dtypes by upcasting explicitly
444447
[old_out] = node.outputs
@@ -988,13 +991,9 @@ def print_profile(stream, prof, level=0):
988991

989992
@register_canonicalize
990993
@register_specialize
991-
@node_rewriter([Elemwise])
994+
@node_rewriter([elemwise_of(ps.Composite)])
992995
def local_useless_composite_outputs(fgraph, node):
993996
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
994-
if not (
995-
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite)
996-
):
997-
return
998997
comp = node.op.scalar_op
999998
used_outputs_idxs = [
1000999
i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]
@@ -1104,14 +1103,10 @@ def local_careduce_fusion(fgraph, node):
11041103
return [new_car_op(*elm_inputs)]
11051104

11061105

1107-
@node_rewriter([Elemwise])
1106+
@node_rewriter([elemwise_of(ps.Composite)])
11081107
def local_inline_composite_constants(fgraph, node):
11091108
"""Inline scalar constants in Composite graphs."""
11101109
composite_op = node.op.scalar_op
1111-
1112-
if not isinstance(composite_op, ps.Composite):
1113-
return None
1114-
11151110
new_outer_inputs = []
11161111
new_inner_inputs = []
11171112
inner_replacements = {}
@@ -1287,14 +1282,9 @@ def _rebuild_partial_2f1grad_loop(node, wrt):
12871282

12881283

12891284
@register_specialize
1290-
@node_rewriter([Elemwise])
1285+
@node_rewriter([elemwise_of(Grad2F1Loop)])
12911286
def local_useless_2f1grad_loop(fgraph, node):
12921287
# Remove unused terms from the hyp2f1 grad loop
1293-
1294-
loop_op = node.op.scalar_op
1295-
if not isinstance(loop_op, Grad2F1Loop):
1296-
return
1297-
12981288
grad_related_vars = node.outputs[:-4]
12991289
# Rewrite was already applied
13001290
if len(grad_related_vars) // 3 != 3:
@@ -1326,18 +1316,13 @@ def local_useless_2f1grad_loop(fgraph, node):
13261316
return replacements
13271317

13281318

1329-
@node_rewriter([Elemwise])
1319+
@node_rewriter([elemwise_of(Grad2F1Loop)])
13301320
def split_2f1grad_loop(fgraph, node):
13311321
"""
13321322
2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
13331323
13341324
This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
13351325
"""
1336-
loop_op = node.op.scalar_op
1337-
1338-
if not isinstance(loop_op, Grad2F1Loop):
1339-
return None
1340-
13411326
grad_related_vars = node.outputs[:-4]
13421327
# local_useless_2f1grad_loop was used, we should be safe
13431328
if len(grad_related_vars) // 3 != 3:

0 commit comments

Comments
 (0)