Skip to content

Commit 916b508

Browse files
committed
Allow OpPattern in tracks
Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters
1 parent 62a616b commit 916b508

File tree

4 files changed

+173
-46
lines changed

4 files changed

+173
-46
lines changed

doc/gallery/rewrites/graph_rewrites.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@
583583
" def tracks(self):\n",
584584
" return [pt.log]\n",
585585
" \n",
586-
" def transform(self, fgraph, node):\n",
586+
" def transform(self, fgraph, node, enforce_tracks=True):\n",
587587
" return local_log1p(node) \n",
588588
" \n",
589589
" def __str__(self):\n",
@@ -669,8 +669,8 @@
669669
"@node_rewriter(tracks=[pt.abs])\n",
670670
"def local_useless_abs_exp(fgraph, node):\n",
671671
" # Because of the tracks we don't need to check \n",
672-
" # that `node` has a `Sign` Op.\n",
673-
" # We still need to check whether it's input is an `Abs` Op\n",
672+
" # that `node` has a `Abs` Op.\n",
673+
" # We still need to check whether it's input is an `Exp` Op\n",
674674
" exp_node = node.inputs[0].owner\n",
675675
" if exp_node is None or exp_node.op != pt.exp:\n",
676676
" return None\n",

pytensor/graph/rewriting/basic.py

Lines changed: 132 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import warnings
1212
from collections import Counter, UserList, defaultdict, deque
1313
from collections.abc import Callable, Iterable, Sequence
14-
from functools import _compose_mro, partial # type: ignore
14+
from functools import _compose_mro, lru_cache, partial # type: ignore
1515
from itertools import chain
1616
from typing import Literal
1717

@@ -141,7 +141,12 @@ def tracks(self) -> Sequence[Op] | None:
141141

142142
@abc.abstractmethod
143143
def transform(
144-
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
144+
self,
145+
fgraph: FunctionGraph,
146+
node: Apply,
147+
enfoce_tracks: bool = True,
148+
*args,
149+
**kwargs,
145150
) -> TransformOutputType:
146151
r"""Rewrite the sub-graph given by `node`.
147152
@@ -159,7 +164,8 @@ def transform(
159164
A `FunctionGraph` containing `node`.
160165
node
161166
An `Apply` node to be rewritten.
162-
167+
enforce_tracks: bool
168+
Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage.
163169
"""
164170

165171
raise NotImplementedError()
@@ -935,15 +941,43 @@ class FromFunctionNodeRewriter(NodeRewriter):
935941
def __init__(self, fn, tracks=None, requirements=()):
936942
self.fn = fn
937943
self._tracks = tracks
938-
self._tracked_types = (
939-
tuple(t for t in tracks if isinstance(t, type)) if tracks else ()
940-
)
944+
self._tracked_ops = set()
945+
self._tracked_types = type(None)
946+
self._tracked_op_pattern_types = type(None)
947+
self._tracked_op_patterns: list[OpPattern] = []
948+
if tracks is not None:
949+
if not tracks:
950+
raise ValueError(
951+
"To specify a general rewrite leave tracks as None instead of an empty container"
952+
)
953+
for t in tracks:
954+
if isinstance(t, Op):
955+
self._tracked_ops.add(t)
956+
if isinstance(t, type):
957+
self._tracked_types |= t
958+
elif isinstance(t, OpPattern):
959+
if t.parameters:
960+
self._tracked_op_patterns.append(t)
961+
self._tracked_op_pattern_types |= t.op_type
962+
else:
963+
# An OpPattern without parameters behaves like a regular tracked_type
964+
self._tracked_types |= t
941965
self.requirements = requirements
942966

943-
def transform(self, fgraph, node):
944-
if self._tracks:
967+
def transform(self, fgraph, node, enforce_tracks: bool = True):
968+
if enforce_tracks and self._tracks:
969+
node_op = node.op
945970
if not (
946-
node.op in self._tracks or isinstance(node.op, self._tracked_types)
971+
node_op in self._tracked_ops
972+
or isinstance(node_op, self._tracked_types)
973+
or (
974+
isinstance(node.op, self._tracked_op_pattern_types)
975+
and any(
976+
t.match_parameters(node_op)
977+
for t in self._tracked_op_patterns
978+
if isinstance(node_op, t.op_type)
979+
)
980+
)
947981
):
948982
return False
949983

