Skip to content

Commit 4faee2e

Browse files
Fix issue in constant-propagation inside function subgraph (microsoft#16330)
### Description The SequenceMap function-op has a graph-attribute. ORT's constant-folding optimization may identify constant-expressions inside the subgraph and promote them to constants, stored as initializers in the main graph. When it does this, the optimization updates the subgraph to remove the corresponding nodes. When we expand a SequenceMap node by inlining its function-expansion, we need to use this updated subgraph. However, the existing code uses the original graph-attribute (GraphProto), instead of regenerating it from the modified subgraph. This results in producing a graph with duplicate definitions for the constant-folded variable, resulting in an error during graph-resolve. This PR fixes this issue (just a single line fix), and adds a test-case to cover this scenario. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Suryaprakash Shanmugam <[email protected]>
1 parent ea43671 commit 4faee2e

File tree

4 files changed

+33
-4
lines changed

4 files changed

+33
-4
lines changed

onnxruntime/core/graph/graph.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
585585
// Check if this node has a schema defined function proto.
586586
if (op_->HasContextDependentFunction()) {
587587
NodeProto node_proto;
588-
ToProto(node_proto);
588+
ToProto(node_proto, true);
589589
std::vector<TypeProto> input_types;
590590
for (size_t i = 0, n = InputDefs().size(); i < n; i++) {
591591
auto p_node_arg = InputDefs().at(i);

onnxruntime/core/providers/openvino/ov_versions/capability.cc

+7-3
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,15 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
7575
}
7676

7777
const auto& nodes = graph_viewer_.GetNodesInTopologicalOrder();
78+
79+
// Handle cases where lone, reoccuring Ops in smaller models cannot be supported in OpenVINO
80+
// If only a node of the same lone,unsupported type is present, then do not proceed with the subgraph
81+
const auto& node = graph_viewer_.GetNode(nodes[0]);
82+
if (data_ops_->IsOpSupportedOnlyInModel(node->OpType()))
83+
return result;
84+
7885
// Nodes that work well in models but not as a single node
7986
if (nodes.size() == 1) {
80-
const auto& node = graph_viewer_.GetNode(nodes[0]);
81-
if (data_ops_->IsOpSupportedOnlyInModel(node->OpType()))
82-
return result;
8387
// If reshape is not an intermediate node, shape needs to be an initializer
8488
if (data_ops_->SpecialConditionForClusterSizeOne(ng_required_initializers, node)) {
8589
return result;

onnxruntime/core/providers/openvino/ov_versions/data_ops.cc

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace openvino_ep {
3030

3131
// Ops which are supported only in models(as intermediate nodes) and not in unit tests
3232
std::set<std::string> ops_supported_only_in_model = {
33+
"Add",
3334
"Cast",
3435
"Concat",
3536
"ConstantOfShape",

onnxruntime/test/framework/function_test.cc

+24
Original file line numberDiff line numberDiff line change
@@ -506,5 +506,29 @@ TEST(FunctionTest, UnusedFunctionInputs) {
506506
Check(code, "x", {1.0, 2.0, 3.0}, "y", {1.0, 4.0, 9.0});
507507
}
508508

509+
// Test constant-folding inside a sub-graph is handled correctly
510+
// for functions that are inlined.
511+
TEST(FunctionTest, ConstantFoldingInSubGraph) {
512+
const char* code = R"(
513+
<ir_version: 8, opset_import: [ "" : 17 ]>
514+
agraph (float[N] X) => (float[M] Y) {
515+
seq1 = SequenceConstruct(X, X, X)
516+
seq2 = SequenceMap (seq1) <body =
517+
add1 (float[K] Z) => (float[K] W) {
518+
C1 = Constant <value = float {1.0}> ()
519+
C2 = Constant <value = float {1.0}> ()
520+
# C is a constant, which will be constant-folded into an initializer out of the sub-graph.
521+
C = Add (C1, C2)
522+
# After optimization, only following Add will be left in this sub-graph.
523+
W = Add (Z, C)
524+
}
525+
>
526+
Y = ConcatFromSequence <axis=0> (seq2)
527+
}
528+
)";
529+
530+
Check(code, "X", {1.0, 2.0, 3.0}, "Y", {3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0});
531+
}
532+
509533
} // namespace test
510534
} // namespace onnxruntime

0 commit comments

Comments
 (0)