Skip to content

Commit e9dfa4d

Browse files
committed
Add Graph::import_graph_def_with_return_outputs
1 parent 76bab31 commit e9dfa4d

File tree

3 files changed

+79
-39
lines changed

3 files changed

+79
-39
lines changed

src/expr.rs

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub trait ExprImpl<T: TensorType>: Display + Debug {
9494
/// rather than creating child operations itself.
9595
fn create_operation(&self,
9696
graph: &mut Graph,
97-
children: &[Rc<Operation>],
97+
children: &[Operation],
9898
id_gen: &mut FnMut() -> String)
9999
-> Result<Operation, Status>;
100100

@@ -113,7 +113,7 @@ impl<T: TensorType> ExprImpl<T> for T {
113113

114114
fn create_operation(&self,
115115
graph: &mut Graph,
116-
_children: &[Rc<Operation>],
116+
_children: &[Operation],
117117
id_gen: &mut FnMut() -> String)
118118
-> Result<Operation, Status> {
119119
let mut nd = try!(graph.new_operation("Const", &id_gen()));
@@ -197,11 +197,11 @@ macro_rules! impl_bin_op {
197197
vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
198198
}
199199

200-
fn create_operation(&self, graph: &mut Graph, children: &[Rc<Operation>],
200+
fn create_operation(&self, graph: &mut Graph, children: &[Operation],
201201
id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
202202
let mut nd = try!(graph.new_operation($tf_op, &id_gen()));
203-
nd.add_input(Output {operation: &children[0], index: 0});
204-
nd.add_input(Output {operation: &children[1], index: 0});
203+
nd.add_input(Output {operation: children[0].clone(), index: 0});
204+
nd.add_input(Output {operation: children[1].clone(), index: 0});
205205
nd.finish()
206206
}
207207

@@ -287,18 +287,18 @@ impl<T: TensorType> ExprImpl<T> for TruncateDiv<T> {
287287

288288
fn create_operation(&self,
289289
graph: &mut Graph,
290-
children: &[Rc<Operation>],
290+
children: &[Operation],
291291
id_gen: &mut FnMut() -> String)
292292
-> Result<Operation, Status> {
293293
let mut nd = try!(graph.new_operation("TruncateDiv", &id_gen()));
294294
nd.add_input(Output {
295-
operation: &children[0],
296-
index: 0,
297-
});
295+
operation: children[0].clone(),
296+
index: 0,
297+
});
298298
nd.add_input(Output {
299-
operation: &children[1],
300-
index: 0,
301-
});
299+
operation: children[1].clone(),
300+
index: 0,
301+
});
302302
nd.finish()
303303
}
304304

@@ -351,14 +351,14 @@ impl<T: TensorType> ExprImpl<T> for Neg<T> {
351351

352352
fn create_operation(&self,
353353
graph: &mut Graph,
354-
children: &[Rc<Operation>],
354+
children: &[Operation],
355355
id_gen: &mut FnMut() -> String)
356356
-> Result<Operation, Status> {
357357
let mut nd = try!(graph.new_operation("Neg", &id_gen()));
358358
nd.add_input(Output {
359-
operation: &children[0],
360-
index: 0,
361-
});
359+
operation: children[0].clone(),
360+
index: 0,
361+
});
362362
nd.finish()
363363
}
364364

