@@ -141,7 +141,12 @@ def tracks(self) -> Sequence[Op] | None:
141
141
142
142
@abc .abstractmethod
143
143
def transform (
144
- self , fgraph : FunctionGraph , node : Apply , * args , ** kwargs
144
+ self ,
145
+ fgraph : FunctionGraph ,
146
+ node : Apply ,
147
+ enforce_tracks : bool = True ,
148
+ * args ,
149
+ ** kwargs ,
145
150
) -> TransformOutputType :
146
151
r"""Rewrite the sub-graph given by `node`.
147
152
@@ -159,7 +164,9 @@ def transform(
159
164
A `FunctionGraph` containing `node`.
160
165
node
161
166
An `Apply` node to be rewritten.
162
-
167
+ enforce_tracks: bool
168
+ Whether the transform method should enforce tracks, or it can be assumed the caller already enforced them in a pre-filter stage.
169
+ See `node_rewriter` tracks argument for more details.
163
170
"""
164
171
165
172
raise NotImplementedError ()
@@ -935,15 +942,48 @@ class FromFunctionNodeRewriter(NodeRewriter):
935
942
def __init__ (self , fn , tracks = None , requirements = ()):
936
943
self .fn = fn
937
944
self ._tracks = tracks
938
- self ._tracked_types = (
939
- tuple (t for t in tracks if isinstance (t , type )) if tracks else ()
940
- )
945
+ self ._tracked_ops = set ()
946
+ self ._tracked_types = type (None )
947
+ self ._tracked_op_pattern_types = type (None )
948
+ self ._tracked_op_patterns : list [OpPattern ] = []
949
+ if tracks is not None :
950
+ if not tracks :
951
+ raise ValueError (
952
+ "To specify a general rewrite leave tracks as None instead of an empty container"
953
+ )
954
+ for t in tracks :
955
+ if isinstance (t , Op ):
956
+ self ._tracked_ops .add (t )
957
+ elif isinstance (t , type ):
958
+ self ._tracked_types |= t
959
+ elif isinstance (t , OpPattern ):
960
+ if t .parameters :
961
+ self ._tracked_op_patterns .append (t )
962
+ self ._tracked_op_pattern_types |= t .op_type
963
+ else :
964
+ # An OpPattern without parameters behaves like a regular tracked_type
965
+ self ._tracked_types |= t
966
+ else :
967
+ raise TypeError (
968
+ "`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
969
+ f"Got { t } of type { type (t )} "
970
+ )
941
971
self .requirements = requirements
942
972
943
- def transform (self , fgraph , node ):
944
- if self ._tracks :
973
+ def transform (self , fgraph , node , enforce_tracks : bool = True ):
974
+ if enforce_tracks and self ._tracks :
975
+ node_op = node .op
945
976
if not (
946
- node .op in self ._tracks or isinstance (node .op , self ._tracked_types )
977
+ node_op in self ._tracked_ops
978
+ or isinstance (node_op , self ._tracked_types )
979
+ or (
980
+ isinstance (node .op , self ._tracked_op_pattern_types )
981
+ and any (
982
+ t .match_parameters (node_op )
983
+ for t in self ._tracked_op_patterns
984
+ if isinstance (node_op , t .op_type )
985
+ )
986
+ )
947
987
):
948
988
return False
949
989
@@ -967,7 +1007,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
967
1007
968
1008
969
1009
def node_rewriter (
970
- tracks : Sequence [Op | type ] | None ,
1010
+ tracks : Sequence [Op | type , OpPattern ] | None ,
971
1011
inplace : bool = False ,
972
1012
requirements : tuple [type , ...] | None = (),
973
1013
):
@@ -976,7 +1016,7 @@ def node_rewriter(
976
1016
Parameters
977
1017
----------
978
1018
tracks
979
- The `Op` types or instances to which this rewrite applies.
1019
+ The `Op` type, instances or `OpPattern` to which this rewrite applies.
980
1020
Use ``None`` instead of an empty list to have the rewrite apply to
981
1021
all `Op`\s.
982
1022
inplace
@@ -995,14 +1035,16 @@ def decorator(f):
995
1035
if tracks is not None :
996
1036
if len (tracks ) == 0 :
997
1037
raise ValueError (
998
- "Use `None` instead of an empty list to make an rewrite apply to all nodes."
1038
+ "Use `None` instead of an empty list to make a rewrite apply to all nodes."
999
1039
)
1000
1040
for t in tracks :
1001
1041
if not (
1002
- isinstance (t , Op ) or (isinstance (t , type ) and issubclass (t , Op ))
1042
+ isinstance (t , Op | OpPattern )
1043
+ or (isinstance (t , type ) and issubclass (t , Op ))
1003
1044
):
1004
1045
raise TypeError (
1005
- "`tracks` must consist of `Op` classes or instances."
1046
+ "`tracks` must consist of `Op` classes, `Op` instances or `OpPattern` instances. "
1047
+ f"Got { t } of type { type (t )} "
1006
1048
)
1007
1049
req = requirements
1008
1050
if inplace :
@@ -1024,47 +1066,93 @@ class OpToRewriterTracker:
1024
1066
def __init__ (self ):
1025
1067
self .tracked_instances : dict [Op , list [NodeRewriter ]] = defaultdict (list )
1026
1068
self .tracked_types : dict [type , list [NodeRewriter ]] = defaultdict (list )
1069
+ self .tracked_pattern_types : dict [type , dict [OpPattern , list [NodeRewriter ]]] = (
1070
+ defaultdict (lambda : defaultdict (list ))
1071
+ )
1027
1072
self .untracked_rewrites : list [NodeRewriter ] = []
1073
+ self ._cached_composed_mro = None
1028
1074
1029
1075
def add_tracker (self , rw : NodeRewriter ):
1030
1076
"""Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally."""
1077
+ if self ._cached_composed_mro is not None :
1078
+ # We shouldn't actually add_trackers after the first call to get_trackers
1079
+ # But just to be safe we kill the cache here
1080
+ self ._cached_composed_mro = None
1081
+
1031
1082
tracks = rw .tracks ()
1032
1083
1033
1084
if tracks is None :
1034
1085
self .untracked_rewrites .append (rw )
1035
1086
else :
1036
1087
for c in tracks :
1088
+ if isinstance (c , OpPattern ):
1089
+ if not isinstance (c .op_type , type ):
1090
+ # OpPattern allows anything that you can check with isinstance(op, op_type),
1091
+ # including tuples or union types. But for OpToRewriterTracker we need a single type.
1092
+ raise NotImplementedError (
1093
+ "OpToRewriterTracker requires the outermost `OpPattern.op_type` to be a type. "
1094
+ f"Got { c .op_type } of type { type (c .op_type )} "
1095
+ )
1096
+
1097
+ if c .parameters :
1098
+ self .tracked_pattern_types [c .op_type ][c ].append (rw )
1099
+ else :
1100
+ # An OpPattern without parameters behaves like a regular tracked_type
1101
+ self .tracked_types [c .op_type ].append (rw )
1037
1102
if isinstance (c , type ):
1038
1103
self .tracked_types [c ].append (rw )
1039
1104
else :
1040
1105
self .tracked_instances [c ].append (rw )
1041
1106
1042
- def _find_impl (self , cls ) -> list [NodeRewriter ]:
1043
- r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
1107
+ @functools .cache
1108
+ def get_trackers (self , op : Op ) -> list [NodeRewriter ]:
1109
+ """Get all the rewrites applicable to an `Op`."""
1110
+
1111
+ if self ._cached_composed_mro is None :
1112
+ # Cache the mro call on the Op type. We have a small subset of op_types we actually care about
1113
+ # like Elemwise, Blockwise, and so on, which we don't need to repeatedly investigate
1114
+ tracked_types = (
1115
+ self .tracked_types .keys () | self .tracked_pattern_types .keys ()
1116
+ )
1117
+
1118
+ @functools .cache
1119
+ def cached_composed_mro (op_type , tracked_types = tracked_types ):
1120
+ return _compose_mro (op_type , tracked_types )
1121
+
1122
+ self ._cached_composed_mro = cached_composed_mro
1044
1123
1045
- This based on `functools._find_impl`.
1046
- """
1047
- mro = _compose_mro (cls , self .tracked_types .keys ())
1048
1124
matches = []
1049
- for t in mro :
1050
- match = self .tracked_types .get (t , None )
1051
- if match :
1052
- matches .extend (match )
1125
+ if self .tracked_types or self .tracked_pattern_types :
1126
+ # Find matches for type(op) (and their subclasses) using the same approach that functools.singledispatch uses
1127
+ mro = self ._cached_composed_mro (type (op ))
1128
+ for t in mro :
1129
+ if (match := self .tracked_types .get (t , None )) is not None :
1130
+ matches .extend (match )
1131
+ if (
1132
+ potential_matches := self .tracked_pattern_types .get (t , None )
1133
+ ) is not None :
1134
+ # We still need to check if the Op parameters match the constraints
1135
+ matches .extend (
1136
+ [
1137
+ item
1138
+ for op_pattern , r_list in potential_matches .items ()
1139
+ if op_pattern .match_parameters (op )
1140
+ for item in r_list
1141
+ ]
1142
+ )
1143
+ matches .extend (self .tracked_instances .get (op , []))
1144
+ matches .extend (self .untracked_rewrites )
1053
1145
return matches
1054
1146
1055
- @functools .lru_cache
1056
- def get_trackers (self , op : Op ) -> list [NodeRewriter ]:
1057
- """Get all the rewrites applicable to `op`."""
1058
- return (
1059
- self ._find_impl (type (op ))
1060
- + self .tracked_instances .get (op , [])
1061
- + self .untracked_rewrites
1062
- )
1063
-
1064
- def get_rewriters (self ):
1147
+ def get_rewriters (self ) -> Iterable [NodeRewriter ]:
1148
+ """Get all the registered rewriters."""
1065
1149
return chain (
1150
+ chain .from_iterable (self .tracked_types .values ()),
1151
+ chain .from_iterable (self .tracked_instances .values ()),
1066
1152
chain .from_iterable (
1067
- chain (self .tracked_types .values (), self .tracked_instances .values ())
1153
+ item
1154
+ for sub_dict in self .tracked_pattern_types .values ()
1155
+ for item in sub_dict .values ()
1068
1156
),
1069
1157
self .untracked_rewrites ,
1070
1158
)
@@ -1138,7 +1226,7 @@ def tracks(self):
1138
1226
t .extend (at )
1139
1227
return t
1140
1228
1141
- def transform (self , fgraph , node ):
1229
+ def transform (self , fgraph , node , enforce_tracks = False ):
1142
1230
if len (self .rewrites ) == 0 :
1143
1231
return
1144
1232
@@ -1150,7 +1238,8 @@ def transform(self, fgraph, node):
1150
1238
new_repl = None
1151
1239
for rewrite in rewrites :
1152
1240
rewrite_start = time .perf_counter ()
1153
- new_repl = rewrite .transform (fgraph , node )
1241
+ # Tracks are already enforced by `self.tracker.get_trackers`
1242
+ new_repl = rewrite .transform (fgraph , node , enforce_tracks = False )
1154
1243
rewrite_finish = time .perf_counter ()
1155
1244
if self .profile :
1156
1245
self .time_rewrites [rewrite ] += rewrite_start - rewrite_finish
@@ -1292,8 +1381,8 @@ def __init__(self, op1, op2, transfer_tags=True):
1292
1381
def tracks (self ):
1293
1382
return [self .op1 ]
1294
1383
1295
- def transform (self , fgraph , node ):
1296
- if node .op != self .op1 :
1384
+ def transform (self , fgraph , node , enforce_tracks = True ):
1385
+ if enforce_tracks and ( node .op != self .op1 ) :
1297
1386
return False
1298
1387
repl = self .op2 .make_node (* node .inputs )
1299
1388
if self .transfer_tags :
@@ -1498,7 +1587,7 @@ def __init__(
1498
1587
def tracks (self ):
1499
1588
return self ._tracks
1500
1589
1501
- def transform (self , fgraph , node , get_nodes = True ):
1590
+ def transform (self , fgraph , node , enforce_tracks : bool = False , get_nodes = True ):
1502
1591
"""Check if the graph from node corresponds to ``in_pattern``.
1503
1592
1504
1593
If it does, it constructs ``out_pattern`` and performs the replacement.
@@ -1788,6 +1877,7 @@ def process_node(
1788
1877
fgraph : FunctionGraph ,
1789
1878
node : Apply ,
1790
1879
node_rewriter : NodeRewriter | None = None ,
1880
+ enforce_tracks : bool = True ,
1791
1881
):
1792
1882
r"""Apply `node_rewriter` to `node`.
1793
1883
@@ -1805,6 +1895,9 @@ def process_node(
1805
1895
node_rewriter
1806
1896
A `NodeRewriter` instance that may have a better idea for
1807
1897
how to compute node's outputs.
1898
+ enforce_tracks: bool
1899
+ Whether the transform method should enforce tracks,
1900
+ or it can be assumed the caller already enforced them in a pre-filter stage.
1808
1901
1809
1902
Returns
1810
1903
-------
@@ -1820,7 +1913,9 @@ def process_node(
1820
1913
# TODO FIXME: This class's interface is broken
1821
1914
assert node_rewriter is not None
1822
1915
try :
1823
- replacements = node_rewriter .transform (fgraph , node )
1916
+ replacements = node_rewriter .transform (
1917
+ fgraph , node , enforce_tracks = enforce_tracks
1918
+ )
1824
1919
except Exception as e :
1825
1920
if self .failure_callback is not None :
1826
1921
self .failure_callback (
@@ -1938,7 +2033,8 @@ def importer(node):
1938
2033
if node not in fgraph .apply_nodes :
1939
2034
continue
1940
2035
current_node = node
1941
- nb += self .process_node (fgraph , node )
2036
+ # This rewriter does not enforce tracks itself
2037
+ nb += self .process_node (fgraph , node , enforce_tracks = True )
1942
2038
loop_t = time .perf_counter () - t0
1943
2039
finally :
1944
2040
self .detach_updater (fgraph , u )
@@ -2279,8 +2375,9 @@ def chin_(node, i, r, new_r, reason):
2279
2375
for node_rewriter in self .node_tracker .get_trackers (node .op ):
2280
2376
nb = change_tracker .nb_imported
2281
2377
t_rewrite = time .perf_counter ()
2378
+ # Tracks are already enfoced by `self.node_tracker.get_trackers`
2282
2379
node_rewriter_change = self .process_node (
2283
- fgraph , node , node_rewriter
2380
+ fgraph , node , node_rewriter , enforce_tracks = False
2284
2381
)
2285
2382
time_rewriters [node_rewriter ] += time .perf_counter () - t_rewrite
2286
2383
if not node_rewriter_change :
0 commit comments