@@ -967,7 +1001,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
9671001

9681002

9691003
def node_rewriter(
970-
tracks: Sequence[Op | type] | None,
1004+
tracks: Sequence[Op | type, OpPattern] | None,
9711005
inplace: bool = False,
9721006
requirements: tuple[type, ...] | None = (),
9731007
):
@@ -995,14 +1029,16 @@ def decorator(f):
9951029
if tracks is not None:
9961030
if len(tracks) == 0:
9971031
raise ValueError(
998-
"Use `None` instead of an empty list to make an rewrite apply to all nodes."
1032+
"Use `None` instead of an empty list to make a rewrite apply to all nodes."
9991033
)
10001034
for t in tracks:
10011035
if not (
1002-
isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op))
1036+
isinstance(t, Op | OpPattern)
1037+
or (isinstance(t, type) and issubclass(t, Op))
10031038
):
10041039
raise TypeError(
1005-
"`tracks` must consist of `Op` classes or instances."
1040+
"`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
1041+
f"Got {t} of type {type(t)}"
10061042
)
10071043
req = requirements
10081044
if inplace:
@@ -1024,47 +1060,93 @@ class OpToRewriterTracker:
10241060
def __init__(self):
10251061
self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
10261062
self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
1063+
self.tracked_pattern_types: dict[type, dict[OpPattern, list[NodeRewriter]]] = (
1064+
defaultdict(lambda: defaultdict(list))
1065+
)
10271066
self.untracked_rewrites: list[NodeRewriter] = []
1067+
self._cached_composed_mro = None
10281068

10291069
def add_tracker(self, rw: NodeRewriter):
10301070
"""Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
1071+
if self._cached_composed_mro is not None:
1072+
# We shouldn't actually add_trackers after the first call to get_trackers
1073+
# But just to be safe we kill the cache here
1074+
self._cached_composed_mro = None
1075+
10311076
tracks = rw.tracks()
10321077

10331078
if tracks is None:
10341079
self.untracked_rewrites.append(rw)
10351080
else:
10361081
for c in tracks:
1082+
if isinstance(c, OpPattern):
1083+
if not isinstance(c.op_type, type):
1084+
# OpPattern allows anything that you can check with isinstance(op, op_type),
1085+
# including tuples or union types. But for OpToRewriterTracker we need a single type.
1086+
raise NotImplementedError(
1087+
"OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. "
1088+
f"Got {c.op_type} of type {type(c.op_type)}"
1089+
)
1090+
1091+
if c.parameters:
1092+
self.tracked_pattern_types[c.op_type][c].append(rw)
1093+
else:
1094+
# An OpPattern without parameters behaves like a regular tracked_type
1095+
self.tracked_types[c.op_type].append(rw)
10371096
if isinstance(c, type):
10381097
self.tracked_types[c].append(rw)
10391098
else:
10401099
self.tracked_instances[c].append(rw)
10411100

1042-
def _find_impl(self, cls) -> list[NodeRewriter]:
1043-
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
1101+
@functools.lru_cache
1102+
def get_trackers(self, op: Op) -> list[NodeRewriter]:
1103+
"""Get all the rewrites applicable to an `Op`."""
1104+
1105+
if self._cached_composed_mro is None:
1106+
# Cache the mro call on the Op type. We have a small subset of op_types we actuall care about
1107+
# like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate
1108+
tracked_types = (
1109+
self.tracked_types.keys() | self.tracked_pattern_types.keys()
1110+
)
1111+
1112+
@lru_cache
1113+
def cached_composed_mro(op_type, tracked_types=tracked_types):
1114+
return _compose_mro(op_type, tracked_types)
1115+
1116+
self._cached_composed_mro = cached_composed_mro
10441117

1045-
This based on `functools._find_impl`.
1046-
"""
1047-
mro = _compose_mro(cls, self.tracked_types.keys())
10481118
matches = []
1049-
for t in mro:
1050-
match = self.tracked_types.get(t, None)
1051-
if match:
1052-
matches.extend(match)
1119+
if self.tracked_types or self.tracked_pattern_types:
1120+
# Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses
1121+
mro = self._cached_composed_mro(type(op))
1122+
for t in mro:
1123+
if (match := self.tracked_types.get(t, None)) is not None:
1124+
matches.extend(match)
1125+
if (
1126+
potential_matches := self.tracked_pattern_types.get(t, None)
1127+
) is not None:
1128+
# We still need to check if the Op parameters match the constraints
1129+
matches.extend(
1130+
[
1131+
item
1132+
for op_pattern, r_list in potential_matches.items()
1133+
if op_pattern.match_parameters(op)
1134+
for item in r_list
1135+
]
1136+
)
1137+
matches.extend(self.tracked_instances.get(op, []))
1138+
matches.extend(self.untracked_rewrites)
10531139
return matches
10541140

