Skip to content

Commit b33227d

Browse files
authored
[dicp][ascend] Optimize stable_diffusion performance on ascend. (#704)
* Optimize stable_diffusion performance. * Add switch config for sd bmm_fp16. * Fix ci llama hf case. * Fix review comments. * Remove redundant singleton design. * Redesign infer_shape interface. * Remove unused codes. * Clean more code. * Add empty line for code-style.
1 parent cda8286 commit b33227d

File tree

5 files changed

+38
-4
lines changed

5 files changed

+38
-4
lines changed

dicp/dicp/vendor/AscendGraph/codegen/ascend.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1116,13 +1116,13 @@ def MatMul(name, x1, x2, trans_x1: bool, trans_x2: bool):
11161116
return op.to_node()
11171117

11181118
@staticmethod
1119-
def BatchMatMul(name, x1, x2, adj_x1: bool, adj_x2: bool):
1119+
def BatchMatMul(name, x1, x2, adj_x1: bool, adj_x2: bool, keep_dtype=1):
11201120
op = OP(name, "BatchMatMul")
11211121
op.set_input("x1", x1)
11221122
op.set_attr_bool("adj_x1", adj_x1)
11231123
op.set_input("x2", x2)
11241124
op.set_attr_bool("adj_x2", adj_x2)
1125-
op.set_attr_int("_keep_dtype", 1)
1125+
op.set_attr_int("_keep_dtype", keep_dtype)
11261126
return op.to_node()
11271127

11281128
@staticmethod

dicp/dicp/vendor/AscendGraph/conversion.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import functools
23
import operator
34
import _operator
@@ -26,6 +27,8 @@
2627
prims = torch.ops.prims
2728
conversions = {}
2829

30+
sd_fp16 = int(os.environ.get("SD_FP16", 0))
31+
2932

3033
def get_reduction_str(r):
3134
if r == 0:
@@ -1173,7 +1176,7 @@ def mm(self, x, y):
11731176
@register_conversion(aten.bmm.default)
11741177
def bmm(self, x, y):
11751178
out_dtype = fx_traceback.get_current_meta()['val'].dtype
1176-
bmm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False))
1179+
bmm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False, sd_fp16 ^ 1))
11771180
return self.get_proxy(ascend_op.Cast, (bmm, get_ascend_dtype(out_dtype)))
11781181

11791182
@register_conversion(torch.torch.ops.aten.addmm)
@@ -1292,6 +1295,8 @@ def _softmax(self, x, dim=-1, half_to_float=False):
12921295
if isinstance(dim, int):
12931296
dim = [dim]
12941297
assert (half_to_float is False)
1298+
if sd_fp16 is not None and int(sd_fp16) == 1:
1299+
x = self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(torch.float16)))
12951300
return self.get_proxy(ascend_op.SoftmaxV2, (x, dim))
12961301

12971302
@register_conversion(torch.ops.aten.sum.default)

dicp/dicp/vendor/AscendGraph/opset_convert.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dicp.dynamo_bridge.compile_fx import is_torch_210
44
from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp
55
from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer
6+
from ...dynamo_bridge.graph import GraphTransformer
67

78
if is_torch_210:
89
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
@@ -18,7 +19,7 @@ def transform(self, gm: torch.fx.graph_module):
1819
for n in gm.graph.nodes:
1920
if hasattr(n, 'op') and n.op == 'placeholder':
2021
fake_tensor = n.meta['val']
21-
memo = fake_tensor.fake_mode.fake_tensor_converter.tensor_memo
22+
memo = fake_tensor.fake_mode.fake_tensor_converter.tensor_memo
2223
for key in memo:
2324
if id(memo[key].fake_device) == id(fake_tensor.fake_device):
2425
memory_format = torch_dipu.get_native_memory_format(key())
@@ -86,6 +87,9 @@ def ascendgraph_opset_convert(
8687

8788
# For bug in pytorch
8889
# Avoid for dynamic shape
90+
gt = GraphTransformer(gm, "ascendgraph")
91+
gt.infer_shape_dtype()
92+
gm = gt.gm
8993
if is_torch_210 and not symint_in_inputs(list(gm.graph.nodes)):
9094
gm = BackendPatternMatcherTransformer(
9195
ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm)

dicp/dicp/vendor/AscendGraph/pattern_replacement.py

+23
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape,
4444
def replacement(self, repeat, dim):
4545
return torch.ops.aten.repeat_interleave.self_int(self, repeat, dim)
4646

47+
Muls = torch.fx.wrap(ascend_op.Muls.get_singleton())
48+
Shape = torch.fx.wrap(ascend_op.Shape.get_singleton())
4749
Const = torch.fx.wrap(ascend_op.Const.get_singleton())
4850
Transpose = torch.fx.wrap(ascend_op.Transpose.get_singleton())
4951
Identity = torch.fx.wrap(ascend_op.Identity.get_singleton())
@@ -71,6 +73,27 @@ def replacement(x1, x2, dtype):
7173
return BatchMatMul(x1, reshape, adj_x1=False, adj_x2=True)
7274

7375

76+
@register_ascend_pattern
77+
class FuseBmmTransposeMulsPattern(BackendPatternBase):
78+
@staticmethod
79+
def pattern(x1, x2, c1, c2):
80+
transpose = Transpose(x2, c1)
81+
muls = Muls(transpose, 0.3535533905932738)
82+
identity = Identity(muls, None)
83+
identity1 = Identity(identity, None)
84+
reshape = Reshape(identity1, c2)
85+
return BatchMatMul(x1, reshape, False, False, 0)
86+
87+
@staticmethod
88+
def replacement(x1, x2, c1, c2):
89+
x2 = Reshape(x2, c2)
90+
perm = Permute(x2, [0, 2, 1])
91+
shape = Shape(perm)
92+
reshape = Reshape(x2, shape)
93+
muls = Muls(reshape, 0.3535533905932738)
94+
return BatchMatMul(x1, muls, adj_x1=False, adj_x2=True, keep_dtype=0)
95+
96+
7497
# @pandaoxin negotiate with @tangzhiyi
7598
# another submit would implement
7699
# @register_ascend_pattern

dicp/test/model/test_stable_diffusion.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def test_inference(
3939
prompt = "A photo of an astronaut riding a horse on mars."
4040
utils.update_dynamo_config(dynamic=dynamic)
4141
torch_dipu.dipu.set_device(device)
42+
if backend == "ascendgraph":
43+
os.environ["SD_FP16"] = "1"
4244

4345
# CPU
4446
torch.manual_seed(1)

0 commit comments

Comments
 (0)