29
29
from pytensor .graph .features import AlreadyThere , Feature
30
30
from pytensor .graph .fg import FunctionGraph , Output
31
31
from pytensor .graph .op import Op
32
- from pytensor .graph .rewriting .unify import Var , convert_strs_to_vars
32
+ from pytensor .graph .rewriting .unify import OpPattern , Var , convert_strs_to_vars
33
33
from pytensor .graph .utils import AssocList , InconsistencyError
34
34
from pytensor .misc .ordered_set import OrderedSet
35
35
from pytensor .utils import flatten
@@ -1312,6 +1312,7 @@ class PatternNodeRewriter(NodeRewriter):
1312
1312
The input and output patterns have the following syntax:
1313
1313
1314
1314
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
1315
+ input_pattern ::= (OpPattern(type(op), {<param>: <value>, ...}), <sub_pattern1>, <sub_pattern2>, ...)
1315
1316
input_pattern ::= dict(pattern = <input_pattern>,
1316
1317
constraint = <constraint>)
1317
1318
sub_pattern ::= input_pattern
@@ -1325,6 +1326,7 @@ class PatternNodeRewriter(NodeRewriter):
1325
1326
output_pattern ::= string
1326
1327
output_pattern ::= int
1327
1328
output_pattern ::= float
1329
+ output_pattern ::= callable
1328
1330
1329
1331
Each string in the input pattern is a variable that will be set to
1330
1332
whatever expression is found in its place. If the same string is
@@ -1350,20 +1352,73 @@ class PatternNodeRewriter(NodeRewriter):
1350
1352
Examples
1351
1353
--------
1352
1354
1353
- PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x'))
1354
- PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x'))
1355
- PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x')
1356
- PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x'))
1357
- PatternNodeRewriter((boggle, {'pattern': 'x',
1358
- 'constraint': lambda expr: expr.type == scrabble}),
1359
- (scrabble, 'x'))
1355
+ .. code-block:: python
1360
1356
1357
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1358
+ from pytensor.tensor import add, mul, sub, pow, square
1359
+
1360
+ PatternNodeRewriter((add, "x", "y"), (add, "y", "x"))
1361
+ PatternNodeRewriter((mul, "x", "x"), (square, "x"))
1362
+ PatternNodeRewriter((sub, (add, "x", "y"), "y"), "x")
1363
+ PatternNodeRewriter((pow, "x", 2.0), (square, "x"))
1364
+ PatternNodeRewriter(
1365
+ (mul, {"pattern": "x", "constraint": lambda expr: expr.ndim == 0}, "y"),
1366
+ (mul, "y", "x"),
1367
+ )
1368
+
1369
+ You can use OpPattern to match a subtype of an Op, with some parameter constraints
1370
+ You can also specify a callable as the output pattern, which will be called with (fgraph, node, subs_dict) as arguments.
1371
+
1372
+
1373
+ .. code-block:: python
1374
+
1375
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1376
+ from pytensor.graph.rewriting.unify import OpPattern
1377
+ from pytensor.tensor.basic import Join
1378
+ from pytensor.tensor.elemwise import CAReduce, Elemwise
1379
+
1380
+
1381
+ def output_fn(fgraph, node, s):
1382
+ reduce_op = node.op
1383
+ reduced_a = reduce_op(s["a"])
1384
+ reduced_b = reduce_op(s["b"])
1385
+ return Elemwise(s["scalar_op"])(reduced_a, reduced_b)
1386
+
1387
+
1388
+ PatternNodeRewriter(
1389
+ (
1390
+ OpPattern(CAReduce, scalar_op="scalar_op", axis=None),
1391
+ (Join(), "join_axis", "a", "b"),
1392
+ ),
1393
+ output_fn,
1394
+ )
1395
+
1396
+
1397
+ If you want to test a string parameter, you must use LiteralString to avoid it being interpreted as a unification variable.
1398
+
1399
+ .. code-block:: python
1400
+
1401
+
1402
+ from pytensor.graph.rewriting.basic import PatternNodeRewriter
1403
+ from pytensor.graph.rewriting.unify import OpPattern, LiteralString
1404
+ from pytensor.tensor.blockwise import Blockwise
1405
+ from pytensor.tensor.slinalg import Solve
1406
+
1407
+ PatternNodeRewriter(
1408
+ (
1409
+ OpPattern(
1410
+ Blockwise, core_op=OpPattern(Solve, assume_a=LiteralString("gen"))
1411
+ ),
1412
+ "A",
1413
+ "b",
1414
+ )
1415
+ )
1361
1416
"""
1362
1417
1363
1418
def __init__ (
1364
1419
self ,
1365
- in_pattern ,
1366
- out_pattern ,
1420
+ in_pattern : tuple ,
1421
+ out_pattern : tuple | Callable ,
1367
1422
allow_multiple_clients : bool = False ,
1368
1423
name : str | None = None ,
1369
1424
tracks = (),
@@ -1378,7 +1433,7 @@ def __init__(
1378
1433
in_pattern
1379
1434
The input pattern that we want to replace.
1380
1435
out_pattern
1381
- The replacement pattern.
1436
+ The replacement pattern. Or a callable that takes (fgraph, node, subs_dict) as inputs
1382
1437
allow_multiple_clients
1383
1438
If ``False``, the pattern matching will fail if one of the subpatterns has
1384
1439
more than one client.
@@ -1407,26 +1462,35 @@ def __init__(
1407
1462
self .out_pattern = convert_strs_to_vars (out_pattern , var_map = var_map )
1408
1463
self .values_eq_approx = values_eq_approx
1409
1464
self .allow_cast = allow_cast
1410
- if isinstance (in_pattern , list | tuple ):
1411
- self .op = self .in_pattern [0 ]
1412
- elif isinstance (in_pattern , dict ):
1413
- self .op = self .in_pattern ["pattern" ][0 ]
1414
- else :
1415
- raise TypeError (
1416
- "The pattern to search for must start with a specific Op instance."
1417
- )
1418
1465
self .allow_multiple_clients = allow_multiple_clients
1419
1466
if name :
1420
1467
self .__name__ = name
1421
- self ._tracks = tracks
1422
1468
self .get_nodes = get_nodes
1423
1469
if tracks != ():
1424
- assert get_nodes
1470
+ if not get_nodes :
1471
+ raise ValueError ("Custom `tracks` requires `get_nodes` to be provided." )
1472
+ self ._tracks = tracks
1473
+ else :
1474
+ if isinstance (in_pattern , list | tuple ):
1475
+ op = self .in_pattern [0 ]
1476
+ elif isinstance (in_pattern , dict ):
1477
+ op = self .in_pattern ["pattern" ][0 ]
1478
+ else :
1479
+ raise TypeError (
1480
+ "The pattern to search for must start with a specific Op instance."
1481
+ )
1482
+ if isinstance (op , Op ):
1483
+ self ._tracks = [op ]
1484
+ elif isinstance (op , OpPattern ):
1485
+ self ._tracks = [op .op_type ]
1486
+ else :
1487
+ raise ValueError (
1488
+ f"The pattern to search for must start with a specific Op instance or an OpPattern class. "
1489
+ f"Got { op } , with type { type (op )} ."
1490
+ )
1425
1491
1426
1492
def tracks (self ):
1427
- if self ._tracks != ():
1428
- return self ._tracks
1429
- return [self .op ]
1493
+ return self ._tracks
1430
1494
1431
1495
def transform (self , fgraph , node , get_nodes = True ):
1432
1496
"""Check if the graph from node corresponds to ``in_pattern``.
@@ -1447,28 +1511,39 @@ def transform(self, fgraph, node, get_nodes=True):
1447
1511
# PatternNodeRewriter doesn't support replacing multi-output nodes
1448
1512
return False
1449
1513
1450
- s = unify (self .in_pattern , node .out )
1514
+ s = unify (self .in_pattern , node .out , {} )
1451
1515
1452
1516
if s is False :
1453
1517
return False
1454
1518
1455
- ret = reify (self .out_pattern , s )
1456
-
1457
- if isinstance (ret , ExpressionTuple ):
1458
- ret = ret .evaled_obj
1459
-
1460
- if self .values_eq_approx :
1461
- ret .tag .values_eq_approx = self .values_eq_approx
1462
-
1463
1519
if not self .allow_multiple_clients :
1464
- input_vars = list (s .values ())
1520
+ input_vars = set (s .values ())
1521
+ clients = fgraph .clients
1465
1522
if any (
1466
- len (fgraph . clients [v ]) > 1
1523
+ len (clients [v ]) > 1
1467
1524
for v in vars_between (input_vars , node .inputs )
1468
1525
if v not in input_vars
1469
1526
):
1470
1527
return False
1471
1528
1529
+ if callable (self .out_pattern ):
1530
+ # token is the variable name used in the original pattern
1531
+ ret = self .out_pattern (fgraph , node , {k .token : v for k , v in s .items ()})
1532
+ if ret is None or ret is False :
1533
+ # The output function is still allowed to reject the rewrite
1534
+ return False
1535
+ if not isinstance (ret , Variable ):
1536
+ raise ValueError (
1537
+ f"The output of the PatternNodeRewriter callable must be a variable got { ret } of type { type (ret )} ."
1538
+ )
1539
+ else :
1540
+ ret = reify (self .out_pattern , s )
1541
+ if isinstance (ret , ExpressionTuple ):
1542
+ ret = ret .evaled_obj
1543
+
1544
+ if self .values_eq_approx :
1545
+ ret .tag .values_eq_approx = self .values_eq_approx
1546
+
1472
1547
[old_out ] = node .outputs
1473
1548
if not old_out .type .is_super (ret .type ):
1474
1549
from pytensor .tensor .type import TensorType
0 commit comments