Skip to content

Commit 18c1ef7

Browse files
committed
Allow OpPattern in tracks
Also avoid repeated checks when an outer rewriter enforces tracks before calling individual node rewriters
1 parent e0bbed2 commit 18c1ef7

File tree

3 files changed

+132
-44
lines changed

3 files changed

+132
-44
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 129 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import warnings
1212
from collections import Counter, UserList, defaultdict, deque
1313
from collections.abc import Callable, Iterable, Sequence
14-
from functools import _compose_mro, partial # type: ignore
14+
from functools import _compose_mro, lru_cache, partial # type: ignore
1515
from itertools import chain
1616
from typing import Literal
1717

@@ -141,7 +141,12 @@ def tracks(self) -> Sequence[Op] | None:
141141

142142
@abc.abstractmethod
143143
def transform(
144-
self, fgraph: FunctionGraph, node: Apply, *args, **kwargs
144+
self,
145+
fgraph: FunctionGraph,
146+
node: Apply,
147+
enfoce_tracks: bool = True,
148+
*args,
149+
**kwargs,
145150
) -> TransformOutputType:
146151
r"""Rewrite the sub-graph given by `node`.
147152
@@ -159,7 +164,8 @@ def transform(
159164
A `FunctionGraph` containing `node`.
160165
node
161166
An `Apply` node to be rewritten.
162-
167+
enforce_tracks: bool
168+
Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage.
163169
"""
164170

165171
raise NotImplementedError()
@@ -935,15 +941,43 @@ class FromFunctionNodeRewriter(NodeRewriter):
935941
def __init__(self, fn, tracks=None, requirements=()):
936942
self.fn = fn
937943
self._tracks = tracks
938-
self._tracked_types = (
939-
tuple(t for t in tracks if isinstance(t, type)) if tracks else ()
940-
)
944+
self._tracked_ops = set()
945+
self._tracked_types = type(None)
946+
self._tracked_parametrized_types = type(None)
947+
self._tracked_op_instance_patterns: list[OpPattern] = []
948+
if tracks is not None:
949+
if not tracks:
950+
raise ValueError(
951+
"To specify a general rewrite leave tracks as None instead of an empty container"
952+
)
953+
for t in tracks:
954+
if isinstance(t, Op):
955+
self._tracked_ops.add(t)
956+
if isinstance(t, type):
957+
self._tracked_types |= t
958+
elif isinstance(t, OpPattern):
959+
if t.parameters:
960+
self._tracked_op_instance_patterns.append(t)
961+
self._tracked_parametrized_types |= t.op_type
962+
else:
963+
# It's a regular tracked_type
964+
self._tracked_types |= t
941965
self.requirements = requirements
942966

943-
def transform(self, fgraph, node):
944-
if self._tracks:
967+
def transform(self, fgraph, node, enforce_tracks: bool = True):
968+
if enforce_tracks and self._tracks:
969+
node_op = node.op
945970
if not (
946-
node.op in self._tracks or isinstance(node.op, self._tracked_types)
971+
node_op in self._tracked_ops
972+
or isinstance(node_op, self._tracked_types)
973+
or (
974+
isinstance(node.op, self._tracked_parametrized_types)
975+
and any(
976+
t.match_parameters(node_op)
977+
for t in self._tracked_op_instance_patterns
978+
if isinstance(node_op, t.op_type)
979+
)
980+
)
947981
):
948982
return False
949983

@@ -967,7 +1001,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
9671001

9681002

9691003
def node_rewriter(
970-
tracks: Sequence[Op | type] | None,
1004+
tracks: Sequence[Op | type, OpPattern] | None,
9711005
inplace: bool = False,
9721006
requirements: tuple[type, ...] | None = (),
9731007
):
@@ -995,14 +1029,15 @@ def decorator(f):
9951029
if tracks is not None:
9961030
if len(tracks) == 0:
9971031
raise ValueError(
998-
"Use `None` instead of an empty list to make an rewrite apply to all nodes."
1032+
"Use `None` instead of an empty list to make a rewrite apply to all nodes."
9991033
)
10001034
for t in tracks:
10011035
if not (
1002-
isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op))
1036+
isinstance(t, Op | OpPattern)
1037+
or (isinstance(t, type) and issubclass(t, Op))
10031038
):
10041039
raise TypeError(
1005-
"`tracks` must consist of `Op` classes or instances."
1040+
"`tracks` must consist of `Op` classes, instances or `OpPattern` instances."
10061041
)
10071042
req = requirements
10081043
if inplace:
@@ -1024,47 +1059,91 @@ class OpToRewriterTracker:
10241059
def __init__(self):
10251060
self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
10261061
self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
1062+
self.tracked_parametrized_types: dict[
1063+
type, dict[OpPattern, list[NodeRewriter]]
1064+
] = defaultdict(lambda: defaultdict(list))
10271065
self.untracked_rewrites: list[NodeRewriter] = []
1066+
self._cached_composed_mro = None
10281067