1055-
@functools.lru_cache
1056-
def get_trackers(self, op: Op) -> list[NodeRewriter]:
1057-
"""Get all the rewrites applicable to `op`."""
1058-
return (
1059-
self._find_impl(type(op))
1060-
+ self.tracked_instances.get(op, [])
1061-
+ self.untracked_rewrites
1062-
)
1063-
1064-
def get_rewriters(self):
1141+
def get_rewriters(self) -> Iterable[NodeRewriter]:
1142+
"""Get all the registered rewriters."""
10651143
return chain(
1144+
chain.from_iterable(self.tracked_types.values()),
1145+
chain.from_iterable(self.tracked_instances.values()),
10661146
chain.from_iterable(
1067-
chain(self.tracked_types.values(), self.tracked_instances.values())
1147+
item
1148+
for sub_dict in self.tracked_pattern_types.values()
1149+
for item in sub_dict.values()
10681150
),
10691151
self.untracked_rewrites,
10701152
)
@@ -1138,7 +1220,7 @@ def tracks(self):
11381220
t.extend(at)
11391221
return t
11401222

1141-
def transform(self, fgraph, node):
1223+
def transform(self, fgraph, node, enforce_tracks=False):
11421224
if len(self.rewrites) == 0:
11431225
return
11441226

@@ -1150,7 +1232,8 @@ def transform(self, fgraph, node):
11501232
new_repl = None
11511233
for rewrite in rewrites:
11521234
rewrite_start = time.perf_counter()
1153-
new_repl = rewrite.transform(fgraph, node)
1235+
# Tracks are already enforced by `self.tracker.get_trackers`
1236+
new_repl = rewrite.transform(fgraph, node, enforce_tracks=False)
11541237
rewrite_finish = time.perf_counter()
11551238
if self.profile:
11561239
self.time_rewrites[rewrite] += rewrite_start - rewrite_finish
@@ -1292,8 +1375,8 @@ def __init__(self, op1, op2, transfer_tags=True):
12921375
def tracks(self):
12931376
return [self.op1]
12941377

