@@ -44,6 +44,8 @@ def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape,
44
44
def replacement (self , repeat , dim ):
45
45
return torch .ops .aten .repeat_interleave .self_int (self , repeat , dim )
46
46
47
+ Muls = torch .fx .wrap (ascend_op .Muls .get_singleton ())
48
+ Shape = torch .fx .wrap (ascend_op .Shape .get_singleton ())
47
49
Const = torch .fx .wrap (ascend_op .Const .get_singleton ())
48
50
Transpose = torch .fx .wrap (ascend_op .Transpose .get_singleton ())
49
51
Identity = torch .fx .wrap (ascend_op .Identity .get_singleton ())
@@ -71,6 +73,27 @@ def replacement(x1, x2, dtype):
71
73
return BatchMatMul (x1 , reshape , adj_x1 = False , adj_x2 = True )
72
74
73
75
76
+ @register_ascend_pattern
77
+ class FuseBmmTransposeMulsPattern (BackendPatternBase ):
78
+ @staticmethod
79
+ def pattern (x1 , x2 , c1 , c2 ):
80
+ transpose = Transpose (x2 , c1 )
81
+ muls = Muls (transpose , 0.3535533905932738 )
82
+ identity = Identity (muls , None )
83
+ identity1 = Identity (identity , None )
84
+ reshape = Reshape (identity1 , c2 )
85
+ return BatchMatMul (x1 , reshape , False , False , 0 )
86
+
87
+ @staticmethod
88
+ def replacement (x1 , x2 , c1 , c2 ):
89
+ x2 = Reshape (x2 , c2 )
90
+ perm = Permute (x2 , [0 , 2 , 1 ])
91
+ shape = Shape (perm )
92
+ reshape = Reshape (x2 , shape )
93
+ muls = Muls (reshape , 0.3535533905932738 )
94
+ return BatchMatMul (x1 , muls , adj_x1 = False , adj_x2 = True , keep_dtype = 0 )
95
+
96
+
74
97
# @pandaoxin negotiate with @tangzhiyi
75
98
# another submit would implement
76
99
# @register_ascend_pattern
0 commit comments