Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 135 additions & 3 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/optimizer/utils.h"
#include "float.h"
#include <algorithm>
#include <cstdint>

Check warning on line 9 in onnxruntime/core/optimizer/layer_norm_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: layer_norm_fusion.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/optimizer/layer_norm_fusion.cc:9: Found C++ system header after other header. Should be: layer_norm_fusion.h, c system, c++ system, other. [build/include_order] [4]
#include <deque>

using namespace ONNX_NAMESPACE;
Expand Down Expand Up @@ -78,6 +79,98 @@
return axes_values;
};

static bool TryGetScalarInitializerAsDouble(const Graph& graph, const NodeArg& node_arg, double& value) {
const auto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name());
if (tensor_proto == nullptr) {
return false;
}

Initializer initializer{graph, *tensor_proto, graph.ModelPath()};
if (initializer.size() != 1) {
return false;
}

switch (tensor_proto->data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
value = static_cast<double>(initializer.data<float>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
value = static_cast<double>(initializer.data<MLFloat16>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
value = initializer.data<double>()[0];
return true;
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
value = static_cast<double>(initializer.data<BFloat16>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
value = static_cast<double>(initializer.data<int8_t>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_INT16:
value = static_cast<double>(initializer.data<int16_t>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
value = static_cast<double>(initializer.data<int32_t>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
value = static_cast<double>(initializer.data<int64_t>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
value = static_cast<double>(initializer.data<uint8_t>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_UINT16:
value = static_cast<double>(initializer.data<uint16_t>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
value = static_cast<double>(initializer.data<uint32_t>()[0]);
return true;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
value = static_cast<double>(initializer.data<uint64_t>()[0]);
return true;
default:
Comment thread
the0cp marked this conversation as resolved.
return false;
}
}

static bool IsPowExponentTwo(const Graph& graph, const Node& pow_node) {
const auto& pow_inputs = pow_node.InputDefs();
if (pow_inputs.size() < 2 || pow_inputs[1] == nullptr) {
return false;
}

double exponent_value = 0.0;
if (TryGetScalarInitializerAsDouble(graph, *pow_inputs[1], exponent_value)) {
return exponent_value == 2.0;
}

const Node* exponent_input_node = graph_utils::GetInputNode(pow_node, 1);
if (exponent_input_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*exponent_input_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) ||
exponent_input_node->InputDefs().empty() || exponent_input_node->InputDefs()[0] == nullptr) {
return false;
}

return TryGetScalarInitializerAsDouble(graph, *exponent_input_node->InputDefs()[0], exponent_value) &&
exponent_value == 2.0;
}

static const NodeArg* GetOtherAddInput(const Node& add_node, const NodeArg& known_input) {
const auto& add_inputs = add_node.InputDefs();
if (add_inputs.size() < 2) {
return nullptr;
}

if (add_inputs[0] != nullptr && add_inputs[0]->Name() == known_input.Name()) {
return add_inputs[1];
}

if (add_inputs[1] != nullptr && add_inputs[1]->Name() == known_input.Name()) {
return add_inputs[0];
}

return nullptr;
}

/**
Layer Normalization will fuse LayerNormalization into one node :
+---------------------+
Expand Down Expand Up @@ -555,10 +648,12 @@
Node& pow_node = *p_pow;
ORT_RETURN_IF_ERROR(Recurse(pow_node, modified, graph_level, logger));

// Only the Pow base/output type must be supported by SimplifiedLayerNormalization. The exponent can
// be an integer scalar 2 per the Pow schema and is validated separately.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13, 15}) ||
!graph_utils::IsSupportedProvider(pow_node, GetCompatibleExecutionProviders()) ||
!optimizer_utils::CheckOutputEdges(graph, pow_node, 1) || graph.NodeProducesGraphOutput(pow_node) ||
!IsSupportedDataType(pow_node)) {
!IsSupportedDataType(pow_node, 1) || !IsPowExponentTwo(graph, pow_node)) {
continue;
}
nodes_to_remove.push_back(pow_node);
Expand Down Expand Up @@ -590,6 +685,11 @@
}
nodes_to_remove.push_back(add_node);

const NodeArg* epsilon_input = GetOtherAddInput(add_node, *reduce_mean_node.MutableOutputDefs()[0]);
if (epsilon_input == nullptr) {
continue;
}

const Node* p_sqrt = graph_utils::FirstChildByType(add_node, "Sqrt");
if (p_sqrt == nullptr) {
continue;
Expand Down Expand Up @@ -728,7 +828,7 @@

// Get constant "epsilon" from "Add" node if available. Else, default value will be used.
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, add_node.MutableInputDefs()[1]->Name());
graph_utils::GetConstantInitializer(graph, epsilon_input->Name());
if (tensor_proto != nullptr && tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
Initializer initializer{graph, *tensor_proto, graph.ModelPath()};
// epsilon must be a scalar/1-element tensor; fall back to default otherwise.
Expand All @@ -752,11 +852,43 @@
// Assign provider to this new node. Provider should be same as the provider for old node.
layer_norm_node.SetExecutionProviderType(reduce_mean_node.GetExecutionProviderType());

// move input edges to add (first in list) across to the layer_norm_node.
// FinalizeNodeFusion moves every input edge of the first node by NodeArg name. Disconnect inputs
// that the replacement does not use, such as a Pow exponent produced by a mixed-precision Cast.
// Keep track of their producers so they can be removed if this fusion makes them dead.
InlinedHashSet<NodeIndex> unused_input_node_indices;
// Pow may follow a leading Cast and not be the first node finalized. Track its exponent producer
// explicitly because removing Pow will disconnect that edge without moving it to the replacement.
if (const Node* pow_exponent_input_node = graph_utils::GetInputNode(pow_node, 1)) {
unused_input_node_indices.insert(pow_exponent_input_node->Index());
}

const auto first_node_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(nodes_to_remove.front().get());
for (const auto& input_edge : first_node_input_edges) {
const bool is_replacement_input =
std::any_of(layer_norm_input_defs.cbegin(), layer_norm_input_defs.cend(),
[&input_edge](const NodeArg* input) { return input->Name() == input_edge.arg_name; });
if (!is_replacement_input) {
unused_input_node_indices.insert(input_edge.src_node);
graph.RemoveEdge(input_edge.src_node, input_edge.dst_node,
input_edge.src_arg_index, input_edge.dst_arg_index);
}
}

// move input edges from the first node in nodes_to_remove to layer_norm_node.
// move output definitions and output edges from mul_node (last in list) to layer_norm_node.
// remove all the other nodes.
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node);

// Remove unused input producers and any newly dead upstream nodes only after their final consumer is
// fused. A producer can be shared by multiple matched subgraphs, so it must remain while it still has users.
for (const NodeIndex unused_input_node_index : unused_input_node_indices) {
Comment thread
the0cp marked this conversation as resolved.
Node* unused_input_node = graph.GetNode(unused_input_node_index);
if (unused_input_node != nullptr && unused_input_node->GetOutputEdgesCount() == 0 &&
!graph.NodeProducesGraphOutput(*unused_input_node)) {
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *unused_input_node);
}
}

#ifdef ENABLE_TRAINING_CORE
// add one extra output def, so we have 2 output defs that match what gradient builder expected
layer_norm_node.MutableOutputDefs().push_back(
Expand Down
Loading
Loading