Skip to content

Commit 701d5ec

Browse files
authored
Merge pull request #84 from adamcrume/master
Add graph def import options and clean up a few things
2 parents 8efd409 + c5b31d8 commit 701d5ec

File tree

9 files changed

+170
-74
lines changed

9 files changed

+170
-74
lines changed

Cargo.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,16 @@ random = "0.12"
2121
tensorflow_unstable = []
2222

2323
[workspace]
24+
25+
[[example]]
26+
name = "addition"
27+
28+
[[example]]
29+
name = "expressions"
30+
required-features = ["tensorflow_unstable"]
31+
32+
[[example]]
33+
name = "regression"
34+
35+
[[example]]
36+
name = "regression_savedmodel"

examples/addition.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ fn run() -> Result<(), Box<Error>> {
5959
try!(session.run(&mut step));
6060

6161
// Check our results.
62-
let z_res: i32 = try!(step.take_output(z)).data()[0];
62+
let z_res: i32 = step.take_output(z)?[0];
6363
println!("{:?}", z_res);
6464

6565
Ok(())

examples/expressions.rs

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ use std::error::Error;
55
use std::result::Result;
66
use std::process::exit;
77
use tensorflow::Code;
8-
// Workaround for https://github.com/rust-lang/cargo/issues/1570
9-
#[cfg(feature = "tensorflow_unstable")]
108
use tensorflow::expr::{Compiler, Placeholder};
119
use tensorflow::Graph;
1210
use tensorflow::Session;
@@ -60,16 +58,6 @@ impl Checker {
6058
}
6159
}
6260