10291068
def add_tracker(self, rw: NodeRewriter):
10301069
"""Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
1070+
if self._cached_composed_mro is not None:
1071+
# We shouldn't actually add_trackers after the first call to get_trackers
1072+
# But just to be safe we kill the cache here
1073+
self._cached_composed_mro = None
1074+
10311075
tracks = rw.tracks()
10321076

10331077
if tracks is None:
10341078
self.untracked_rewrites.append(rw)
10351079
else:
10361080
for c in tracks:
1081+
if isinstance(c, OpPattern):
1082+
if not isinstance(c.op_type, type):
1083+
raise NotImplementedError(
1084+
"OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. "
1085+
f"Got {c.op_type} of type {type(c.op_type)}"
1086+
)
1087+
1088+
if c.parameters:
1089+
self.tracked_parametrized_types[c.op_type][c].append(rw)
1090+
else:
1091+
# It's a simple type track
1092+
self.tracked_types[c.op_type].append(rw)
10371093
if isinstance(c, type):
10381094
self.tracked_types[c].append(rw)
10391095
else:
10401096
self.tracked_instances[c].append(rw)
10411097

1042-
def _find_impl(self, cls) -> list[NodeRewriter]:
1043-
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
1098+
@functools.lru_cache
1099+
def get_trackers(self, op: Op) -> list[NodeRewriter]:
1100+
"""Get all the rewrites applicable to an `Op`."""
1101+
1102+
if self._cached_composed_mro is None:
1103+
# Cache the mro call on the Op type. We have a small subset of op_types we actuall care about
1104+
# like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate
1105+
tracked_types = (
1106+
self.tracked_types.keys() | self.tracked_parametrized_types.keys()
1107+
)
1108+
1109+
@lru_cache
1110+
def cached_composed_mro(op_type, tracked_types=tracked_types):
1111+
return _compose_mro(op_type, tracked_types)
1112+
1113+
self._cached_composed_mro = cached_composed_mro
10441114

1045-
This based on `functools._find_impl`.
1046-
"""
1047-
mro = _compose_mro(cls, self.tracked_types.keys())
10481115
matches = []
1049-
for t in mro:
1050-
match = self.tracked_types.get(t, None)
1051-
if match:
1052-
matches.extend(match)
1116+
if self.tracked_types or self.tracked_parametrized_types:
1117+
# Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses
1118+
mro = self._cached_composed_mro(type(op))
1119+
for t in mro:
1120+
if (match := self.tracked_types.get(t, None)) is not None:
1121+
matches.extend(match)
1122+
if (
1123+
potential_matches := self.tracked_parametrized_types.get(t, None)
1124+
) is not None:
1125+
# We still need to check if the Op parameters match the constraints
1126+
matches.extend(
1127+
[
1128+
item
1129+
for op_pattern, r_list in potential_matches.items()
1130+
if op_pattern.match_parameters(op)
1131+
for item in r_list
1132+
]
1133+
)
1134+
matches.extend(self.tracked_instances.get(op, []))
1135+
matches.extend(self.untracked_rewrites)
10531136
return matches
10541137

1055-
@functools.lru_cache
1056-
def get_trackers(self, op: Op) -> list[NodeRewriter]:
1057-
"""Get all the rewrites applicable to `op`."""
1058-
return (
1059-
self._find_impl(type(op))
1060-
+ self.tracked_instances.get(op, [])
1061-
+ self.untracked_rewrites
1062-
)
1063-
1064-
def get_rewriters(self):
1138+
def get_rewriters(self) -> Iterable[NodeRewriter]:
1139+
"""Get all the registered rewriters."""
10651140
return chain(
1141+
chain.from_iterable(self.tracked_types.values()),
1142+
chain.from_iterable(self.tracked_instances.values()),
10661143
chain.from_iterable(
1067-
chain(self.tracked_types.values(), self.tracked_instances.values())
1144+
item
1145+
for sub_dict in self.tracked_parametrized_types.values()
1146+
for item in sub_dict.values()
10681147
),
10691148
self.untracked_rewrites,
10701149
)
@@ -1138,7 +1217,7 @@ def tracks(self):
11381217
t.extend(at)
11391218
return t
11401219

1141-
def transform(self, fgraph, node):
1220+
def transform(self, fgraph, node, enforce_tracks=False):
11421221
if len(self.rewrites) == 0:
11431222
return
11441223