1295-
def transform(self, fgraph, node):
1296-
if node.op != self.op1:
1378+
def transform(self, fgraph, node, enforce_tracks=True):
1379+
if enforce_tracks and (node.op != self.op1):
12971380
return False
12981381
repl = self.op2.make_node(*node.inputs)
12991382
if self.transfer_tags:
@@ -1497,7 +1580,7 @@ def __init__(
14971580
def tracks(self):
14981581
return self._tracks
14991582

1500-
def transform(self, fgraph, node, get_nodes=True):
1583+
def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True):
15011584
"""Check if the graph from node corresponds to ``in_pattern``.
15021585
15031586
If it does, it constructs ``out_pattern`` and performs the replacement.
@@ -1787,6 +1870,7 @@ def process_node(
17871870
fgraph: FunctionGraph,
17881871
node: Apply,
17891872
node_rewriter: NodeRewriter | None = None,
1873+
enforce_tracks: bool = True,
17901874
):
17911875
r"""Apply `node_rewriter` to `node`.
17921876
@@ -1804,6 +1888,9 @@ def process_node(
18041888
node_rewriter
18051889
A `NodeRewriter` instance that may have a better idea for
18061890
how to compute node's outputs.
1891+
enforce_tracks: bool
1892+
Whether the transform method should enforce tracks,
1893+
or it can be assumed the caller already enforced them in a pre-filter stage.
18071894
18081895
Returns
18091896
-------
@@ -1819,7 +1906,9 @@ def process_node(
18191906
# TODO FIXME: This class's interface is broken
18201907
assert node_rewriter is not None
18211908
try:
1822-
replacements = node_rewriter.transform(fgraph, node)
1909+
replacements = node_rewriter.transform(
1910+
fgraph, node, enforce_tracks=enforce_tracks
1911+
)
18231912
except Exception as e:
18241913
if self.failure_callback is not None:
18251914
self.failure_callback(
@@ -1937,7 +2026,8 @@ def importer(node):
19372026
if node not in fgraph.apply_nodes:
19382027
continue
19392028
current_node = node
1940-
nb += self.process_node(fgraph, node)
2029+
# This rewriter does not enforce tracks itself
2030+
nb += self.process_node(fgraph, node, enforce_tracks=True)
19412031
loop_t = time.perf_counter() - t0
19422032
finally:
19432033
self.detach_updater(fgraph, u)
@@ -2278,8 +2368,9 @@ def chin_(node, i, r, new_r, reason):
22782368
for node_rewriter in self.node_tracker.get_trackers(node.op):
22792369
nb = change_tracker.nb_imported
22802370
t_rewrite = time.perf_counter()
2371+
# Tracks are already enfoced by `self.node_tracker.get_trackers`
22812372
node_rewriter_change = self.process_node(
2282-
fgraph, node, node_rewriter
2373+
fgraph, node, node_rewriter, enforce_tracks=False
22832374
)
22842375
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
22852376
if not node_rewriter_change:

pytensor/graph/rewriting/unify.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,42 @@ class OpPattern:
275275
Examples
276276
--------
277277
278+
OpPattern can be used in the `tracks` functionality of `node_rewriter` to more flexible filter out nodes.
279+
For Ops that are parametrized by other Ops, it's possible to use nested OpPatterns.
280+
281+
.. test-code::
282+
283+
from pytensor.graph.rewriting.basic import node_rewriter
284+
from pytensor.graph.rewriting.unify import OpPattern
285+
from pytensor.tensor.elemwise import CAReduce
286+
from pytensor.tensor.blockwise import Blockwise
287+
from pytensor.tensor.slinalg import Solve
288+
289+
@node_rewriter(tracks=[OpPattern(CAReduce, axis=None)])
290+
def local_car_reduce_all_rewriter(fgraph, node):
291+
# This will always be true!
292+
assert isinstance(node.op, CAReduce) and node.op.axis is None
293+
...
294+
295+
# Any Blockwise whose core_op is a Solve Op (or subclass) instance
296+
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve))])
297+
def local_blockwise_solve_triangular_rewriter(fgraph, node):
298+
# This will always be true!
299+
assert isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)
300+
...
301+
302+
# Any Blockwise whose core_op is a Solve Op (or subclass) instance with b_ndim==1
303+
@node_rewriter(tracks=[OpPattern(Blockwise, core_op=OpPattern(Solve, b_ndim=1))])
304+
def local_blockwise_vector_solve_rewriter(fgraph, node):
305+
# This will always be true!
306+
assert (
307+
isinstance(node.op, Blockwise)
308+
and isinstance(node.op.core_op, Solve)
309+
and node.op.core_op.b_ndim == 1
310+
)
311+
...
312+
313+
278314
OpPattern can be used with `PatternNodeRewriter` to define graph rewrites that match Ops with specific parameters.
279315
The example below matches two nested CAReduce Ops with the same `scalar_op`,
280316
the outer with `axis=None` (full reduction) and fuses them into a single CAReduce.

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,9 +1338,9 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
13381338

13391339
return ct + num, denum
13401340

1341-
def transform(self, fgraph, node):
1341+
def transform(self, fgraph, node, enforce_tracks=True):
13421342
op = node.op
1343-
if op not in [self.main, self.inverse, self.reciprocal]:
1343+
if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}):
13441344
return False
13451345

13461346
assert len(node.outputs) == 1

0 commit comments

Comments
 (0)