@@ -2407,47 +2407,52 @@ def call_operator(self, op, args, kwargs, meta):
2407
2407
)
2408
2408
2409
2409
2410
- # This class encapsulates all the functions that replace/switch one op in the
2411
- # graph with another.
2412
- class CadenceReplaceOpsInGraph :
2410
+ # Encapsulates the generic graph improvement passes.
2411
+ class GenericReplaceOpsInGraph :
2413
2412
passes = [
2414
- ReplaceEmptyTensorsWithFullPass ,
2415
- ReplaceFunctionallyEquivalentOpTargets ,
2413
+ ReplaceSqueezeAndUnsqueezeWithViewPass ,
2414
+ ReplaceSplitWithSlicePass ,
2415
+ ReplaceSelectWithViewOpPass ,
2416
+ ReplaceMMWithAddMMPass ,
2417
+ ReplaceRepeatWithCatPass ,
2418
+ ReplaceFullLikeWithFullPass ,
2419
+ ReplaceFunctionallyEquivalentOpTargets ,
2420
+ ReplaceConvolutionOptionalArgsWithConcreteArgsPass ,
2421
+ ReplaceAddMMWithLinearPass ,
2422
+ ReplacePadWithCatPass ,
2423
+ ReplaceConstantPadNdWithSlicePass ,
2424
+ ReplaceTrivialConvWithLinear ,
2425
+ ReplaceScalarTensorWithFullPass ,
2426
+ ReplaceInfArgInFullWithValuePass ,
2427
+ ReplaceLogicalNotBooleanWhereWithWherePass ,
2428
+ ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass ,
2429
+ ReplaceAtenApproxGeluWithApproxGeluPass ,
2430
+ ReplaceAtenConvolutionWithJarvisConvolutionPass ,
2431
+ ReplacePT2QuantWithCadenceQuantPass ,
2432
+ ReplacePT2DequantWithCadenceDequantPass ,
2433
+ ReplacePowWithMulPass ,
2434
+ ReplaceEmptyTensorsWithFullPass ,
2435
+ ReplaceScalarWithTensorArgPass ,
2436
+ RemoveNopSelectOpPass ,
2437
+ ReplaceNopTransposeOrPermuteWithViewPass ,
2438
+ ]
2439
+
2440
+ # Encapsulates all the passes which replace ops in graph including Sadence specific ones.
2441
+ class CadenceReplaceOpsInGraph :
2442
+ cadence_specific_passes = [
2416
2443
ReplacePermuteWithTransposePass ,
2417
- ReplaceScalarWithTensorArgPass ,
2418
- ReplaceConvolutionOptionalArgsWithConcreteArgsPass ,
2419
- ReplaceMMWithAddMMPass ,
2420
- ReplaceSqueezeAndUnsqueezeWithViewPass ,
2421
- ReplaceAddMMWithLinearPass ,
2422
- RemoveNopSelectOpPass ,
2423
- ReplaceSelectWithViewOpPass ,
2424
- ReplaceRepeatWithCatPass ,
2425
2444
ReplacePadWithCatPass ,
2426
- ReplaceConstantPadNdWithSlicePass ,
2427
2445
ReplaceConvWithChannelLastConvPass ,
2428
- ReplaceAtenConvolutionWithJarvisConvolutionPass ,
2429
2446
ForceChannelLastForConvPass ,
2430
- ReplaceTrivialConvWithLinear ,
2431
2447
ReplaceConvWithIm2RowAndLinear ,
2432
2448
ReplaceTransposedConvWithLinearPass ,
2433
2449
# This pass should be after passes that replace conv -> im2row + linear.
2434
2450
ReplaceIm2RowWithViewPass ,
2435
2451
MakeSliceAndCatDimOutermostPass ,
2436
2452
ReplaceMatmulWithTransposedMatmulPass ,
2437
- ReplaceNopTransposeOrPermuteWithViewPass ,
2438
2453
ReplaceLinearWithFullyConnectedOpPass ,
2439
- ReplaceScalarTensorWithFullPass ,
2440
- ReplaceFullLikeWithFullPass ,
2441
- ReplaceInfArgInFullWithValuePass ,
2442
- ReplaceLogicalNotBooleanWhereWithWherePass ,
2443
- ReplacePT2QuantWithCadenceQuantPass ,
2444
- ReplacePT2DequantWithCadenceDequantPass ,
2445
2454
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass ,
2446
- ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass ,
2447
2455
ReplaceAtenAvgPoolWithJarvisAvgPoolPass ,
2448
2456
ReplaceWhereWithFullArgsWithWhereScalar ,
2449
- ReplaceAtenApproxGeluWithApproxGeluPass ,
2450
- ReplaceSplitWithSlicePass ,
2451
- ReplacePowWithMulPass ,
2452
- ReplaceMulTensorWithMulAndFullOpsPass ,
2453
2457
]
2458
+ passes = GenericReplaceOpsInGraph .passes + cadence_specific_passes
0 commit comments