Skip to content

Commit 183f4c5

Browse files
committed
Allow OpPattern in tracks
Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters
1 parent 1b99b08 commit 183f4c5

File tree

5 files changed

+181
-48
lines changed

5 files changed

+181
-48
lines changed

doc/gallery/rewrites/graph_rewrites.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@
583583
" def tracks(self):\n",
584584
" return [pt.log]\n",
585585
" \n",
586-
" def transform(self, fgraph, node):\n",
586+
" def transform(self, fgraph, node, enforce_tracks=True):\n",
587587
" return local_log1p(node) \n",
588588
" \n",
589589
" def __str__(self):\n",
@@ -669,8 +669,8 @@
669669
"@node_rewriter(tracks=[pt.abs])\n",
670670
"def local_useless_abs_exp(fgraph, node):\n",
671671
" # Because of the tracks we don't need to check \n",
672-
" # that `node` has a `Sign` Op.\n",
673-
" # We still need to check whether it's input is an `Abs` Op\n",
672+
" # that `node` has a `Abs` Op.\n",
673+
" # We still need to check whether it's input is an `Exp` Op\n",
674674
" exp_node = node.inputs[0].owner\n",
675675
" if exp_node is None or exp_node.op != pt.exp:\n",
676676
" return None\n",

pytensor/graph/rewriting/basic.py

Lines changed: 138 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,12 @@ def tracks(self) -> Sequence[Op] | None:
141141

142142
@abc.abstractmethod
143143
def transform(
144-
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
144+
self,
145+
fgraph: FunctionGraph,
146+
node: Apply,
147+
enforce_tracks: bool = True,
148+
*args,
149+
**kwargs,
145150
) -> TransformOutputType:
146151
r"""Rewrite the sub-graph given by `node`.
147152
@@ -159,7 +164,9 @@ def transform(
159164
A `FunctionGraph` containing `node`.
160165
node
161166
An `Apply` node to be rewritten.
162-
167+
enforce_tracks: bool
168+
Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage.
169+
See `node_rewriter` tracks argument for more details.
163170
"""
164171

165172
raise NotImplementedError()
@@ -935,15 +942,48 @@ class FromFunctionNodeRewriter(NodeRewriter):
935942
def __init__(self, fn, tracks=None, requirements=()):
936943
self.fn = fn
937944
self._tracks = tracks
938-
self._tracked_types = (
939-
tuple(t for t in tracks if isinstance(t, type)) if tracks else ()
940-
)
945+
self._tracked_ops = set()
946+
self._tracked_types = type(None)
947+
self._tracked_op_pattern_types = type(None)
948+
self._tracked_op_patterns: list[OpPattern] = []
949+
if tracks is not None:
950+
if not tracks:
951+
raise ValueError(
952+
"To specify a general rewrite leave tracks as None instead of an empty container"
953+
)
954+
for t in tracks:
955+
if isinstance(t, Op):
956+
self._tracked_ops.add(t)
957+
elif isinstance(t, type):
958+
self._tracked_types |= t
959+
elif isinstance(t, OpPattern):
960+
if t.parameters:
961+
self._tracked_op_patterns.append(t)
962+
self._tracked_op_pattern_types |= t.op_type
963+
else:
964+
# An OpPattern without parameters behaves like a regular tracked_type
965+
self._tracked_types |= t
966+
else:
967+
raise TypeError(
968+
"`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
969+
f"Got {t} of type {type(t)}"
970+
)
941971
self.requirements = requirements
942972

943-
def transform(self, fgraph, node):
944-
if self._tracks:
973+
def transform(self, fgraph, node, enforce_tracks: bool = True):
974+
if enforce_tracks and self._tracks:
975+
node_op = node.op
945976
if not (
946-
node.op in self._tracks or isinstance(node.op, self._tracked_types)
977+
node_op in self._tracked_ops
978+
or isinstance(node_op, self._tracked_types)
979+
or (
980+
isinstance(node.op, self._tracked_op_pattern_types)
981+
and any(
982+
t.match_parameters(node_op)
983+
for t in self._tracked_op_patterns
984+
if isinstance(node_op, t.op_type)
985+
)
986+
)
947987
):
948988
return False
949989

@@ -967,7 +1007,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
9671007

9681008

