Skip to content

Commit e0bbed2

Browse files
committed
Allow unifying with OpPatterns
1 parent 808528e commit e0bbed2

File tree

4 files changed

+311
-40
lines changed

4 files changed

+311
-40
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytensor.graph.features import AlreadyThere, Feature
3030
from pytensor.graph.fg import FunctionGraph, Output
3131
from pytensor.graph.op import Op
32-
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
32+
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
3333
from pytensor.graph.utils import AssocList, InconsistencyError
3434
from pytensor.misc.ordered_set import OrderedSet
3535
from pytensor.utils import flatten
@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
13121312
The input and output patterns have the following syntax:
13131313
13141314
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1315+
input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
13151316
input_pattern ::= dict(pattern = <input_pattern>,
13161317
constraint = <constraint>)
13171318
sub_pattern ::= input_pattern
@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
13251326
output_pattern ::= string
13261327
output_pattern ::= int
13271328
output_pattern ::= float
1329+
output_pattern ::= callable
13281330
13291331
Each string in the input pattern is a variable that will be set to
13301332
whatever expression is found in its place. If the same string is
@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
13501352
Examples
13511353
--------
13521354
1353-
PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1354-
PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1355-
PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1356-
PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1357-
PatternNodeRewriter((boggle, {'pattern': 'x',
1358-
'constraint': lambda expr: expr.type == scrabble}),
1359-
(scrabble, 'x'))
1355+
.. code-block:: python
13601356
1357+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1358+
from pytensor.tensor import add, mul, sub, pow, square
1359+
1360+
PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1361+
PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1362+
PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1363+
PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1364+
PatternNodeRewriter(
1365+
(mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1366+
(mul, "y", "x"),
1367+
)
1368+
1369+
You can use OpPattern to match a subtype of an Op, with some parameter constraints
1370+
You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1371+
1372+
1373+
.. code-block:: python
1374+
1375+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1376+
from pytensor.graph.rewriting.unify import OpPattern
1377+
from pytensor.tensor.basic import Join
1378+
from pytensor.tensor.elemwise import CAReduce, Elemwise
1379+
1380+
1381+
def output_fn(fgraph, node, s):
1382+
reduce_op = node.op
1383+
reduced_a = reduce_op(s["a"])
1384+
reduced_b = reduce_op(s["b"])
1385+
return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1386+
1387+
1388+
PatternNodeRewriter(
1389+
(
1390+
OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
1391+
(Join(), "join_axis", "a", "b"),
1392+
),
1393+
output_fn,
1394+
)
1395+
1396+
1397+
If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1398+
1399+
.. code-block:: python
1400+
1401+
1402+
from pytensor.graph.rewriting.basic import PatternNodeRewriter
1403+
from pytensor.graph.rewriting.unify import OpPattern, LiteralString
1404+
from pytensor.tensor.blockwise import Blockwise
1405+
from pytensor.tensor.slinalg import Solve
1406+
1407+
PatternNodeRewriter(
1408+
(
1409+
OpPattern(
1410+
Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
1411+
),
1412+
"A",
1413+
"b",
1414+
)
1415+
)
13611416
"""
13621417

13631418
def __init__(
13641419
self,
1365-
in_pattern,
1366-
out_pattern,
1420+
in_pattern: tuple,
1421+
out_pattern: tuple | Callable,
13671422
allow_multiple_clients: bool = False,
13681423
name: str | None = None,
13691424
tracks=(),
@@ -1378,7 +1433,7 @@ def __init__(
13781433
in_pattern
13791434
The input pattern that we want to replace.
13801435
out_pattern
1381-
The replacement pattern.
1436+
The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs
13821437
allow_multiple_clients
13831438
If ``False``, the pattern matching will fail if one of the subpatterns has
13841439
more than one client.
@@ -1407,26 +1462,35 @@ def __init__(
14071462
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
14081463
self.values_eq_approx = values_eq_approx
14091464
self.allow_cast = allow_cast
1410-
if isinstance(in_pattern, list | tuple):
1411-
self.op = self.in_pattern[0]
1412-
elif isinstance(in_pattern, dict):
1413-
self.op = self.in_pattern["pattern"][0]
1414-
else:
1415-
raise TypeError(
1416-
"The pattern to search for must start with a specific Op instance."
1417-
)
14181465
self.allow_multiple_clients = allow_multiple_clients
14191466
if name:
14201467
self.__name__ = name
1421-
self._tracks = tracks
14221468
self.get_nodes = get_nodes
14231469
if tracks != ():
1424-
assert get_nodes
1470+
if not get_nodes:
1471+
raise ValueError("Custom `tracks` requires `get_nodes` to be provided.")
1472+
self._tracks = tracks
1473+
else:
1474+
if isinstance(in_pattern, list | tuple):
1475+
op = self.in_pattern[0]
1476+
elif isinstance(in_pattern, dict):
1477+
op = self.in_pattern["pattern"][0]
1478+
else:
1479+
raise TypeError(
1480+
"The pattern to search for must start with a specific Op instance."
1481+
)
1482+
if isinstance(op, Op):
1483+
self._tracks = [op]
1484+
elif isinstance(op, OpPattern):
1485+
self._tracks = [op.op_type]
1486+
else:
1487+
raise ValueError(
1488+
f"The pattern to search for must start with a specific Op instance or an OpPattern class. "
1489+
f"Got {op}, with type {type(op)}."
1490+
)
14251491

14261492
def tracks(self):
1427-
if self._tracks != ():
1428-
return self._tracks
1429-
return [self.op]
1493+
return self._tracks
14301494

14311495
def transform(self, fgraph, node, get_nodes=True):
14321496
"""Check if the graph from node corresponds to ``in_pattern``.
@@ -1447,28 +1511,39 @@ def transform(self, fgraph, node, get_nodes=True):
14471511
# PatternNodeRewriter doesn't support replacing multi-output nodes
14481512
return False
14491513

1450-
s = unify(self.in_pattern, node.out)
1514+
s = unify(self.in_pattern, node.out, {})
14511515

14521516
if s is False:
14531517
return False
14541518

1455-
ret = reify(self.out_pattern, s)
1456-
1457-
if isinstance(ret, ExpressionTuple):
1458-
ret = ret.evaled_obj
1459-
1460-
if self.values_eq_approx:
1461-
ret.tag.values_eq_approx = self.values_eq_approx
1462-
14631519
if not self.allow_multiple_clients:
1464-
input_vars = list(s.values())
1520+
input_vars = set(s.values())
1521+
clients = fgraph.clients
14651522
if any(
1466-
len(fgraph.clients[v]) > 1
1523+
len(clients[v]) > 1
14671524
for v in vars_between(input_vars, node.inputs)
14681525
if v not in input_vars
14691526
):
14701527
return False
14711528

1529+
if callable(self.out_pattern):
1530+
# token is the variable name used in the original pattern
1531+
ret = self.out_pattern(fgraph, node, {k.token: v for k, v in s.items()})
1532+
if ret is None or ret is False:
1533+
# The output function is still allowed to reject the rewrite
1534+
return False
1535+
if not isinstance(ret, Variable):
1536+
raise ValueError(
1537+
f"The output of the PatternNodeRewriter callable must be a variable got {ret} of type {type(ret)}."
1538+
)
1539+
else:
1540+
ret = reify(self.out_pattern, s)
1541+
if isinstance(ret, ExpressionTuple):
1542+
ret = ret.evaled_obj
1543+
1544+
if self.values_eq_approx:
1545+
ret.tag.values_eq_approx = self.values_eq_approx
1546+
14721547
[old_out] = node.outputs
14731548
if not old_out.type.is_super(ret.type):
14741549
from pytensor.tensor.type import TensorType

pytensor/graph/rewriting/unify.py

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
1111
"""
1212

13-
from collections.abc import Mapping
13+
from collections.abc import Mapping, Sequence
14+
from dataclasses import dataclass
1415
from numbers import Number
16+
from types import UnionType
17+
from typing import Any
1518

1619
import numpy as np
1720
from cons.core import ConsError, _car, _cdr
@@ -254,6 +257,128 @@ def _unify_ConstrainedVar_object(u, v, s):
254257
_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object)
255258

256259

260+
@dataclass(frozen=True)
261+
class LiteralString:
262+
value: str
263+
264+
265+
@dataclass(unsafe_hash=True)
266+
class OpPattern:
267+
"""Class that can be unified with Op instances of a given type and parameters.
268+
269+
An op instance is unified as long as the parameters specified in the OpPattern can be unified as well.
270+
Parameters that are not specified in the OpPattern are ignored during unification.
271+
272+
This is needed because some Ops can be complex to parametrize fully,
273+
and not all parameters are relevant for a given pattern.
274+
275+
Examples
276+
--------
277+
278+
.. testcode::
279+
280+
from unification import var, unify
281+
from etuples import etuple
282+
283+
import pytensor.tensor as pt
284+
from pytensor.graph.rewriting.unify import OpPattern
285+
from pytensor.tensor.blockwise import Blockwise
286+
from pytensor.tensor.slinalg import Solve
287+
288+
A = var("A")
289+
b = var("b")
290+
pattern = etuple(
291+
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a="gen")), A, b
292+
)
293+
294+
A_pt = pt.tensor3("A")
295+
b_pt = pt.tensor3("b")
296+
out1 = pt.linalg.solve(A_pt, b_pt)
297+
out2 = pt.linalg.solve(A_pt, b_pt, assume_a="pos")
298+
299+
assert unify(pattern, out1) == {A: A_pt, b: b_pt}
300+
assert unify(pattern, out2) is False
301+
302+
assume_a = var("assume_a")
303+
pattern = etuple(
304+
OpPattern(Blockwise, core_op=OpPattern(Solve, assume_a=assume_a)),
305+
A,
306+
b,
307+
)
308+
assert unify(pattern, out1) == {A: A_pt, b: b_pt, assume_a: "gen"}
309+
assert unify(pattern, out2) == {A: A_pt, b: b_pt, assume_a: "pos"}
310+
311+
312+
"""
313+
314+
op_type: type[Op] | tuple[type[Op]] | UnionType
315+
parameters: tuple[str, Any]
316+
317+
def __init__(
318+
self,
319+
op_type: type[Op] | UnionType | tuple[type[Op]],
320+
parameters: dict[str, Any] | Sequence[tuple[str, Any]] | None = None,
321+
**kwargs,
322+
):
323+
if kwargs:
324+
if parameters is not None:
325+
raise ValueError(
326+
"Cannot provide both parameters dict and keyword arguments"
327+
)
328+
parameters = kwargs
329+
if isinstance(parameters, dict):
330+
parameters = tuple(sorted(parameters.items()))
331+
elif isinstance(parameters, list | tuple):
332+
parameters = tuple(sorted(parameters))
333+
elif parameters is None:
334+
parameters = ()
335+
self.op_type = op_type
336+
self.parameters = parameters
337+
338+
def match_op(self, op: Op):
339+
if not isinstance(op, self.op_type):
340+
return False
341+
return self.match_parameters(op)
342+
343+
def match_parameters(self, op):
344+
# This is used by methods that already check the op_type is satisfied
345+
# Some methods may index on the op_type and know in advance the op is matched
346+
# Also recursive calls to OpPattern.match_parameters do the op check outside to exit early (see below)
347+
for key, param in self.parameters:
348+
if isinstance(param, OpPattern):
349+
# Parameters can itself be other OpPatterns
350+
# We check the op_type to avoid a nested call in cases we can reject early
351+
sub_op = getattr(op, key)
352+
if not isinstance(sub_op, param.op_type):
353+
return False
354+
# Match the pattern of the inner Op
355+
# Skip if there are no parameters
356+
if param.parameters and not param.match_parameters(sub_op):
357+
return False
358+
elif getattr(op, key) != param:
359+
return False
360+
return True
361+
362+
def __str__(self):
363+
return f"{self.op_type.__name__}({self.op_type}, {', '.join(f'{k}={v}' for k, v in self.parameters)})"
364+
365+
366+
def _unify_parametrized_op(v: Op, u: OpPattern, s: Mapping):
367+
if not isinstance(v, u.op_type):
368+
yield False
369+
return
370+
for parameter_key, parameter_pattern in u.parameters:
371+
parameter_value = getattr(v, parameter_key)
372+
s = yield _unify(parameter_value, parameter_pattern, s)
373+
if s is False:
374+
yield False
375+
return
376+
yield s
377+
378+
379+
_unify.add((Op, OpPattern, Mapping), _unify_parametrized_op)
380+
381+
257382
def convert_strs_to_vars(
258383
x: tuple | str | dict, var_map: dict[str, Var] | None = None
259384
) -> ExpressionTuple | Var:
@@ -266,11 +391,13 @@ def convert_strs_to_vars(
266391
if var_map is None:
267392
var_map = {}
268393

269-
def _convert(y):
394+
def _convert(y, op_prop=False):
270395
if isinstance(y, str):
271396
v = var_map.get(y, var(y))
272397
var_map[y] = v
273398
return v
399+
if isinstance(y, LiteralString):
400+
return y.value
274401
elif isinstance(y, dict):
275402
pattern = y["pattern"]
276403
if not isinstance(pattern, str):
@@ -282,8 +409,14 @@ def _convert(y):
282409
var_map[pattern] = v
283410
return v
284411
elif isinstance(y, tuple):
285-
return etuple(*(_convert(e) for e in y))
286-
elif isinstance(y, Number | np.ndarray):
412+
return etuple(*(_convert(e, op_prop=op_prop) for e in y))
413+
elif isinstance(y, OpPattern):
414+
return OpPattern(
415+
y.op_type,
416+
{k: _convert(v, op_prop=True) for k, v in y.parameters},
417+
)
418+
elif (not op_prop) and isinstance(y, Number | np.ndarray):
419+
# If we are converting an Op property, we don't want to convert numbers to PyTensor constants
287420
from pytensor.tensor import as_tensor_variable
288421

289422
return as_tensor_variable(y)

0 commit comments

Comments
 (0)