Skip to content

Commit be7e4f1

Browse files
tianleiwuguschmue
authored andcommitted
Dump nodes with potential overflow in half conversion (#23363)
Add a tool to generate node_block_list used in [float16 conversion tool](https://github.com/microsoft/onnxruntime/blob/04030f64be10e020d3ac9aa5ba7d0f2917cbd14e/onnxruntime/python/tools/transformers/float16.py#L175). Previously, we have a feature to dump statistics data (like min, max) of each node input/output. However, it is time consuming to generate a list of nodes that need to be kept in float32 when model is large. This could help speed up the process by outputting a list of nodes that have potential overflow in float-to-half conversion. Usage is to build onnxruntime from source with ` --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=1`, then set some environment variables before running float32 optimized onnx model like: ``` export ORT_DEBUG_NODE_IO_DUMP_HALF_CONVERSION_OVERFLOW=1 export ORT_DEBUG_NODE_IO_HALF_OVERFLOW_THRESHOLD=50000 python benchmark.py -e optimum --height 1024 --width 1024 --steps 3 -b 1 -v Flux.1D -p flux1_dev_onnx/fp32_opt --skip_warmup ``` The threshold `ORT_DEBUG_NODE_IO_HALF_OVERFLOW_THRESHOLD` shall be <= 65504. The default value is 50000 if the environment variable is not set. It is better to leave some margin if number of samples are not large enough in the test. As a demo, we add an option --skip_warmup to benchmark.py for Flux, so that we can reduce the time on dumping warm-up runs. Example snippet of stdout (each inference session has such a summary when session ended): ``` Total counter in node dumping: 141 Found 2 nodes cannot be converted to half precision due to potential input/output overflow. Operator frequencies for these nodes: Softmax : 1 MatMul : 1 # ------- # Example python script for float16 conversion # For details, search `node_block_list` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/float16.py # ------- from onnxruntime.transformers.onnx_model import OnnxModel m = OnnxModel(onnx.load('flux1_dev_onnx/fp32_opt/vae_decoder/model.onnx')) node_block_list = [ '/decoder/mid_block/attentions.0/Softmax', '/decoder/mid_block/attentions.0/MatMul', ] m.convert_float_to_float16(keep_io_types=False, node_block_list=node_block_list) m.save_model_to_file('fp16/optimized.onnx', use_external_data_format=False) ``` Then you can use the python script to convert corresponding model to float16. ### Motivation and Context It is a tool used to generate node_block_list used in float16 conversion of stable diffusion 3.x and flux models in #22986. In stable diffusion or Flux pipeline, there are multiple models and there could be multiple session runs for each model. Without a proper tool, it is time consuming to get node_block_list for each model.
1 parent 349d45a commit be7e4f1

File tree

5 files changed

+267
-61
lines changed

5 files changed

+267
-61
lines changed

onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc

+137-19
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,73 @@
2222
namespace onnxruntime {
2323
namespace utils {
2424

25+
void NodeDumpAnalysis::Add(const std::string& node_name, const std::string& op_type, bool is_half_overflow) {
26+
std::lock_guard<std::mutex> lock(set_mutex);
27+
if (is_half_overflow) {
28+
auto p = half_overflow_nodes.insert(node_name);
29+
if (p.second) { // insert succeeded
30+
++half_overflow_ops[op_type];
31+
}
32+
}
33+
34+
counter++;
35+
}
36+
37+
void NodeDumpAnalysis::PrintToStdOut(const std::string& model_path) {
38+
std::lock_guard<std::mutex> lock(set_mutex);
39+
if (counter == 0) {
40+
return;
41+
}
42+
43+
// We added counter twice per node (once for node inputs, once for node outputs), so we need to divide it by 2.
44+
counter /= 2;
45+
46+
std::cout << "Total counter in node dumping: " << counter << std::endl;
47+
48+
if (!half_overflow_nodes.empty()) {
49+
std::cout << "Found " << half_overflow_nodes.size() << " nodes cannot be converted to half precision due to potential input/output overflow." << std::endl;
50+
51+
if (half_overflow_nodes.count("") > 0) {
52+
std::cout << "Warning: some node name is empty and node_block_list is not completed. "
53+
<< "Please update the model to make sure each node has name then run this tool again!" << std::endl;
54+
}
55+
56+
// Sort and display the op frequency in the descending order
57+
std::cout << "Operator frequencies for these nodes:" << std::endl;
58+
std::vector<std::pair<std::string, int>> op_freq(half_overflow_ops.begin(), half_overflow_ops.end());
59+
std::sort(op_freq.begin(), op_freq.end(),
60+
[](const std::pair<std::string, int>& a, const std::pair<std::string, int>& b) {
61+
return b.second < a.second;
62+
});
63+
for (const auto& pair : op_freq) {
64+
std::cout << pair.first << " : " << pair.second << std::endl;
65+
}
66+
} else {
67+
std::cout << "No node has potential overflow during half conversion so node_block_list is empty." << std::endl;
68+
}
69+
70+
std::cout << "# -------" << std::endl;
71+
std::cout << "# Example python script for float16 conversion" << std::endl;
72+
std::cout << "# For details, search `node_block_list` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/float16.py" << std::endl;
73+
std::cout << "# -------" << std::endl;
74+
std::cout << "from onnxruntime.transformers.onnx_model import OnnxModel" << std::endl;
75+
std::cout << "m = OnnxModel(onnx.load('" << model_path << "'))" << std::endl;
76+
if (!half_overflow_nodes.empty()) {
77+
std::cout << "node_block_list = [" << std::endl;
78+
for (const auto& node : half_overflow_nodes) {
79+
if (!node.empty()) {
80+
std::cout << " '" << node << "'," << std::endl;
81+
}
82+
}
83+
std::cout << "]" << std::endl;
84+
std::cout << "m.convert_float_to_float16(keep_io_types=False, node_block_list=node_block_list)" << std::endl;
85+
} else {
86+
std::cout << "m.convert_float_to_float16(keep_io_types=False)" << std::endl;
87+
}
88+
89+
std::cout << "m.save_model_to_file('fp16/optimized.onnx', use_external_data_format=False)" << std::endl;
90+
}
91+
2592
namespace {
2693

2794
struct TensorMetadata {
@@ -59,10 +126,13 @@ bool FilterNode(const NodeDumpOptions& dump_options, const Node& node) {
59126
}
60127

61128
template <typename T>
62-
void DumpTensorToStdOut(const Tensor& tensor, const NodeDumpOptions& dump_options) {
63-
onnxruntime::utils::PrintCpuTensor<T>(tensor, dump_options.snippet_threshold, dump_options.snippet_edge_items);
64-
if (dump_options.dump_flags & NodeDumpOptions::DumpFlags::StatisticsData) {
65-
onnxruntime::utils::PrintCpuTensorStats<T>(tensor);
129+
void DumpTensorToStdOut(const Tensor& tensor, const NodeDumpOptions& dump_options, TensorStatisticsData& tensor_statistics) {
130+
if ((dump_options.dump_flags & NodeDumpOptions::DumpFlags::InputData) != 0) {
131+
onnxruntime::utils::PrintCpuTensor<T>(tensor, dump_options.snippet_threshold, dump_options.snippet_edge_items);
132+
}
133+
134+
if ((dump_options.dump_flags & NodeDumpOptions::DumpFlags::StatisticsData) != 0) {
135+
onnxruntime::utils::PrintCpuTensorStats<T>(tensor, tensor_statistics);
66136
}
67137
}
68138

@@ -295,10 +365,10 @@ void InsertNodePlacementToSqliteDb(const NodeDumpContext& dump_context, const No
295365

296366
void DumpCpuTensor(
297367
const NodeDumpOptions& dump_options,
298-
const Tensor& tensor, const TensorMetadata& tensor_metadata) {
368+
const Tensor& tensor, const TensorMetadata& tensor_metadata, TensorStatisticsData& tensor_statistics) {
299369
switch (dump_options.data_destination) {
300370
case NodeDumpOptions::DataDestination::StdOut: {
301-
DispatchOnTensorType(tensor.DataType(), DumpTensorToStdOut, tensor, dump_options);
371+
DispatchOnTensorType(tensor.DataType(), DumpTensorToStdOut, tensor, dump_options, tensor_statistics);
302372
break;
303373
}
304374
case NodeDumpOptions::DataDestination::TensorProtoFiles: {
@@ -321,15 +391,15 @@ void DumpCpuTensor(
321391

322392
void DumpTensor(
323393
const NodeDumpOptions& dump_options,
324-
const Tensor& tensor, TensorMetadata& tensor_metadata,
394+
const Tensor& tensor, TensorMetadata& tensor_metadata, TensorStatisticsData& tensor_statistics,
325395
const SessionState& session_state) {
326396
// check tensor is on CPU before dumping it
327397
auto& tensor_location = tensor.Location();
328398
if (tensor_location.device.Type() == OrtDevice::CPU ||
329399
tensor_location.mem_type == OrtMemTypeCPUInput ||
330400
tensor_location.mem_type == OrtMemTypeCPUOutput) {
331401
tensor_metadata.device_type = "CPU";
332-
DumpCpuTensor(dump_options, tensor, tensor_metadata);
402+
DumpCpuTensor(dump_options, tensor, tensor_metadata, tensor_statistics);
333403
} else {
334404
std::cout << tensor_location << "\n";
335405

@@ -345,7 +415,7 @@ void DumpTensor(
345415
auto status = data_transfer_mgr.CopyTensor(tensor, cpu_tensor);
346416
if (status == common::Status::OK()) {
347417
tensor_metadata.device_type = "GPU";
348-
DumpCpuTensor(dump_options, cpu_tensor, tensor_metadata);
418+
DumpCpuTensor(dump_options, cpu_tensor, tensor_metadata, tensor_statistics);
349419
} else {
350420
std::cout << " failed to transfer data to cpu.\n";
351421
}
@@ -383,6 +453,11 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() {
383453
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpStatisticsData, false)) {
384454
opts.dump_flags |= NodeDumpOptions::DumpFlags::StatisticsData;
385455
}
456+
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpHalfConversionOverflow, false)) {
457+
// Statistics data is required for half conversion overflow detection.
458+
opts.dump_flags |= NodeDumpOptions::DumpFlags::StatisticsData;
459+
opts.dump_flags |= NodeDumpOptions::DumpFlags::HalfConversionOverflow;
460+
}
386461

387462
opts.filter.name_pattern = Env::Default().GetEnvironmentVar(env_vars::kNameFilter);
388463
opts.filter.op_type_pattern = Env::Default().GetEnvironmentVar(env_vars::kOpTypeFilter);
@@ -402,6 +477,13 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() {
402477
opts.snippet_threshold = ParseEnvironmentVariableWithDefault<int>(env_vars::kSnippetThreshold, kDefaultSnippetThreshold);
403478
opts.snippet_edge_items = ParseEnvironmentVariableWithDefault<int>(env_vars::kSnippetEdgeItems, kDefaultSnippetEdgeItems);
404479

480+
constexpr int kMaxHalfThreshold = 65504;
481+
// The default value is set to have reasonable margin for input variance.
482+
int threshold = ParseEnvironmentVariableWithDefault<int>(env_vars::kHalfOverflowThreshold, 50000);
483+
ORT_ENFORCE(threshold > 0 && threshold <= kMaxHalfThreshold,
484+
debug_node_inputs_outputs_env_vars::kHalfOverflowThreshold, " shall be a positive integer <= ", kMaxHalfThreshold);
485+
opts.half_overflow_threshold = static_cast<float>(threshold);
486+
405487
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kAppendRankToFileName, false)) {
406488
std::string rank = Env::Default().GetEnvironmentVar("OMPI_COMM_WORLD_RANK");
407489
if (rank.empty()) {
@@ -452,7 +534,8 @@ void DumpNodeInputs(
452534
const NodeDumpContext& dump_context,
453535
const OpKernelContext& context,
454536
const Node& node,
455-
const SessionState& session_state) {
537+
const SessionState& session_state,
538+
NodeDumpAnalysis& dump_analysis) {
456539
const bool is_any_output_dumped = IsAnyOutputDumped(dump_options);
457540
if (!is_any_output_dumped) {
458541
return;
@@ -477,6 +560,9 @@ void DumpNodeInputs(
477560
const auto& input_defs = node.InputDefs();
478561
TensorMetadata tensor_metadata;
479562

563+
bool check_half_overflow = (dump_options.data_destination == NodeDumpOptions::DataDestination::StdOut) &&
564+
(dump_options.dump_flags & NodeDumpOptions::DumpFlags::HalfConversionOverflow) != 0;
565+
bool potential_half_overflow = false;
480566
for (auto i = 0, end = context.InputCount(); i < end; ++i) {
481567
if (input_defs[i]->Exists()) {
482568
std::cout << "Input " << i << " Name: " << input_defs[i]->Name() << "\n";
@@ -491,11 +577,20 @@ void DumpNodeInputs(
491577
const bool is_shape_set = (dump_options.dump_flags & NodeDumpOptions::DumpFlags::Shape) != 0;
492578
PrintIf(is_shape_set, MakeString(" Shape: ", shape, "\n"));
493579

494-
if ((dump_options.dump_flags & NodeDumpOptions::DumpFlags::InputData) != 0) {
580+
if ((dump_options.dump_flags & NodeDumpOptions::DumpFlags::InputData) != 0 || check_half_overflow) {
495581
tensor_metadata.name = input_defs[i]->Name();
496582
tensor_metadata.step = dump_context.iteration;
497583
tensor_metadata.consumer = node.Name() + ":" + std::to_string(i);
498-
DumpTensor(dump_options, *tensor, tensor_metadata, session_state);
584+
585+
TensorStatisticsData tensor_statistics;
586+
DumpTensor(dump_options, *tensor, tensor_metadata, tensor_statistics, session_state);
587+
588+
if (check_half_overflow && tensor_statistics.is_float) {
589+
float threshold = dump_options.half_overflow_threshold;
590+
if (tensor_statistics.float_min < -threshold || tensor_statistics.float_max > threshold) {
591+
potential_half_overflow = true;
592+
}
593+
}
499594
}
500595
} else {
501596
std::cout << " is empty optional tensor.\n";
@@ -511,22 +606,28 @@ void DumpNodeInputs(
511606
std::cout << "Input " << i << " is optional and was not provided.\n";
512607
}
513608
}
609+
610+
if (check_half_overflow) {
611+
dump_analysis.Add(node.Name(), node.OpType(), potential_half_overflow);
612+
}
514613
}
515614

516615
void DumpNodeInputs(
517616
const NodeDumpContext& dump_context,
518617
const OpKernelContext& context,
519618
const Node& node,
520-
const SessionState& session_state) {
521-
DumpNodeInputs(NodeDumpOptionsFromEnvironmentVariables(), dump_context, context, node, session_state);
619+
const SessionState& session_state,
620+
NodeDumpAnalysis& dump_analysis) {
621+
DumpNodeInputs(NodeDumpOptionsFromEnvironmentVariables(), dump_context, context, node, session_state, dump_analysis);
522622
}
523623

524624
void DumpNodeOutputs(
525625
const NodeDumpOptions& dump_options,
526626
const NodeDumpContext& dump_context,
527627
OpKernelContext& context,
528628
const Node& node,
529-
const SessionState& session_state) {
629+
const SessionState& session_state,
630+
NodeDumpAnalysis& dump_analysis) {
530631
const bool is_any_output_dumped = IsAnyOutputDumped(dump_options);
531632
if (!is_any_output_dumped) {
532633
return;
@@ -549,6 +650,9 @@ void DumpNodeOutputs(
549650
const auto& output_defs = node.OutputDefs();
550651
TensorMetadata tensor_metadata;
551652

653+
bool check_half_overflow = (dump_options.data_destination == NodeDumpOptions::DataDestination::StdOut) &&
654+
(dump_options.dump_flags & NodeDumpOptions::DumpFlags::HalfConversionOverflow) != 0;
655+
bool potential_half_overflow = false;
552656
for (auto i = 0, end = context.OutputCount(); i < end; ++i) {
553657
if (output_defs[i]->Exists()) {
554658
std::cout << "Output " << i << " Name: " << output_defs[i]->Name() << "\n";
@@ -562,11 +666,20 @@ void DumpNodeOutputs(
562666
const bool is_shape_set = (dump_options.dump_flags & NodeDumpOptions::DumpFlags::Shape) != 0;
563667
PrintIf(is_shape_set, MakeString(" Shape: ", shape, "\n"));
564668

565-
if ((dump_options.dump_flags & NodeDumpOptions::DumpFlags::OutputData) != 0) {
669+
if ((dump_options.dump_flags & NodeDumpOptions::DumpFlags::OutputData) != 0 || check_half_overflow) {
566670
tensor_metadata.name = output_defs[i]->Name();
567671
tensor_metadata.step = dump_context.iteration;
568672
tensor_metadata.producer = node.Name() + ":" + std::to_string(i);
569-
DumpTensor(dump_options, *tensor, tensor_metadata, session_state);
673+
674+
TensorStatisticsData tensor_statistics;
675+
DumpTensor(dump_options, *tensor, tensor_metadata, tensor_statistics, session_state);
676+
677+
if (check_half_overflow && tensor_statistics.is_float) {
678+
float threshold = dump_options.half_overflow_threshold;
679+
if (tensor_statistics.float_min < -threshold || tensor_statistics.float_max > threshold) {
680+
potential_half_overflow = true;
681+
}
682+
}
570683
}
571684
} else {
572685
std::cout << " is empty optional tensor.\n";
@@ -582,6 +695,10 @@ void DumpNodeOutputs(
582695
std::cout << "Output " << i << " is optional and was not produced.\n";
583696
}
584697

698+
if (check_half_overflow) {
699+
dump_analysis.Add(node.Name(), node.OpType(), potential_half_overflow);
700+
}
701+
585702
std::cout << std::endl;
586703
}
587704
}
@@ -590,8 +707,9 @@ void DumpNodeOutputs(
590707
const NodeDumpContext& dump_context,
591708
OpKernelContext& context,
592709
const Node& node,
593-
const SessionState& session_state) {
594-
DumpNodeOutputs(NodeDumpOptionsFromEnvironmentVariables(), dump_context, context, node, session_state);
710+
const SessionState& session_state,
711+
NodeDumpAnalysis& dump_analysis) {
712+
DumpNodeOutputs(NodeDumpOptionsFromEnvironmentVariables(), dump_context, context, node, session_state, dump_analysis);
595713
}
596714

597715
} // namespace utils

onnxruntime/core/framework/debug_node_inputs_outputs_utils.h

+32-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include "core/framework/op_kernel.h"
2020
#include "core/framework/session_state.h"
2121
#include "core/graph/graph.h"
22+
#include <unordered_set>
23+
#include <mutex>
24+
#include <string>
2225

2326
namespace onnxruntime {
2427
namespace utils {
@@ -37,6 +40,8 @@ constexpr const char* kDumpInputData = "ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA";
3740
constexpr const char* kDumpOutputData = "ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA";
3841
// Output statistics data like min, max, count of NaN, count of infinity etc.
3942
constexpr const char* kDumpStatisticsData = "ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA";
43+
// Output node name when any float input or output exceeds a threshold for float16 conversion overflow.
44+
constexpr const char* kDumpHalfConversionOverflow = "ORT_DEBUG_NODE_IO_DUMP_HALF_CONVERSION_OVERFLOW";
4045

4146
// specify a node name filter to limit the nodes that are dumped
4247
// see NodeDumpOptions::FilterOptions
@@ -61,6 +66,10 @@ constexpr const char* kSnippetThreshold = "ORT_DEBUG_NODE_IO_SNIPPET_THRESHOLD";
6166
// Number of array items in snippet at beginning and end of each dimension (default 3)
6267
constexpr const char* kSnippetEdgeItems = "ORT_DEBUG_NODE_IO_SNIPPET_EDGE_ITEMS";
6368

69+
// Threshold for float to float16 conversion overflow detection (default 50000).
70+
// It is a positive integer that <= 65504, and it is recommended to add some margin for new inputs.
71+
constexpr const char* kHalfOverflowThreshold = "ORT_DEBUG_NODE_IO_HALF_OVERFLOW_THRESHOLD";
72+
6473
} // namespace debug_node_inputs_outputs_env_vars
6574

6675
constexpr char kFilterPatternDelimiter = ';';
@@ -73,7 +82,8 @@ struct NodeDumpOptions {
7382
OutputData = 1 << 2,
7483
NodePlacement = 1 << 3,
7584
StatisticsData = 1 << 4,
76-
AllData = Shape | InputData | OutputData | NodePlacement | StatisticsData,
85+
HalfConversionOverflow = 1 << 5,
86+
AllData = Shape | InputData | OutputData | NodePlacement | StatisticsData | HalfConversionOverflow,
7787
};
7888

7989
// specifies the information to dump per node
@@ -117,6 +127,9 @@ struct NodeDumpOptions {
117127

118128
// Number of array items in snippet at beginning and end of each dimension for Stdout.
119129
int snippet_edge_items;
130+
131+
// Threshold for float16 conversion overflow.
132+
float half_overflow_threshold;
120133
};
121134

122135
struct NodeDumpContext {
@@ -126,6 +139,16 @@ struct NodeDumpContext {
126139
size_t program_counter;
127140
};
128141

142+
// A session level analysis of node dumps. It can be used to collect some statistics or analysis during node dumps.
143+
struct NodeDumpAnalysis {
144+
std::mutex set_mutex;
145+
std::unordered_set<std::string> half_overflow_nodes;
146+
std::unordered_map<std::string, int> half_overflow_ops;
147+
int counter{0};
148+
void Add(const std::string& node_name, const std::string& op_name, bool is_half_overflow);
149+
void PrintToStdOut(const std::string& model_path);
150+
};
151+
129152
// gets NodeDumpOptions instance configured from environment variable values
130153
const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables();
131154

@@ -135,27 +158,31 @@ void DumpNodeInputs(
135158
const NodeDumpContext& dump_context,
136159
const OpKernelContext& context,
137160
const Node& node,
138-
const SessionState& session_state);
161+
const SessionState& session_state,
162+
NodeDumpAnalysis& dump_analysis);
139163

140164
void DumpNodeInputs(
141165
const NodeDumpContext& dump_context,
142166
const OpKernelContext& context,
143167
const Node& node,
144-
const SessionState& session_state);
168+
const SessionState& session_state,
169+
NodeDumpAnalysis& dump_analysis);
145170

146171
// dumps outputs for a node
147172
void DumpNodeOutputs(
148173
const NodeDumpOptions& dump_options,
149174
const NodeDumpContext& dump_context,
150175
OpKernelContext& context,
151176
const Node& node,
152-
const SessionState& session_state);
177+
const SessionState& session_state,
178+
NodeDumpAnalysis& dump_analysis);
153179

154180
void DumpNodeOutputs(
155181
const NodeDumpContext& dump_context,
156182
OpKernelContext& context,
157183
const Node& node,
158-
const SessionState& session_state);
184+
const SessionState& session_state,
185+
NodeDumpAnalysis& dump_analysis);
159186

160187
} // namespace utils
161188
} // namespace onnxruntime

0 commit comments

Comments
 (0)