9691009
def node_rewriter(
970-
tracks: Sequence[Op | type] | None,
1010+
tracks: Sequence[Op | type, OpPattern] | None,
9711011
inplace: bool = False,
9721012
requirements: tuple[type, ...] | None = (),
9731013
):
@@ -976,7 +1016,7 @@ def node_rewriter(
9761016
Parameters
9771017
----------
9781018
tracks
979-
The `Op` types or instances to which this rewrite applies.
1019+
The `Op` type, instances or `OpPattern` to which this rewrite applies.
9801020
Use ``None`` instead of an empty list to have the rewrite apply to
9811021
all `Op`\s.
9821022
inplace
@@ -995,14 +1035,16 @@ def decorator(f):
9951035
if tracks is not None:
9961036
if len(tracks) == 0:
9971037
raise ValueError(
998-
"Use `None` instead of an empty list to make an rewrite apply to all nodes."
1038+
"Use `None` instead of an empty list to make a rewrite apply to all nodes."
9991039
)
10001040
for t in tracks:
10011041
if not (
1002-
isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op))
1042+
isinstance(t, Op | OpPattern)
1043+
or (isinstance(t, type) and issubclass(t, Op))
10031044
):
10041045
raise TypeError(
1005-
"`tracks` must consist of `Op` classes or instances."
1046+
"`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
1047+
f"Got {t} of type {type(t)}"
10061048
)
10071049
req = requirements
10081050
if inplace:
@@ -1024,47 +1066,93 @@ class OpToRewriterTracker:
10241066
def __init__(self):
10251067
self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
10261068
self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
1069+
self.tracked_pattern_types: dict[type, dict[OpPattern, list[NodeRewriter]]] = (
1070+
defaultdict(lambda: defaultdict(list))
1071+
)
10271072
self.untracked_rewrites: list[NodeRewriter] = []
1073+
self._cached_composed_mro = None
10281074

10291075
def add_tracker(self, rw: NodeRewriter):
10301076
"""Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
1077+
if self._cached_composed_mro is not None:
1078+
# We shouldn't actually add_trackers after the first call to get_trackers
1079+
# But just to be safe we kill the cache here
1080+
self._cached_composed_mro = None
1081+
10311082
tracks = rw.tracks()
10321083

10331084
if tracks is None:
10341085
self.untracked_rewrites.append(rw)
10351086
else:
10361087
for c in tracks:
1088+
if isinstance(c, OpPattern):
1089+
if not isinstance(c.op_type, type):
1090+
# OpPattern allows anything that you can check with isinstance(op, op_type),
1091+
# including tuples or union types. But for OpToRewriterTracker we need a single type.
1092+
raise NotImplementedError(
1093+
"OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. "
1094+
f"Got {c.op_type} of type {type(c.op_type)}"
1095+
)
1096+
1097+
if c.parameters:
1098+
self.tracked_pattern_types[c.op_type][c].append(rw)
1099+
else:
1100+
# An OpPattern without parameters behaves like a regular tracked_type
1101+
self.tracked_types[c.op_type].append(rw)
10371102
if isinstance(c, type):
10381103
self.tracked_types[c].append(rw)
10391104
else:
10401105
self.tracked_instances[c].append(rw)
10411106

1042-
def _find_impl(self, cls) -> list[NodeRewriter]:
1043-
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
1107+
@functools.cache
1108+
def get_trackers(self, op: Op) -> list[NodeRewriter]:
1109+
"""Get all the rewrites applicable to an `Op`."""
1110+
1111+
if self._cached_composed_mro is None:
1112+
# Cache the mro call on the Op type. We have a small subset of op_types we actually care about
1113+
# like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate
1114+
tracked_types = (
1115+
self.tracked_types.keys() | self.tracked_pattern_types.keys()
1116+
)
1117+
1118+
@functools.cache
1119+
def cached_composed_mro(op_type, tracked_types=tracked_types):
1120+
return _compose_mro(op_type, tracked_types)
1121+
1122+
self._cached_composed_mro = cached_composed_mro
10441123

1045-
This based on `functools._find_impl`.
1046-
"""
1047-
mro = _compose_mro(cls, self.tracked_types.keys())
10481124
matches = []
1049-
for t in mro:
1050-
match = self.tracked_types.get(t, None)
1051-
if match:
1052-
matches.extend(match)
1125+
if self.tracked_types or self.tracked_pattern_types:
1126+
# Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses
1127+
mro = self._cached_composed_mro(type(op))
1128+
for t in mro:
1129+
if (match := self.tracked_types.get(t, None)) is not None:
1130+
matches.extend(match)
1131+
if (
1132+
potential_matches := self.tracked_pattern_types.get(t, None)
1133+
) is not None:
1134+
# We still need to check if the Op parameters match the constraints
1135+
matches.extend(
1136+
[
1137+
item
1138+
for op_pattern, r_list in potential_matches.items()
1139+
if op_pattern.match_parameters(op)
1140+
for item in r_list
1141+
]
1142+
)
1143+
matches.extend(self.tracked_instances.get(op, []))
1144+
matches.extend(self.untracked_rewrites)
10531145
return matches
10541146

1055-
@functools.lru_cache
1056-
def get_trackers(self, op: Op) -> list[NodeRewriter]:
1057-
"""Get all the rewrites applicable to `op`."""
1058-
return (
1059-
self._find_impl(type(op))
1060-
+ self.tracked_instances.get(op, [])
1061-
+ self.untracked_rewrites
1062-
)
1063-
1064-
def get_rewriters(self):
1147+
def get_rewriters(self) -> Iterable[NodeRewriter]:
1148+
"""Get all the registered rewriters."""
10651149
return chain(
1150+
chain.from_iterable(self.tracked_types.values()),
1151+
chain.from_iterable(self.tracked_instances.values()),
10661152
chain.from_iterable(
1067-
chain(self.tracked_types.values(), self.tracked_instances.values())
1153+
item
1154+
for sub_dict in self.tracked_pattern_types.values()
1155+
for item in sub_dict.values()
10681156
),
10691157
self.untracked_rewrites,
10701158
)
@@ -1138,7 +1226,7 @@ def tracks(self):
11381226
t.extend(at)
11391227
return t
11401228

