Skip to content

Commit a86f7d3

Browse files
tianleiwuguschmue
authored andcommitted
Stable Diffusion 3.x and Flux Optimization (#22986)
### Description It has dependency on the following PRs: - #23297 Optimize the ONNX pipeline for Stable Diffusion 3.x and Flux 1.0 models (fp32 or fp16). - [x] Update optimize_pipeline script - [x] Update benchmkark script - [x] Update document about Stable Diffusion 3.x and Flux 1.0 models - [x] Add graph optimizations for MMDit model - [x] FastGelu fusion - [x] RMSNorm fusion - [x] MultiHeadAttention fusion - [x] Add graph optimizations for Flux transformer models - [x] MultiHeadAttention fusion - [x] Update graph optimizations for t5 - [x] Add tests Optimize the ONNX pipeline for Stable Diffusion 3.x and Flux 1.0 models: ``` python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp16 --float16 Optimize flux1_schnell_onnx/fp32/transformer/model.onnx ... Fused LayerNormalization: 115 Fused SimplifiedLayerNormalization: 152 Fused FastGelu: 76 Fused MultiHeadAttention: 57 ``` ### H100 Benchmark Results * GPU: NVIDIA H100 80GB HBM3 * Image Size: 1024x1024 * Batch Size: 1 Model | Steps | Precision | Engine | Latency (Seconds) | GPU Memory (MB) -- | -- | -- | -- | -- | -- Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (compile) | 8.198 | 37,603 Flux 1.0 Dev | 50 | FP16+BF16 | Optimum (ORT) | 10.762 | 41,469 Flux 1.0 Dev | 50 | FP16+FP32 | Optimum (ORT) | 10.891 | 43,545 Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (eager) | 12.339 | 36,651 Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (compile) | 0.775 | 37,857 Flux 1.0 Schnell | 4 | FP16+BF16 | Optimum (ORT) | 0.931 | 41,433 Flux 1.0 Schnell | 4 | FP16+FP32 | Optimum (ORT) | 0.939 | 43,809 Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (eager) | 1.120 | 36,629 SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (compile) | 7.466 | 32,217 SD 3.5 Large | 50 | FP16+BF16 | Optimum (ORT) | 10.275 | 36,609 SD 3.5 Large | 50 | FP16+FP32 | Optimum (ORT) | 10.283 | 36,729 SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (eager) | 11.615 | 31,517 SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (compile) | 3.240 | 21,143 SD 3.5 Medium | 50 | FP16+BF16 | Optimum (ORT) | 4.799 | 25,097 SD 3.5 Medium | 50 | FP16+FP32 | Optimum (ORT) | 4.838 | 25,109 SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (eager) | 5.582 | 20,489 ### A100 Benchmark Results * GPU: A100-SXM4-80GB * Image Size: 1024x1024 * Batch Size: 1 Model | Steps | Precision | Engine | Latency (Seconds) | GPU Memory (MB) -- | -- | -- | -- | -- | -- Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (compile) | 17.593 | 37,723 Flux 1.0 Dev | 50 | FP16+BF16 | Optimum (ORT) | 21.918 | 41,348 Flux 1.0 Dev | 50 | FP16+FP32 | Optimum (ORT) | 22.060 | 44,860 Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (eager) | 24.267 | 36,847 Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (compile) | 1.627 | 37,881 Flux 1.0 Schnell | 4 | FP16+BF16 | Optimum (ORT) | 1.884 | 41,537 Flux 1.0 Schnell | 4 | FP16+FP32 | Optimum (ORT) | 1.902 | 44,858 Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (eager) | 2.162 | 36,831 SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (compile) | 15.881 | 32,307 SD 3.5 Large | 50 | FP16+FP32 | Optimum (ORT) | 19.837 | 36,451 SD 3.5 Large | 50 | FP16+BF16 | Optimum (ORT) | 19.964 | 36,461 SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (eager) | 22.477 | 31,513 SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (compile) | 6.476 | 21,341 SD 3.5 Medium | 50 | FP16+FP32 | Optimum (ORT) | 8.775 | 25,183 SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (eager) | 10.057 | 20,433 ### Future Works * Triton kernel for matrix multiplication and auto tuning. * FP8/Int8 quantization ### Motivation and Context SD 3.5 Architecture: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/resolve/main/mmdit-x.png
1 parent 8e4253d commit a86f7d3

19 files changed

+2089
-525
lines changed

onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu

+25-36
Original file line numberDiff line numberDiff line change
@@ -125,42 +125,31 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
125125
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
126126
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
127127

128-
if (data.bias == nullptr) {
129-
assert(nullptr == fused_runner);
130-
// For quantized attention, bias has been added so only need transpose here.
131-
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
132-
assert(qk_head_size == v_head_size);
133-
int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
134-
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
135-
max_threads_per_block, false, data.gemm_buffer, qkv, 3));
136-
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
137-
} else {
138-
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
139-
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
140-
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
141-
// For fused causal kernel, use format 1 since we need have K and V to update present state,
142-
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
143-
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
144-
data.qkv_format = use_fused_kernel
145-
? AttentionQkvFormat::QKV_BSN3H
146-
: (use_flash_or_efficient_attention
147-
? AttentionQkvFormat::Q_K_V_BSNH
148-
: (use_fused_causal
149-
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
150-
: AttentionQkvFormat::Q_K_V_BNSH));
151-
152-
// For fused causal, we will update gemm_buffer with bias directly.
153-
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
154-
155-
int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
156-
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
157-
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
158-
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
159-
batch_size, sequence_length, num_heads, qk_head_size,
160-
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
161-
3, parameters.do_rotary, parameters.rotary_embedding,
162-
parameters.past_sequence_length);
163-
}
128+
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
129+
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
130+
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
131+
// For fused causal kernel, use format 1 since we need have K and V to update present state,
132+
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
133+
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
134+
data.qkv_format = use_fused_kernel
135+
? AttentionQkvFormat::QKV_BSN3H
136+
: (use_flash_or_efficient_attention
137+
? AttentionQkvFormat::Q_K_V_BSNH
138+
: (use_fused_causal
139+
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
140+
: AttentionQkvFormat::Q_K_V_BNSH));
141+
142+
// For fused causal, we will update gemm_buffer with bias directly.
143+
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
144+
145+
int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
146+
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
147+
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
148+
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
149+
batch_size, sequence_length, num_heads, qk_head_size,
150+
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
151+
3, parameters.do_rotary, parameters.rotary_embedding,
152+
parameters.past_sequence_length);
164153
return Status::OK();
165154
}
166155

