Skip to content

Commit dfef95b

Browse files
committed
Implement trie algorithm pattern matching
This allows matching multiple patterns to a single graph, avoiding repeated comparisons across similar patterns
1 parent 40ccab1 commit dfef95b

File tree

3 files changed

+455
-2
lines changed

3 files changed

+455
-2
lines changed
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Union
3+
4+
from pytensor import Variable
5+
from pytensor.graph import Op
6+
7+
8+
@dataclass(frozen=True, eq=False)
9+
class MatchPattern:
10+
name: str | None
11+
pattern: tuple
12+
13+
def __repr__(self):
14+
if self.name is not None:
15+
return self.name
16+
return str(self.pattern)
17+
18+
def __hash__(self):
19+
return id(self)
20+
21+
def __eq__(self, other):
22+
return self is other
23+
24+
25+
@dataclass(frozen=True)
26+
class Literal:
27+
# Wrapper class to signal that a pattern is a literal value, not a pattern variable
28+
pattern: Any
29+
30+
31+
@dataclass(frozen=True)
32+
class TrieNode:
33+
# Class for Op level trie nodes
34+
# Each node has edges for exact Op matches, Op type matches, variable matches, and
35+
# edges for starting parametrized Op matches (which lead to ParameterTrieNodes)
36+
# Terminal patterns are stored at the nodes where patterns end
37+
op_edges: dict[Op, "TrieNode"] = field(default_factory=dict)
38+
op_type_edges: dict[type[Op], "TrieNode"] = field(default_factory=dict)
39+
start_parameter_edges: dict[type[Op], "ParameterTrieNode"] = field(
40+
default_factory=dict
41+
)
42+
variable_edges: dict[str, "TrieNode"] = field(default_factory=dict)
43+
terminal_patterns: list[MatchPattern] = field(default_factory=list)
44+
45+
46+
@dataclass(frozen=False)
47+
class ParameterTrieNode:
48+
# Class for Op parameter level trie nodes
49+
# Each node has edges for matching Op parameters (key, pattern) pairs
50+
# (where pattern can be a variable name, an Op type, a literal value, or a nested parametrized Op (OpType, {param: value, ...}))
51+
52+
# A ParameterTrieNode may have multiple parameter edges to move to the next ParameterTrieNode
53+
# A ParameterTrieNode may have an end_parameter_edge, to move back to the outer TrieNode/ ParameterTrieNode
54+
# This allows different patterns to match a different number of parameters.
55+
# Parameters are arranged in alphabetical order to help sharing of common paths.
56+
57+
# A ParameterTrieNode may also have a sub_op_parameter_edge, to start matching parameters of a nested parametrized Op
58+
# A sub_op_parameter_edge always follows a parameter_edge for the same parameter key and op type.
59+
60+
parameter_edges: list[tuple[str, Any], "ParameterTrieNode"] = field(
61+
default_factory=list
62+
)
63+
sub_op_parameter_edge: tuple[str, "ParameterTrieNode"] | None = field(default=None)
64+
# A ParameterTrieNode may end up followed by a ParameterTrieNode, if it was a nested parametrized op
65+
# Or with a regular TrieNode, if it was the end of a parametrized op pattern
66+
end_parameter_edge: Union["TrieNode", "ParameterTrieNode"] | None = field(
67+
default=None
68+
)
69+
70+
# We can also have variable edges at the parameter level, to match parameter values that are variables
71+
variable_edges: dict[str, "ParameterTrieNode"] = field(default_factory=dict)
72+
73+
74+
@dataclass(frozen=True)
75+
class Trie:
76+
root_node: TrieNode = field(default_factory=TrieNode)
77+
78+
def add_pattern(self, pattern: MatchPattern | tuple):
79+
"""Expand Trie with new pattern"""
80+
if not isinstance(pattern, MatchPattern):
81+
pattern = MatchPattern(None, pattern)
82+
83+
def validate_head_tuple(head):
84+
# We only allow very specific head tuples (to parametrize Ops)
85+
if not isinstance(head, tuple) and len(head) == 2:
86+
raise TypeError(f"Head tuple must have exactly two entries: {head}")
87+
head_op_type, head_dict = head
88+
if not (isinstance(head_op_type, type) and issubclass(head_op_type, Op)):
89+
raise TypeError(
90+
f"Invalid type for first entry of head tuple {type(head_op_type)}: {head_op_type}. Expected type(Op)"
91+
)
92+
if not isinstance(head_dict, dict):
93+
raise TypeError(
94+
f"Invalid type for second entry of head tuple {head_dict}. Expected dict"
95+
)
96+
return head_op_type, head_dict
97+
98+
def get_parametrized_edge(parameter_edges, key, pattern) -> ParameterTrieNode:
99+
for edge in parameter_edges:
100+
(key_edge, pattern_edge), next_trie_node = edge
101+
if key != key_edge:
102+
if isinstance(pattern, type):
103+
if issubclass(pattern, pattern_edge):
104+
break
105+
elif pattern == pattern_edge:
106+
break
107+
else: # no-break, there's no trie yet for this key-pattern pair
108+
next_trie_node = ParameterTrieNode()
109+
parameter_edges.append(((key, pattern), next_trie_node))
110+
return next_trie_node
111+
112+
def recurse_with_op_parameters(trie_node, parameters, nested=False):
113+
assert isinstance(trie_node, ParameterTrieNode)
114+
if not parameters:
115+
# Base case: We consumed all the parameters. Add an end_parameter edge to signal we're done
116+
if trie_node.end_parameter_edge is None:
117+
trie_node.end_parameter_edge = (
118+
ParameterTrieNode() if nested else TrieNode()
119+
)
120+
return trie_node.end_parameter_edge
121+
122+
(item_key, item_pattern), *rest_key_pattern_pairs = parameters
123+
124+
if isinstance(item_pattern, tuple):
125+
# Nested parametrized op
126+
sub_op_type, sub_dict = validate_head_tuple(item_pattern)
127+
# Start with a parameter edge for the op parameter
128+
start_trie_node = get_parametrized_edge(
129+
trie_node.parameter_edges, item_key, sub_op_type
130+
)
131+
if sub_dict:
132+
# Add a sub_op_parameter edge to start matching the nested Op parameters
133+
# A trie node can only have one sub_op_parameter edge, since it's always preceded by a parameter edge
134+
if start_trie_node.sub_op_parameter_edge is None:
135+
start_trie_node.sub_op_parameter_edge = (
136+
item_key,
137+
ParameterTrieNode(),
138+
)
139+
(sub_op_key, sub_op_trie_node) = (
140+
start_trie_node.sub_op_parameter_edge
141+
)
142+
assert sub_op_key == item_key
143+
next_trie_node = recurse_with_op_parameters(
144+
sub_op_trie_node, sorted(sub_dict.items()), nested=True
145+
)
146+
else:
147+
# No parameters, just continue with the start_trie_node
148+
next_trie_node = start_trie_node
149+
else:
150+
# Simple parameter pattern: add a parameter edge
151+
next_trie_node = get_parametrized_edge(
152+
trie_node.parameter_edges, item_key, item_pattern
153+
)
154+
155+
# Recurse with the rest of the parameters
156+
return recurse_with_op_parameters(
157+
next_trie_node, rest_key_pattern_pairs, nested=nested
158+
)
159+
160+
def recurse(trie_node, sub_pattern):
161+
if not sub_pattern:
162+
# Base case: we've consumed the entire pattern
163+
trie_node.terminal_patterns.append(pattern)
164+
return
165+
166+
head, *tail = sub_pattern
167+
168+
if isinstance(head, tuple):
169+
if isinstance(head[0], tuple):
170+
# recurse on the head tuple, until it becomes an Op
171+
head_head, *tail_head = head
172+
return recurse(trie_node, (head_head, *tail_head, *tail))
173+
else:
174+
# Parametrized Op (OpType, {param: value, ...})
175+
head_op_type, head_dict = validate_head_tuple(head)
176+
if head_dict:
177+
# Start with an edge for the op type
178+
next_trie_node = trie_node.start_parameter_edges.get(
179+
head_op_type, None
180+
)
181+
if next_trie_node is None:
182+
trie_node.start_parameter_edges[head_op_type] = (
183+
next_trie_node
184+
) = ParameterTrieNode()
185+
# Recurse into the parameters, with parameter edges
186+
next_trie_node = recurse_with_op_parameters(
187+
next_trie_node, sorted(head_dict.items())
188+
)
189+
else:
190+
# No parameters, just add an op_type edge
191+
next_trie_node = trie_node.op_type_edges.get(head_op_type, None)
192+
if next_trie_node is None:
193+
trie_node.op_type_edges[head_op_type] = next_trie_node = (
194+
TrieNode()
195+
)
196+
else:
197+
if isinstance(head, Op):
198+
edge_type = trie_node.op_edges
199+
elif isinstance(head, type) and issubclass(head, Op):
200+
edge_type = trie_node.op_type_edges
201+
elif isinstance(head, str):
202+
edge_type = trie_node.variable_edges
203+
else:
204+
raise TypeError(f"Invalid head type {type(head)}: {head}")
205+
next_trie_node = edge_type.get(head, None)
206+
if next_trie_node is None:
207+
edge_type[head] = next_trie_node = TrieNode()
208+
209+
# Recurse with the tail of the pattern
210+
recurse(next_trie_node, tail)
211+
212+
recurse(self.root_node, pattern.pattern)
213+
214+
def match(self, variable):
215+
if not isinstance(variable, Variable):
216+
return False
217+
218+
def recurse(
219+
trie_node: TrieNode | ParameterTrieNode,
220+
subject_pattern: tuple[Variable, tuple[Variable, ...]],
221+
subs: dict[str, Any],
222+
):
223+
if isinstance(trie_node, TrieNode):
224+
# Base case, terminal patterns are successfully matched
225+
# whenever trie node is reached with no subject pattern left to unify
226+
if not subject_pattern:
227+
for terminal_pattern in trie_node.terminal_patterns:
228+
yield terminal_pattern, subs
229+
return None
230+
231+
head, *tail = subject_pattern
232+
assert isinstance(head, Variable), (type(head), head)
233+
234+
# Unify variables
235+
for variable, next_trie_node in trie_node.variable_edges.items():
236+
if variable in subs:
237+
if subs[variable] == head:
238+
yield from recurse(next_trie_node, tail, subs)
239+
else:
240+
subs_copy = subs.copy()
241+
subs_copy[variable] = head
242+
yield from recurse(next_trie_node, tail, subs_copy)
243+
244+
if head.owner is None:
245+
# head is a root variable, can only be matched to wildcard patterns above
246+
return False
247+
head_op = head.owner.op
248+
249+
# Match op type or exact op
250+
# We consume the head variable and extend the tail pattern with its inputs
251+
if (
252+
next_trie_node := trie_node.op_edges.get(head_op, None)
253+
) is not None:
254+
yield from recurse(
255+
next_trie_node, (*head.owner.inputs, *tail), subs
256+
)
257+
if (
258+
next_trie_node := trie_node.op_type_edges.get(type(head_op), None)
259+
) is not None:
260+
yield from recurse(
261+
next_trie_node, (*head.owner.inputs, *tail), subs
262+
)
263+
264+
# Match start of parametrized op pattern
265+
if (
266+
next_trie_node := trie_node.start_parameter_edges.get(
267+
type(head_op), None
268+
)
269+
) is not None:
270+
# We place the Op variable at the head of the subject pattern
271+
# And extend the tail pattern with the inputs of the head variable, just like a regular op match
272+
yield from recurse(
273+
next_trie_node, (head_op, *head.owner.inputs, *tail), subs
274+
)
275+
276+
else: # ParameterTrieNode
277+
head_op, *tail = subject_pattern
278+
assert isinstance(head_op, Op), (type(head_op), head_op)
279+
280+
# Exit parametrized op pattern matching
281+
if (next_trie_node := trie_node.end_parameter_edge) is not None:
282+
# We discard the head variable and keep working on the tail pattern
283+
yield from recurse(next_trie_node, tail, subs)
284+
285+
# Match op parameters
286+
for (
287+
op_param_key,
288+
op_param_pattern,
289+
), next_trie_node in trie_node.parameter_edges:
290+
op_param_value = getattr(head_op, op_param_key)
291+
subs_copy = subs
292+
293+
# Match variable pattern
294+
if isinstance(op_param_pattern, str):
295+
if op_param_pattern in subs:
296+
if subs[op_param_pattern] != op_param_value:
297+
continue # mismatch
298+
else:
299+
subs_copy = subs.copy()
300+
subs_copy[op_param_pattern] = op_param_value
301+
# Match op type
302+
elif isinstance(op_param_pattern, type) and issubclass(
303+
op_param_pattern, Op
304+
):
305+
if not isinstance(op_param_value, op_param_pattern):
306+
continue # mismatch
307+
# Match literal value
308+
elif isinstance(op_param_pattern, Literal):
309+
if op_param_value != op_param_pattern.pattern:
310+
continue # mismatch
311+
# Match exact value
312+
elif op_param_value != op_param_pattern:
313+
continue # mismatch
314+
315+
# We arrive here if there was no mismatch
316+
# For parameter edges, we continue to the next trie_node with the same pattern
317+
# as we may still need to check other parameters from the same Op
318+
# We'll eventually move to the tail pattern via an end_parameter edge
319+
yield from recurse(next_trie_node, subject_pattern, subs_copy)
320+
321+
# Match nested op parametrizations
322+
# This always follows an op parameter edge
323+
if trie_node.sub_op_parameter_edge is not None:
324+
(sub_op_param_key, next_trie_node) = trie_node.sub_op_parameter_edge
325+
sub_op = getattr(head_op, sub_op_param_key)
326+
# For sub_op parameter edges, we continue to the next trie_node with the sub_op as the head
327+
yield from recurse(next_trie_node, (sub_op, *subject_pattern), subs)
328+
return None
329+
330+
yield from recurse(self.root_node, (variable,), {})
331+
return None

pytensor/tensor/elemwise.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
159159
# This is the list of the original dimensions that we keep
160160
self.shuffle = [x for x in new_order if x != "x"]
161161
self.transposition = self.shuffle + drop
162-
# List of dimensions of the output that are broadcastable and were not
163-
# in the original input
162+
# List of dimensions of the output that are broadcastable and were not in the original input
164163
self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x")
165164
self.drop = drop
166165

@@ -175,6 +174,12 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
175174
self.is_right_expand_dims = self.is_expand_dims and new_order[
176175
:input_ndim
177176
] == list(range(input_ndim))
177+
self.is_matrix_transpose = False
178+
if dims_are_shuffled and (not drop) and input_ndim >= 2:
179+
# We consider a matrix transpose if we only flip the last two dims
180+
# Regardless of whethre there's an expand_dims or not
181+
mt_pattern = [*range(input_ndim - 2), input_ndim - 1, input_ndim - 2]
182+
self.is_matrix_transpose = new_order[len(augment) :] == mt_pattern
178183

179184
def __setstate__(self, state):
180185
self.__dict__.update(state)

0 commit comments

Comments
 (0)