1141-
def transform(self, fgraph, node):
1229+
def transform(self, fgraph, node, enforce_tracks=False):
11421230
if len(self.rewrites) == 0:
11431231
return
11441232

@@ -1150,7 +1238,8 @@ def transform(self, fgraph, node):
11501238
new_repl = None
11511239
for rewrite in rewrites:
11521240
rewrite_start = time.perf_counter()
1153-
new_repl = rewrite.transform(fgraph, node)
1241+
# Tracks are already enforced by `self.tracker.get_trackers`
1242+
new_repl = rewrite.transform(fgraph, node, enforce_tracks=False)
11541243
rewrite_finish = time.perf_counter()
11551244
if self.profile:
11561245
self.time_rewrites[rewrite] += rewrite_start - rewrite_finish
@@ -1292,8 +1381,8 @@ def __init__(self, op1, op2, transfer_tags=True):
12921381
def tracks(self):
12931382
return [self.op1]
12941383

1295-
def transform(self, fgraph, node):
1296-
if node.op != self.op1:
1384+
def transform(self, fgraph, node, enforce_tracks=True):
1385+
if enforce_tracks and (node.op != self.op1):
12971386
return False
12981387
repl = self.op2.make_node(*node.inputs)
12991388
if self.transfer_tags:
@@ -1498,7 +1587,7 @@ def __init__(
14981587
def tracks(self):
14991588
return self._tracks
15001589

1501-
def transform(self, fgraph, node, get_nodes=True):
1590+
def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True):
15021591
"""Check if the graph from node corresponds to ``in_pattern``.
15031592
15041593
If it does, it constructs ``out_pattern`` and performs the replacement.
@@ -1788,6 +1877,7 @@ def process_node(
17881877
fgraph: FunctionGraph,
17891878
node: Apply,
17901879
node_rewriter: NodeRewriter | None = None,
1880+
enforce_tracks: bool = True,
17911881
):
17921882
r"""Apply `node_rewriter` to `node`.
17931883
@@ -1805,6 +1895,9 @@ def process_node(
18051895
node_rewriter
18061896
A `NodeRewriter` instance that may have a better idea for
18071897
how to compute node's outputs.
1898+
enforce_tracks: bool
1899+
Whether the transform method should enforce tracks,
1900+
or it can be assumed the caller already enforced them in a pre-filter stage.
18081901
18091902
Returns
18101903
-------
@@ -1820,7 +1913,9 @@ def process_node(
18201913
# TODO FIXME: This class's interface is broken
18211914
assert node_rewriter is not None
18221915
try:
1823-
replacements = node_rewriter.transform(fgraph, node)
1916+
replacements = node_rewriter.transform(
1917+
fgraph, node, enforce_tracks=enforce_tracks
1918+
)
18241919
except Exception as e:
18251920
if self.failure_callback is not None:
18261921
self.failure_callback(
@@ -1938,7 +2033,8 @@ def importer(node):
19382033
if node not in fgraph.apply_nodes:
19392034
continue
19402035
current_node = node
1941-
nb += self.process_node(fgraph, node)
2036+
# This rewriter does not enforce tracks itself
2037+
nb += self.process_node(fgraph, node, enforce_tracks=True)
19422038
loop_t = time.perf_counter() - t0
19432039
finally:
19442040
self.detach_updater(fgraph, u)
@@ -2279,8 +2375,9 @@ def chin_(node, i, r, new_r, reason):
22792375
for node_rewriter in self.node_tracker.get_trackers(node.op):
22802376
nb = change_tracker.nb_imported
22812377
t_rewrite = time.perf_counter()
2378+
# Tracks are already enfoced by `self.node_tracker.get_trackers`
22822379
node_rewriter_change = self.process_node(
2283-
fgraph, node, node_rewriter
2380+
fgraph, node, node_rewriter, enforce_tracks=False
22842381
)
22852382
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
22862383
if not node_rewriter_change:

pytensor/graph/rewriting/kanren.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def results_filter(
7474
self.node_filter = node_filter
7575
super().__init__()
7676

77-
def transform(self, fgraph, node):
77+
def transform(self, fgraph, node, enforce_tracks: bool = True):
7878
if self.node_filter(node) is False:
7979
return False
8080

@@ -92,7 +92,7 @@ def transform(self, fgraph, node):
9292
if isinstance(chosen_res, list):
9393
new_outputs = [eval_if_etuple(v) for v in chosen_res]
9494
else:
95-
new_outputs = [eval_if_etuple(chosen_res)]
95+
new_outputs = [eval_if_etuple(chosen_res)] # type: ignore[unreachable]
9696

9797
return new_outputs
9898
else:

0 commit comments

Comments
 (0)