Skip to content

Commit 76bab31

Browse files
committed
Add methods to ImportGraphDefOptions from TensorFlow 1.1
1 parent f5d83c3 commit 76bab31

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

src/graph.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,68 @@ impl ImportGraphDefOptions {
6666
}
6767
Ok(())
6868
}
69+
70+
/// Set any imported nodes with input `src_name:src_index` to have that input
71+
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
72+
/// `dst` references a node already existing in the graph being imported into.
73+
pub fn add_input_mapping(&mut self,
74+
src_name: &str,
75+
src_index: usize,
76+
dst: &Output)
77+
-> std::result::Result<(), NulError> {
78+
let s = CString::new(src_name)?;
79+
unsafe {
80+
tf::TF_ImportGraphDefOptionsAddInputMapping(self.inner,
81+
s.as_ptr(),
82+
src_index as c_int,
83+
dst.to_c());
84+
}
85+
Ok(())
86+
}
87+
88+
/// Set any imported nodes with control input `src_name` to have that input
89+
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
90+
/// `dst` references an operation already existing in the graph being imported
91+
/// into.
92+
pub fn remap_control_dependency(&mut self,
93+
src_name: &str,
94+
dst: &Operation)
95+
-> std::result::Result<(), NulError> {
96+
let s = CString::new(src_name)?;
97+
unsafe {
98+
tf::TF_GraphImportGraphDefOptionsRemapControlDependency(self.inner,
99+
s.as_ptr(),
100+
dst.inner);
101+
}
102+
Ok(())
103+
}
104+
105+
/// Cause the imported graph to have a control dependency on `oper`. `oper`
106+
/// should exist in the graph being imported into.
107+
pub fn add_control_dependency(&mut self, oper: &Operation) {
108+
unsafe {
109+
tf::TF_ImportGraphDefOptionsAddControlDependency(self.inner, oper.inner);
110+
}
111+
}
112+
113+
/// Add an output in `graph_def` to be returned via the `return_outputs` output
114+
/// parameter of `import_graph_def()`. If the output is remapped via an input
115+
/// mapping, the corresponding existing tensor in `graph` will be returned.
116+
pub fn add_return_output(&mut self,
117+
oper_name: &str,
118+
index: usize)
119+
-> std::result::Result<(), NulError> {
120+
let s = CString::new(oper_name)?;
121+
unsafe {
122+
tf::TF_ImportGraphDefOptionsAddReturnOutput(self.inner, s.as_ptr(), index as c_int);
123+
}
124+
Ok(())
125+
}
126+
127+
/// Returns the number of return outputs added via `add_return_output()`.
128+
pub fn num_return_outputs(&self) -> usize {
129+
unsafe { tf::TF_ImportGraphDefOptionsNumReturnOutputs(self.inner) as usize }
130+
}
69131
}
70132

71133
////////////////////////

0 commit comments

Comments
 (0)