@@ -1150,7 +1229,8 @@ def transform(self, fgraph, node):
11501229
new_repl = None
11511230
for rewrite in rewrites:
11521231
rewrite_start = time.perf_counter()
1153-
new_repl = rewrite.transform(fgraph, node)
1232+
# Tracks are already enforced by `self.tracker.get_trackers`
1233+
new_repl = rewrite.transform(fgraph, node, enforce_tracks=False)
11541234
rewrite_finish = time.perf_counter()
11551235
if self.profile:
11561236
self.time_rewrites[rewrite] += rewrite_start - rewrite_finish
@@ -1292,8 +1372,8 @@ def __init__(self, op1, op2, transfer_tags=True):
12921372
def tracks(self):
12931373
return [self.op1]
12941374

1295-
def transform(self, fgraph, node):
1296-
if node.op != self.op1:
1375+
def transform(self, fgraph, node, enforce_tracks=True):
1376+
if enforce_tracks and (node.op != self.op1):
12971377
return False
12981378
repl = self.op2.make_node(*node.inputs)
12991379
if self.transfer_tags:
@@ -1492,7 +1572,7 @@ def __init__(
14921572
def tracks(self):
14931573
return self._tracks
14941574

1495-
def transform(self, fgraph, node, get_nodes=True):
1575+
def transform(self, fgraph, node, enforce_tracks: bool = False, get_nodes=True):
14961576
"""Check if the graph from node corresponds to ``in_pattern``.
14971577
14981578
If it does, it constructs ``out_pattern`` and performs the replacement.
@@ -1782,6 +1862,7 @@ def process_node(
17821862
fgraph: FunctionGraph,
17831863
node: Apply,
17841864
node_rewriter: NodeRewriter | None = None,
1865+
enforce_tracks: bool = True,
17851866
):
17861867
r"""Apply `node_rewriter` to `node`.
17871868
@@ -1799,6 +1880,9 @@ def process_node(
17991880
node_rewriter
18001881
A `NodeRewriter` instance that may have a better idea for
18011882
how to compute node's outputs.
1883+
enforce_tracks: bool
1884+
Whether the transform method should enforce tracks,
1885+
or it can be assumed the caller already enforced them in a pre-filter stage.
18021886
18031887
Returns
18041888
-------
@@ -1814,7 +1898,9 @@ def process_node(
18141898
# TODO FIXME: This class's interface is broken
18151899
assert node_rewriter is not None
18161900
try:
1817-
replacements = node_rewriter.transform(fgraph, node)
1901+
replacements = node_rewriter.transform(
1902+
fgraph, node, enforce_tracks=enforce_tracks
1903+
)
18181904
except Exception as e:
18191905
if self.failure_callback is not None:
18201906
self.failure_callback(
@@ -1932,7 +2018,8 @@ def importer(node):
19322018
if node not in fgraph.apply_nodes:
19332019
continue
19342020
current_node = node
1935-
nb += self.process_node(fgraph, node)
2021+
# This rewriter does not enforce tracks itself
2022+
nb += self.process_node(fgraph, node, enforce_tracks=True)
19362023
loop_t = time.perf_counter() - t0
19372024
finally:
19382025
self.detach_updater(fgraph, u)
@@ -2273,8 +2360,9 @@ def chin_(node, i, r, new_r, reason):
22732360
for node_rewriter in self.node_tracker.get_trackers(node.op):
22742361
nb = change_tracker.nb_imported
22752362
t_rewrite = time.perf_counter()
2363+
# Tracks are already enfoced by `self.node_tracker.get_trackers`
22762364
node_rewriter_change = self.process_node(
2277-
fgraph, node, node_rewriter
2365+
fgraph, node, node_rewriter, enforce_tracks=False
22782366
)
22792367
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
22802368
if not node_rewriter_change:

pytensor/graph/rewriting/unify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def match_parameters(self, op):
360360
return True
361361

362362
def __str__(self):
363-
return f"{self.op_type.__name__}({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})"
363+
return f"OpPattern({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})"
364364

365365

366366
def _unify_parametrized_op(v: Op, u: OpPattern, s: Mapping):

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,9 +1338,9 @@ def simplify_constants(self, orig_num, orig_denum, out_type=None):
13381338

13391339
return ct + num, denum
13401340

1341-
def transform(self, fgraph, node):
1341+
def transform(self, fgraph, node, enforce_tracks=True):
13421342
op = node.op
1343-
if op not in [self.main, self.inverse, self.reciprocal]:
1343+
if enforce_tracks and (op not in {self.main, self.inverse, self.reciprocal}):
13441344
return False
13451345

13461346
assert len(node.outputs) == 1

0 commit comments

Comments
 (0)