63-
// Workaround for https://github.com/rust-lang/cargo/issues/1570
64-
#[cfg(not(feature = "tensorflow_unstable"))]
65-
fn run() -> Result<(), Box<Error>> {
66-
println!("examples/expressions.rs is disabled because the `tensorflow_unstable` feature is \
67-
not enabled!");
68-
Ok(())
69-
}
70-
71-
// Workaround for https://github.com/rust-lang/cargo/issues/1570
72-
#[cfg(feature = "tensorflow_unstable")]
7361
fn run() -> Result<(), Box<Error>> {
7462
// Build the graph
7563
let mut g = Graph::new();
@@ -101,9 +89,8 @@ fn run() -> Result<(), Box<Error>> {
10189

10290
// Check our results.
10391
let output_tensor = try!(step.take_output::<f32>(output_token));
104-
let data = output_tensor.data();
10592
let mut checker = Checker::new(1e-3);
106-
checker.check("data[0]", 5.0, data[0]);
107-
checker.check("data[1]", 7.0, data[1]);
93+
checker.check("output_tensor[0]", 5.0, output_tensor[0]);
94+
checker.check("output_tensor[1]", 7.0, output_tensor[1]);
10895
checker.result()
10996
}

examples/regression.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ fn run() -> Result<(), Box<Error>> {
8989
try!(session.run(&mut output_step));
9090

9191
// Check our results.
92-
let w_hat: f32 = try!(output_step.take_output(w_ix)).data()[0];
93-
let b_hat: f32 = try!(output_step.take_output(b_ix)).data()[0];
92+
let w_hat: f32 = output_step.take_output(w_ix)?[0];
93+
let b_hat: f32 = output_step.take_output(b_ix)?[0];
9494
println!("Checking w: expected {}, got {}. {}",
9595
w,
9696
w_hat,

examples/regression_savedmodel.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ fn run() -> Result<(), Box<Error>> {
7878
try!(session.run(&mut output_step));
7979

8080
// Check our results.
81-
let w_hat: f32 = try!(output_step.take_output(w_ix)).data()[0];
82-
let b_hat: f32 = try!(output_step.take_output(b_ix)).data()[0];
81+
let w_hat: f32 = output_step.take_output(w_ix)?[0];
82+
let b_hat: f32 = output_step.take_output(b_ix)?[0];
8383
println!("Checking w: expected {}, got {}. {}",
8484
w,
8585
w_hat,

src/expr.rs

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub trait ExprImpl<T: TensorType>: Display + Debug {
9494
/// rather than creating child operations itself.
9595
fn create_operation(&self,
9696
graph: &mut Graph,
97-
children: &[Rc<Operation>],
97+
children: &[Operation],
9898
id_gen: &mut FnMut() -> String)
9999
-> Result<Operation, Status>;
100100

@@ -113,7 +113,7 @@ impl<T: TensorType> ExprImpl<T> for T {
113113

114114
fn create_operation(&self,
115115
graph: &mut Graph,
116-
_children: &[Rc<Operation>],
116+
_children: &[Operation],
117117
id_gen: &mut FnMut() -> String)
118118
-> Result<Operation, Status> {
119119
let mut nd = try!(graph.new_operation("Const", &id_gen()));
@@ -197,11 +197,11 @@ macro_rules! impl_bin_op {
197197
vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
198198
}
199199

200-
fn create_operation(&self, graph: &mut Graph, children: &[Rc<Operation>],
200+
fn create_operation(&self, graph: &mut Graph, children: &[Operation],
201201
id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
202202
let mut nd = try!(graph.new_operation($tf_op, &id_gen()));
203-
nd.add_input(Output {operation: &children[0], index: 0});
204-
nd.add_input(Output {operation: &children[1], index: 0});
203+
nd.add_input(Output {operation: children[0].clone(), index: 0});
204+
nd.add_input(Output {operation: children[1].clone(), index: 0});
205205
nd.finish()
206206
}
207207

@@ -287,18 +287,18 @@ impl<T: TensorType> ExprImpl<T> for TruncateDiv<T> {
287287

288288
fn create_operation(&self,
289289
graph: &mut Graph,
290-
children: &[Rc<Operation>],
290+
children: &[Operation],
291291
id_gen: &mut FnMut() -> String)
292292
-> Result<Operation, Status> {
293293
let mut nd = try!(graph.new_operation("TruncateDiv", &id_gen()));
294294
nd.add_input(Output {
295-
operation: &children[0],
296-
index: 0,
297-
});
295+
operation: children[0].clone(),
296+
index: 0,
297+
});
298298
nd.add_input(Output {
299-
operation: &children[1],
300-
index: 0,
301-
});
299+
operation: children[1].clone(),
300+
index: 0,
301+
});
302302
nd.finish()
303303
}
304304

@@ -351,14 +351,14 @@ impl<T: TensorType> ExprImpl<T> for Neg<T> {
351351

352352
fn create_operation(&self,
353353
graph: &mut Graph,
354-
children: &[Rc<Operation>],
354+
children: &[Operation],
355355
id_gen: &mut FnMut() -> String)
356356
-> Result<Operation, Status> {
357357
let mut nd = try!(graph.new_operation("Neg", &id_gen()));
358358
nd.add_input(Output {
359-
operation: &children[0],
360-
index: 0,
361-
});
359+
operation: children[0].clone(),
360+
index: 0,
361+
});
362362
nd.finish()
363363
}
364364

@@ -409,7 +409,7 @@ impl<T: TensorType> ExprImpl<T> for Variable<T> {
409409

410410
fn create_operation(&self,
411411
graph: &mut Graph,
412-
_children: &[Rc<Operation>],
412+
_children: &[Operation],
413413
_id_gen: &mut FnMut() -> String)
414414
-> Result<Operation, Status> {
415415
let mut nd = try!(graph.new_operation("Variable", &self.name));
@@ -469,7 +469,7 @@ impl<T: TensorType> ExprImpl<T> for Placeholder<T> {
469469

470470
fn create_operation(&self,
471471
graph: &mut Graph,
472-
_children: &[Rc<Operation>],
472+
_children: &[Operation],
473473
_id_gen: &mut FnMut() -> String)
474474
-> Result<Operation, Status> {
475475
let mut nd = try!(graph.new_operation("Placeholder", &self.name));
@@ -523,18 +523,18 @@ impl<T: TensorType> ExprImpl<T> for Assign<T> {
523523

524524
fn create_operation(&self,
525525
graph: &mut Graph,
526-
children: &[Rc<Operation>],
526+
children: &[Operation],
527527
id_gen: &mut FnMut() -> String)
528528
-> Result<Operation, Status> {
529529
let mut nd = try!(graph.new_operation("Assign", &id_gen()));
530530
nd.add_input(Output {
531-
operation: &children[0],
532-
index: 0,
533-
});
531+
operation: children[0].clone(),
532+
index: 0,
533+
});
534534
nd.add_input(Output {
535-
operation: &children[1],
536-
index: 0,
537-
});
535+
operation: children[1].clone(),
536+
index: 0,
537+
});
538538
nd.finish()
539539
}
540540

@@ -563,7 +563,7 @@ pub trait AnyExpr: Debug {
563563
/// rather than creating child operations itself.
564564
fn create_operation(&self,
565565
graph: &mut Graph,
566-
children: &[Rc<Operation>],
566+
children: &[Operation],
567567
id_gen: &mut FnMut() -> String)
568568
-> Result<Operation, Status>;
569569

@@ -586,7 +586,7 @@ impl<T: TensorType> AnyExpr for Expr<T> {
586586

587587
fn create_operation(&self,
588588
graph: &mut Graph,
589-
children: &[Rc<Operation>],
589+
children: &[Operation],
590590
id_gen: &mut FnMut() -> String)
591591
-> Result<Operation, Status> {
592592
self.expr.create_operation(graph, children, id_gen)
@@ -620,7 +620,7 @@ impl Hash for Key {
620620
#[derive(Debug)]
621621
pub struct Compiler<'l> {
622622
graph: &'l mut Graph,
623-
operations: HashMap<Key, Rc<Operation>>,
623+
operations: HashMap<Key, Operation>,
624624
next_id: i32,
625625
}
626626

@@ -635,12 +635,12 @@ impl<'l> Compiler<'l> {
635635
}
636636

637637
/// Compiles the expression.
638-
pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Rc<Operation>, Status> {
638+
pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Operation, Status> {
639639
self.compile_any(Box::new(expr))
640640
}
641641

642642
/// Compiles the expression.
643-
pub fn compile_any(&mut self, expr: Box<AnyExpr>) -> Result<Rc<Operation>, Status> {
643+
pub fn compile_any(&mut self, expr: Box<AnyExpr>) -> Result<Operation, Status> {
644644
let mut child_operations = vec![];
645645
for child in expr.children() {
646646
let key = Key(child.clone_box());
@@ -661,7 +661,7 @@ impl<'l> Compiler<'l> {
661661
id
662662
});
663663
self.next_id = next_id;
664-
let operation = Rc::new(try!(result));
664+
let operation = result?;
665665
self.operations.insert(Key(expr), operation.clone());
666666
Ok(operation)
667667
}

0 commit comments

Comments
 (0)