@@ -409,7 +409,7 @@ impl<T: TensorType> ExprImpl<T> for Variable<T> {
409409

410410
fn create_operation(&self,
411411
graph: &mut Graph,
412-
_children: &[Rc<Operation>],
412+
_children: &[Operation],
413413
_id_gen: &mut FnMut() -> String)
414414
-> Result<Operation, Status> {
415415
let mut nd = try!(graph.new_operation("Variable", &self.name));
@@ -469,7 +469,7 @@ impl<T: TensorType> ExprImpl<T> for Placeholder<T> {
469469

470470
fn create_operation(&self,
471471
graph: &mut Graph,
472-
_children: &[Rc<Operation>],
472+
_children: &[Operation],
473473
_id_gen: &mut FnMut() -> String)
474474
-> Result<Operation, Status> {
475475
let mut nd = try!(graph.new_operation("Placeholder", &self.name));
@@ -523,18 +523,18 @@ impl<T: TensorType> ExprImpl<T> for Assign<T> {
523523

524524
fn create_operation(&self,
525525
graph: &mut Graph,
526-
children: &[Rc<Operation>],
526+
children: &[Operation],
527527
id_gen: &mut FnMut() -> String)
528528
-> Result<Operation, Status> {
529529
let mut nd = try!(graph.new_operation("Assign", &id_gen()));
530530
nd.add_input(Output {
531-
operation: &children[0],
532-
index: 0,
533-
});
531+
operation: children[0].clone(),
532+
index: 0,
533+
});
534534
nd.add_input(Output {
535-
operation: &children[1],
536-
index: 0,
537-
});
535+
operation: children[1].clone(),
536+
index: 0,
537+
});
538538
nd.finish()
539539
}
540540

@@ -563,7 +563,7 @@ pub trait AnyExpr: Debug {
563563
/// rather than creating child operations itself.
564564
fn create_operation(&self,
565565
graph: &mut Graph,
566-
children: &[Rc<Operation>],
566+
children: &[Operation],
567567
id_gen: &mut FnMut() -> String)
568568
-> Result<Operation, Status>;
569569

@@ -586,7 +586,7 @@ impl<T: TensorType> AnyExpr for Expr<T> {
586586

587587
fn create_operation(&self,
588588
graph: &mut Graph,
589-
children: &[Rc<Operation>],
589+
children: &[Operation],
590590
id_gen: &mut FnMut() -> String)
591591
-> Result<Operation, Status> {
592592
self.expr.create_operation(graph, children, id_gen)
@@ -620,7 +620,7 @@ impl Hash for Key {
620620
#[derive(Debug)]
621621
pub struct Compiler<'l> {
622622
graph: &'l mut Graph,
623-
operations: HashMap<Key, Rc<Operation>>,
623+
operations: HashMap<Key, Operation>,
624624
next_id: i32,
625625
}
626626

@@ -635,12 +635,12 @@ impl<'l> Compiler<'l> {
635635
}
636636

637637
/// Compiles the expression.
638-
pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Rc<Operation>, Status> {
638+
pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Operation, Status> {
639639
self.compile_any(Box::new(expr))
640640
}
641641

642642
/// Compiles the expression.
643-
pub fn compile_any(&mut self, expr: Box<AnyExpr>) -> Result<Rc<Operation>, Status> {
643+
pub fn compile_any(&mut self, expr: Box<AnyExpr>) -> Result<Operation, Status> {
644644
let mut child_operations = vec![];
645645
for child in expr.children() {
646646
let key = Key(child.clone_box());
@@ -661,7 +661,7 @@ impl<'l> Compiler<'l> {
661661
id
662662
});
663663
self.next_id = next_id;
664-
let operation = Rc::new(try!(result));
664+
let operation = result?;
665665
self.operations.insert(Key(expr), operation.clone());
666666
Ok(operation)
667667
}

src/graph.rs

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ impl Graph {
248248
/// * `output` is not in `graph`.
249249
pub fn tensor_shape(&self, output: Output) -> Result<Shape> {
250250
let mut status = Status::new();
251-
let n = try!(self.num_dims(output));
251+
let n = try!(self.num_dims(output.clone()));
252252
if n == -1 {
253253
return Ok(Shape(None));
254254
}
@@ -283,6 +283,36 @@ impl Graph {
283283
status.into_result()
284284
}
285285
}
286+
287+
/// Import the graph serialized in `graph_def`.
288+
///
289+
/// `num_return_outputs` must be the number of return outputs added (i.e. the
290+
/// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If
291+
/// `num_return_outputs` is non-zero, `return_outputs` must be of length
292+
/// `num_return_outputs`. Otherwise it can be null.
293+
pub fn import_graph_def_with_return_outputs(&mut self,
294+
graph_def: &[u8],
295+
options: &ImportGraphDefOptions)
296+
-> Result<Vec<Output>> {
297+
let buf = Buffer::from(graph_def);
298+
let mut status = Status::new();
299+
let mut c_return_outputs = Vec::new();
300+
let n = options.num_return_outputs();
301+
unsafe {
302+
c_return_outputs.set_len(n);
303+
tf::TF_GraphImportGraphDefWithReturnOutputs(self.gimpl.inner,
304+
buf.inner(),
305+
options.inner,
306+
c_return_outputs.as_mut_ptr(),
307+
n as c_int,
308+
status.inner());
309+
}
310+
status.into_result()?;
311+
Ok(c_return_outputs
312+
.iter()
313+
.map(|x| Output::from_c(self, x))
314+
.collect())
315+
}
286316
}
287317

288318
impl GraphTrait for Graph {
@@ -324,7 +354,7 @@ impl<'a> Iterator for OperationIter<'a> {
324354

325355
/// An `Operation` is a node in a `Graph`.
326356
/// It is a computation which accepts inputs and produces outputs.
327-
#[derive(Debug)]
357+
#[derive(Debug,Clone)]
328358
pub struct Operation {
329359
inner: *mut tf::TF_Operation,
330360
gimpl: Arc<GraphImpl>,
@@ -556,22 +586,32 @@ impl<'a> Input<'a> {
556586

557587
/// A `Output` is one end of a graph edge.
558588
/// It holds an operation and an index into the outputs of that operation.
559-
#[derive(Debug,Copy,Clone)]
560-
pub struct Output<'a> {
589+
#[derive(Debug,Clone)]
590+
pub struct Output {
561591
/// Operation the edge connects to.
562-
pub operation: &'a Operation,
592+
pub operation: Operation,
563593

564594
/// Index into either the outputs of the operation.
565595
pub index: c_int,
566596
}
567597

568-
impl<'a> Output<'a> {
598+
impl Output {
569599
fn to_c(&self) -> tf::TF_Output {
570600
tf::TF_Output {
571601
oper: self.operation.inner,
572602
index: self.index,
573603
}
574604
}
605+
606+
fn from_c(graph: &Graph, output: &tf::TF_Output) -> Self {
607+
Output {
608+
operation: Operation {
609+
inner: output.oper,
610+
gimpl: graph.gimpl.clone(),
611+
},
612+
index: output.index,
613+
}
614+
}
575615
}
576616

577617
////////////////////////

src/session.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,11 @@ mod tests {
301301
let y = {
302302
let mut nd = g.new_operation("Mul", "y").unwrap();
303303
nd.add_input(Output {
304-
operation: &two,
304+
operation: two,
305305
index: 0,
306306
});
307307
nd.add_input(Output {
308-
operation: &x,
308+
operation: x.clone(),
309309
index: 0,
310310
});
311311
nd.finish().unwrap()

0 commit comments

Comments
 (0)