@@ -67,7 +67,7 @@ def partition_types(self) -> list[OpOverload]:
67
67
@abstractmethod
68
68
def get_anchors (
69
69
self , gm : torch .fx .GraphModule , fused_partition : List [fx .GraphModule ]
70
- ) -> Optional [PartitionAnchors ]:
70
+ ) -> Tuple [PartitionAnchors , fx . Node ]:
71
71
pass
72
72
73
73
@abstractmethod
@@ -85,7 +85,7 @@ def partition_types(self) -> List[OpOverload]:
85
85
86
86
def get_anchors (
87
87
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
88
- ) -> PartitionAnchors :
88
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
89
89
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
90
90
addmm_node = fused_partition [0 ].nodes [- 1 ]
91
91
@@ -101,12 +101,12 @@ def get_anchors(
101
101
qscheme = torch .per_tensor_affine ,
102
102
)
103
103
104
- return PartitionAnchors (
104
+ return ( PartitionAnchors (
105
105
inputs = [(addmm_node , 1 )],
106
106
weights = [(addmm_node , 2 )],
107
107
biases = [(addmm_node , 0 , bias_qspec )],
108
108
output = [(addmm_node ,)],
109
- )
109
+ ), addmm_node )
110
110
111
111
def replacement_op (self ) -> OpOverload :
112
112
return torch .ops .cadence .quantized_linear .default
@@ -118,7 +118,7 @@ def partition_types(self) -> List[OpOverload]:
118
118
119
119
def get_anchors (
120
120
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
121
- ) -> PartitionAnchors :
121
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
122
122
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
123
123
add_node = fused_partition [0 ].nodes [- 1 ]
124
124
@@ -129,16 +129,16 @@ def get_anchors(
129
129
add_node .args [1 ], fx .Node
130
130
)
131
131
if not is_tensor_add or len (add_node .kwargs ) > 0 :
132
- return PartitionAnchors (
132
+ return ( PartitionAnchors (
133
133
empty = True ,
134
- )
134
+ ), add_node )
135
135
136
- return PartitionAnchors (
136
+ return ( PartitionAnchors (
137
137
inputs = [(add_node , 0 ), (add_node , 1 )],
138
138
weights = [],
139
139
biases = [],
140
140
output = [(add_node ,)],
141
- )
141
+ ), add_node )
142
142
143
143
def replacement_op (self ) -> OpOverload :
144
144
return torch .ops .cadence .quantized_add .default
@@ -150,16 +150,16 @@ def partition_types(self) -> List[OpOverload]:
150
150
151
151
def get_anchors (
152
152
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
153
- ) -> PartitionAnchors :
153
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
154
154
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
155
155
bmm_node = fused_partition [0 ].nodes [- 1 ]
156
156
157
- return PartitionAnchors (
157
+ return ( PartitionAnchors (
158
158
inputs = [(bmm_node , 0 ), (bmm_node , 1 )],
159
159
weights = [],
160
160
biases = [],
161
161
output = [(bmm_node ,)],
162
- )
162
+ ), bmm_node )
163
163
164
164
def replacement_op (self ) -> OpOverload :
165
165
return torch .ops .cadence .quantized_matmul .default
@@ -171,7 +171,7 @@ def partition_types(self) -> List[OpOverload]:
171
171
172
172
def get_anchors (
173
173
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
174
- ) -> PartitionAnchors :
174
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
175
175
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
176
176
cat_node = fused_partition [0 ].nodes [- 1 ]
177
177
@@ -198,14 +198,14 @@ def get_anchors(
198
198
)
199
199
)
200
200
201
- return PartitionAnchors (
201
+ return ( PartitionAnchors (
202
202
inputs = args ,
203
203
weights = [],
204
204
biases = [],
205
205
output = [
206
206
(cat_node , SharedQuantizationSpec ((cat_node .args [0 ][0 ], cat_node )))
207
207
],
208
- )
208
+ ), cat_node )
209
209
210
210
def replacement_op (self ) -> OpOverload :
211
211
return torch .ops .aten .cat .default
@@ -217,7 +217,7 @@ def partition_types(self) -> List[OpOverload]:
217
217
218
218
def get_anchors (
219
219
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
220
- ) -> PartitionAnchors :
220
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
221
221
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
222
222
conv1d_node = fused_partition [0 ].nodes [- 1 ]
223
223
@@ -238,13 +238,13 @@ def get_anchors(
238
238
if len (conv1d_node .args ) > 2 and conv1d_node .args [2 ] is not None :
239
239
bias = [(conv1d_node , 2 , bias_qspec )]
240
240
241
- return PartitionAnchors (
241
+ return ( PartitionAnchors (
242
242
inputs = [(conv1d_node , 0 )],
243
243
weights = [(conv1d_node , 1 )],
244
244
# pyre-fixme[6]: Incompatible parameter type
245
245
biases = bias ,
246
246
output = [(conv1d_node ,)],
247
- )
247
+ ), conv1d_node )
248
248
249
249
def replacement_op (self ) -> OpOverload :
250
250
return torch .ops .cadence .quantized_conv_nchw .default
@@ -256,7 +256,7 @@ def partition_types(self) -> List[OpOverload]:
256
256
257
257
def get_anchors (
258
258
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
259
- ) -> PartitionAnchors :
259
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
260
260
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
261
261
conv2d_node = fused_partition [0 ].nodes [- 1 ]
262
262
@@ -277,13 +277,13 @@ def get_anchors(
277
277
if len (conv2d_node .args ) > 2 and conv2d_node .args [2 ] is not None :
278
278
bias = [(conv2d_node , 2 , bias_qspec )]
279
279
280
- return PartitionAnchors (
280
+ return ( PartitionAnchors (
281
281
inputs = [(conv2d_node , 0 )],
282
282
weights = [(conv2d_node , 1 )],
283
283
# pyre-fixme[6]: Incompatible parameter type
284
284
biases = bias ,
285
285
output = [(conv2d_node ,)],
286
- )
286
+ ), conv2d_node )
287
287
288
288
def replacement_op (self ) -> OpOverload :
289
289
return torch .ops .cadence .quantized_conv_nchw .default
@@ -295,7 +295,7 @@ def partition_types(self) -> List[OpOverload]:
295
295
296
296
def get_anchors (
297
297
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
298
- ) -> PartitionAnchors :
298
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
299
299
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
300
300
layer_norm_node = fused_partition [0 ].nodes [- 1 ]
301
301
@@ -311,14 +311,14 @@ def get_anchors(
311
311
312
312
# Weights are used in quantized mode by our kernel, so they are
313
313
# passed in as others here along with the normalized shape.
314
- return PartitionAnchors (
314
+ return ( PartitionAnchors (
315
315
inputs = [(layer_norm_node , 0 )],
316
316
weights = [],
317
317
biases = [],
318
318
# Ordering: normalized_shape, weights, bias
319
319
others = others ,
320
320
output = [(layer_norm_node ,)],
321
- )
321
+ ), layer_norm_node )
322
322
323
323
def replacement_op (self ) -> OpOverload :
324
324
return torch .ops .cadence .quantized_layer_norm .default
@@ -330,7 +330,7 @@ def partition_types(self) -> List[OpOverload]:
330
330
331
331
def get_anchors (
332
332
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
333
- ) -> PartitionAnchors :
333
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
334
334
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
335
335
linear_node = fused_partition [0 ].nodes [- 1 ]
336
336
@@ -351,13 +351,13 @@ def get_anchors(
351
351
if len (linear_node .args ) > 2 :
352
352
bias = [(linear_node , 2 , bias_qspec )]
353
353
354
- return PartitionAnchors (
354
+ return ( PartitionAnchors (
355
355
inputs = [(linear_node , 0 )],
356
356
weights = [(linear_node , 1 )],
357
357
# pyre-fixme[6]: Incompatible parameter type
358
358
biases = bias ,
359
359
output = [(linear_node ,)],
360
- )
360
+ ), linear_node )
361
361
362
362
def replacement_op (self ) -> OpOverload :
363
363
return torch .ops .cadence .quantized_linear .default
@@ -369,16 +369,16 @@ def partition_types(self) -> List[OpOverload]:
369
369
370
370
def get_anchors (
371
371
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
372
- ) -> PartitionAnchors :
372
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
373
373
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
374
374
matmul_node = fused_partition [0 ].nodes [- 1 ]
375
375
376
- return PartitionAnchors (
376
+ return ( PartitionAnchors (
377
377
inputs = [(matmul_node , 0 ), (matmul_node , 1 )],
378
378
weights = [],
379
379
biases = [],
380
380
output = [(matmul_node ,)],
381
- )
381
+ ), matmul_node )
382
382
383
383
def replacement_op (self ) -> OpOverload :
384
384
return torch .ops .cadence .quantized_matmul .default
@@ -392,16 +392,16 @@ def partition_types(self) -> List[OpOverload]:
392
392
393
393
def get_anchors (
394
394
self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
395
- ) -> PartitionAnchors :
395
+ ) -> Tuple [ PartitionAnchors , fx . Node ] :
396
396
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
397
397
relu_node = fused_partition [0 ].nodes [- 1 ]
398
398
399
- return PartitionAnchors (
399
+ return ( PartitionAnchors (
400
400
inputs = [(relu_node , 0 )],
401
401
weights = [],
402
402
biases = [],
403
403
output = [(relu_node ,)],
404
- )
404
+ ), relu_node )
405
405
406
406
def replacement_op (self ) -> OpOverload :
407
407
return torch .ops .cadence .quantized_relu .default
0 commit comments