onnxruntime/python/tools/transformers/compare_bert_results.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,23 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
3737
# Validate the output of baseline and treatment, to make sure the results are similar.
3838
diff_count = 0
3939
max_abs_diff = 0
40+
max_diff_percentage = 0
41+
case_passed = True
4042
for test_case_id, results in enumerate(baseline_results):
41-
case_passed = True
4243
for i in range(len(results)):
4344
treatment_output = treatment_results[test_case_id][i]
44-
abs_diff = np.amax(np.abs(treatment_output - results[i]))
45+
abs_diff_tensor = np.abs(treatment_output - results[i])
46+
abs_diff = np.amax(abs_diff_tensor)
4547
if verbose and abs_diff > atol:
4648
print("abs_diff", abs_diff)
4749
print("treatment", treatment_output)
4850
print("baseline", results[i])
4951

52+
count_exceeding = np.sum(abs_diff_tensor > atol)
53+
total_elements = abs_diff_tensor.size
54+
percentage_exceeding = (count_exceeding / total_elements) * 100
55+
max_diff_percentage = max(max_diff_percentage, percentage_exceeding)
56+
5057
max_abs_diff = max(max_abs_diff, abs_diff)
5158
if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol):
5259
if case_passed:
@@ -66,6 +73,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
6673
)
6774

6875
print(f"maximum absolute difference={max_abs_diff}")
76+
print(f"maximum percentage of elements that exceeds atol={atol} is {max_diff_percentage:.3f}%")
6977
return max_abs_diff, case_passed
7078

7179

onnxruntime/python/tools/transformers/fusion_attention.py

-39
Original file line numberDiff line numberDiff line change
@@ -355,45 +355,6 @@ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
355355
self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
356356
self.node_name_to_graph_name[gather_v_name] = self.this_graph_name
357357

