27
27
#include " core/common/span_utils.h"
28
28
#include " core/common/status.h"
29
29
#include " core/common/logging/logging.h"
30
+ #include " core/framework/ort_value.h"
30
31
#include " core/framework/prepacked_weights_container.h"
31
32
#include " core/graph/onnx_protobuf.h"
32
33
#include " core/graph/basic_types.h"
39
40
#include " core/graph/node_arg.h"
40
41
#include " core/graph/ort_format_load_options.h"
41
42
43
+ // Type from Model Editor API in ORT C API so can't be in a namespace
44
+ struct OrtGraph;
45
+
42
46
namespace onnxruntime {
43
47
class Graph;
44
48
struct IndexedSubGraph;
@@ -763,6 +767,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
763
767
*/
764
768
bool GetInitializedTensor (const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const ;
765
769
770
+ /* * Populate `value` if an externally allocated OrtValue exists for an initializer with the given name.
771
+ */
772
+ bool GetOrtValueInitializer (const std::string& name, OrtValue& value) const ;
773
+
766
774
/* * Gets all the initializer tensors in this Graph. */
767
775
const InitializedTensorSet& GetAllInitializedTensors () const noexcept { return name_to_initial_tensor_; }
768
776
@@ -1430,6 +1438,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
1430
1438
const OrtFormatLoadOptions& load_options,
1431
1439
const logging::Logger& logger, std::unique_ptr<Graph>& graph);
1432
1440
1441
+ static Status LoadFromModelEditorApiModel (const OrtGraph& api_graph,
1442
+ const Model& owning_model,
1443
+ const std::unordered_map<std::string, int >& domain_to_version,
1444
+ IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
1445
+ bool strict_shape_type_inference,
1446
+ const logging::Logger& logger,
1447
+ std::unique_ptr<Graph>& graph);
1448
+
1449
+ Status UpdateUsingModelEditorApiModel (const OrtModel& api_model);
1450
+
1433
1451
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1434
1452
const RuntimeOptimizationRecordContainer& RuntimeOptimizations () const {
1435
1453
return runtime_optimizations_;
@@ -1630,7 +1648,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
1630
1648
// Implementation for initializer replacement
1631
1649
Status ReplaceInitializedTensorImpl (ONNX_NAMESPACE::TensorProto new_initializer, bool is_external);
1632
1650
1633
- std::vector<NodeArg*> CreateNodeArgs (const google::protobuf::RepeatedPtrField<std::string>& names,
1651
+ template <typename StringRange> // range-initializer returning std::string
1652
+ std::vector<NodeArg*> CreateNodeArgs (const StringRange& names,
1634
1653
const ArgNameToTypeMap& name_to_type_map);
1635
1654
1636
1655
void ToGraphProtoInternal (ONNX_NAMESPACE::GraphProto& graph_proto) const ;
@@ -1694,6 +1713,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
1694
1713
return nodes_[node_index].get ();
1695
1714
}
1696
1715
1716
+ Status LoadFromModelEditorApiModel (const OrtGraph& api_graph, bool updating_existing_graph = false );
1717
+
1697
1718
const Model& owning_model_;
1698
1719
1699
1720
// GraphProto to store name, version, initializer.
@@ -1708,6 +1729,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
1708
1729
1709
1730
InitializedTensorSet name_to_initial_tensor_;
1710
1731
1732
+ // Initializers that are external to the Graph.
1733
+ // e.g. created from existing memory using CreateTensorWithDataAndDeleterAsOrtValue in the ORT API.
1734
+ // As we need to convert to TensorProto for the optimizers to work and keep the deleter information we store them
1735
+ // in the Graph instance and retrieve during session state finalization.
1736
+ std::unordered_map<std::string, OrtValue> ortvalue_initializers_;
1737
+
1711
1738
std::unordered_set<std::reference_wrapper<const std::string>,
1712
1739
std::hash<std::string>, std::equal_to<std::string>>
1713
1740
sparse_tensor_names_;
@@ -1744,6 +1771,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
1744
1771
// in some case, a fused sub-graph will happens multiple times in one model, we use a map
1745
1772
// to store reusable-schema in lookup.
1746
1773
InlinedHashMap<std::string, std::reference_wrapper<ONNX_NAMESPACE::OpSchema>> reusable_fused_schema_map_;
1774
+
1747
1775
#endif // !defined(ORT_MINIMAL_BUILD)
1748
1776
1749
1777
// Graph nodes.
@@ -1806,7 +1834,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
1806
1834
std::unordered_map<std::string, std::unordered_set<NodeIndex>> node_arg_to_consumer_nodes_;
1807
1835
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
1808
1836
1809
- const std::unordered_map<std::string, int > domain_to_version_;
1837
+ std::unordered_map<std::string, int > domain_to_version_;
1810
1838
1811
1839
// Model IR version.
1812
1840
Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION};
0 commit comments