@@ -66,6 +66,68 @@ impl ImportGraphDefOptions {
66
66
}
67
67
Ok ( ( ) )
68
68
}
69
+
70
+ /// Set any imported nodes with input `src_name:src_index` to have that input
71
+ /// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
72
+ /// `dst` references a node already existing in the graph being imported into.
73
+ pub fn add_input_mapping ( & mut self ,
74
+ src_name : & str ,
75
+ src_index : usize ,
76
+ dst : & Output )
77
+ -> std:: result:: Result < ( ) , NulError > {
78
+ let s = CString :: new ( src_name) ?;
79
+ unsafe {
80
+ tf:: TF_ImportGraphDefOptionsAddInputMapping ( self . inner ,
81
+ s. as_ptr ( ) ,
82
+ src_index as c_int ,
83
+ dst. to_c ( ) ) ;
84
+ }
85
+ Ok ( ( ) )
86
+ }
87
+
88
+ /// Set any imported nodes with control input `src_name` to have that input
89
+ /// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
90
+ /// `dst` references an operation already existing in the graph being imported
91
+ /// into.
92
+ pub fn remap_control_dependency ( & mut self ,
93
+ src_name : & str ,
94
+ dst : & Operation )
95
+ -> std:: result:: Result < ( ) , NulError > {
96
+ let s = CString :: new ( src_name) ?;
97
+ unsafe {
98
+ tf:: TF_GraphImportGraphDefOptionsRemapControlDependency ( self . inner ,
99
+ s. as_ptr ( ) ,
100
+ dst. inner ) ;
101
+ }
102
+ Ok ( ( ) )
103
+ }
104
+
105
+ /// Cause the imported graph to have a control dependency on `oper`. `oper`
106
+ /// should exist in the graph being imported into.
107
+ pub fn add_control_dependency ( & mut self , oper : & Operation ) {
108
+ unsafe {
109
+ tf:: TF_ImportGraphDefOptionsAddControlDependency ( self . inner , oper. inner ) ;
110
+ }
111
+ }
112
+
113
+ /// Add an output in `graph_def` to be returned via the `return_outputs` output
114
+ /// parameter of `import_graph_def()`. If the output is remapped via an input
115
+ /// mapping, the corresponding existing tensor in `graph` will be returned.
116
+ pub fn add_return_output ( & mut self ,
117
+ oper_name : & str ,
118
+ index : usize )
119
+ -> std:: result:: Result < ( ) , NulError > {
120
+ let s = CString :: new ( oper_name) ?;
121
+ unsafe {
122
+ tf:: TF_ImportGraphDefOptionsAddReturnOutput ( self . inner , s. as_ptr ( ) , index as c_int ) ;
123
+ }
124
+ Ok ( ( ) )
125
+ }
126
+
127
+ /// Returns the number of return outputs added via `add_return_output()`.
128
+ pub fn num_return_outputs ( & self ) -> usize {
129
+ unsafe { tf:: TF_ImportGraphDefOptionsNumReturnOutputs ( self . inner ) as usize }
130
+ }
69
131
}
70
132
71
133
////////////////////////
0 commit comments