358-
def transpose_kv(self, past_k: str, past_v: str):
359-
"""Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H)
360-
361-
Args:
362-
past_k (str): name of past K value of shape (B,N,P,H)
363-
past_v (str): name of past V value of shape (B,N,P,H)
364-
365-
Returns:
366-
past_k_transpose (str): name of past K value of shape (B,P,N,H)
367-
past_v_transpose (str): name of past V value of shape (B,P,N,H)
368-
"""
369-
past_k_transpose = (past_k + "_transposed").replace(".", "_")
370-
past_v_transpose = (past_v + "_transposed").replace(".", "_")
371-
transpose_k_name = self.model.create_node_name("Transpose")
372-
transpose_v_name = self.model.create_node_name("Transpose")
373-
374-
transpose_k = helper.make_node(
375-
"Transpose",
376-
inputs=[past_k],
377-
outputs=[past_k_transpose],
378-
name=transpose_k_name,
379-
perm=[0, 2, 1, 3],
380-
)
381-
transpose_v = helper.make_node(
382-
"Transpose",
383-
inputs=[past_v],
384-
outputs=[past_v_transpose],
385-
name=transpose_v_name,
386-
perm=[0, 2, 1, 3],
387-
)
388-
389-
# Add reshape nodes to graph
390-
self.nodes_to_add.append(transpose_k)
391-
self.nodes_to_add.append(transpose_v)
392-
self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name
393-
self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name
394-
395-
return past_k_transpose, past_v_transpose
396-
397358
def create_combined_qkv_bias(
398359
self,
399360
q_add: NodeProto,

onnxruntime/python/tools/transformers/fusion_fastgelu.py

+122
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
2626
if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
2727
return
2828

29+
if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node):
30+
return
31+
2932
def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
3033
"""
3134
Fuse Gelu with tanh into one node:
@@ -358,3 +361,122 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict
358361
self.nodes_to_add.append(fused_node)
359362
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
360363
return True
364+
365+
def fuse_4(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
366+
"""
367+
This pattern is from stable diffusion 3.5 model.
368+
Fuse Gelu with tanh into one node:
369+
+-----------------+------------------+
370+
| | |
371+
| v v
372+
[root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul -->
373+
| (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5)
374+
| |
375+
+-------------------------------------------------------------------------+
376+
Note that constant input for Add and Mul could be first or second input.
377+
"""
378+
if tanh_node.output[0] not in input_name_to_nodes:
379+
return
380+
381+
children = input_name_to_nodes[tanh_node.output[0]]
382+
if len(children) != 1 or children[0].op_type != "Add":
383+
return
384+
add_after_tanh = children[0]
385+
386+
if not self.model.has_constant_input(add_after_tanh, 1.0):
387+
return
388+
389+
if add_after_tanh.output[0] not in input_name_to_nodes:
390+
return
391+
children = input_name_to_nodes[add_after_tanh.output[0]]
392+
if len(children) != 1 or children[0].op_type != "Mul":
393+
return
394+
mul_after_tanh = children[0]
395+
396+
if mul_after_tanh.output[0] not in input_name_to_nodes:
397+
return
398+
children = input_name_to_nodes[mul_after_tanh.output[0]]
399+
if len(children) != 1 or children[0].op_type != "Mul":
400+
return
401+
mul_half = children[0]
402+
if not self.model.has_constant_input(mul_half, 0.5):
403+
return
404+
405+
root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1]
406+
407+
mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
408+
if mul_before_tanh is None:
409+
return
410+
411+
i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
412+
if i < 0:
413+
return
414+
415+
add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
416+
if add_before_tanh is None:
417+
return
418+
419+
if add_before_tanh.input[0] == root_input:
420+
another = 1
421+
elif add_before_tanh.input[1] == root_input:
422+
another = 0
423+
else:
424+
return
425+
426+
mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node)
427+
if mul_after_pow is None:
428+
return
429+
430+
i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
431+
if i < 0:
432+
return
433+
434+
mul = self.model.match_parent(mul_after_pow, "Mul", 0 if i == 1 else 1, output_name_to_node)
435+
if mul is None:
436+
return
437+
438+
if mul.input[0] == root_input:
439+
another = 1
440+
elif mul.input[1] == root_input:
441+
another = 0
442+
else:
443+
return
444+
445+
mul2 = self.model.match_parent(mul, "Mul", another, output_name_to_node)
446+
if mul2 is None:
447+
return
448+
449+
if mul2.input[0] != root_input or mul2.input[1] != root_input:
450+
return
451+
452+
subgraph_nodes = [
453+
mul2,
454+
mul,
455+
mul_after_pow,
456+
add_before_tanh,
457+
mul_before_tanh,
458+
tanh_node,
459+
add_after_tanh,
460+
mul_after_tanh,
461+
mul_half,
462+
]
463+
464+
if not self.model.is_safe_to_fuse_nodes(
465+
subgraph_nodes,
466+
[mul_half.output[0]],
467+
input_name_to_nodes,
468+
output_name_to_node,
469+
):
470+
return
471+
472+
self.nodes_to_remove.extend(subgraph_nodes)
473+
fused_node = helper.make_node(
474+
"FastGelu",
475+
inputs=[root_input],
476+
outputs=mul_half.output,
477+
name=self.model.create_node_name("FastGelu"),
478+
)
479+
fused_node.domain = "com.microsoft"
480+
self.nodes_to_add.append(fused_node)
481+
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
482+
return True

onnxruntime/python/tools/transformers/fusion_group_norm.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
8484
instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
8585
if instance_norm_scale is None or len(instance_norm_scale.shape) != 1:
8686
return
87+
num_groups = int(instance_norm_scale.shape[0])
8788

8889
instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
8990
if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape:
@@ -156,7 +157,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
156157
)
157158

158159
new_node.attribute.extend(instance_norm.attribute)
159-
new_node.attribute.extend([helper.make_attribute("groups", 32)])
160+
161+
new_node.attribute.extend([helper.make_attribute("groups", num_groups)])
160162
new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)])
161163

162164
if not self.channels_last:

0 commit comments

Comments
 (0)