|
| 1 | +use std::error::Error; |
| 2 | +use std::fs::File; |
| 3 | +use std::io::Read; |
| 4 | +use std::path::Path; |
| 5 | +use std::result::Result; |
| 6 | +use tensorflow::Code; |
| 7 | +use tensorflow::Graph; |
| 8 | +use tensorflow::ImportGraphDefOptions; |
| 9 | +use tensorflow::Session; |
| 10 | +use tensorflow::SessionOptions; |
| 11 | +use tensorflow::SessionRunArgs; |
| 12 | +use tensorflow::Status; |
| 13 | +use tensorflow::Tensor; |
| 14 | + |
| 15 | +use ndarray; |
| 16 | + |
| 17 | +use image::io::Reader as ImageReader; |
| 18 | +use image::GenericImageView; |
| 19 | + |
| 20 | +fn main() -> Result<(), Box<dyn Error>> { |
| 21 | + let filename = "examples/zenn/mobilenetv3large.pb"; |
| 22 | + if !Path::new(filename).exists() { |
| 23 | + return Err(Box::new( |
| 24 | + Status::new_set( |
| 25 | + Code::NotFound, |
| 26 | + &format!( |
| 27 | + "Run 'python examples/zenn/zenn.py' to generate {} \ |
| 28 | + and try again.", |
| 29 | + filename |
| 30 | + ), |
| 31 | + ) |
| 32 | + .unwrap(), |
| 33 | + )); |
| 34 | + } |
| 35 | + |
| 36 | + // Create input variables for our addition |
| 37 | + let mut x = Tensor::new(&[1, 224, 224, 3]); |
| 38 | + let img = ImageReader::open("examples/zenn/sample.png")?.decode()?; |
| 39 | + for (i, (_, _, pixel)) in img.pixels().enumerate() { |
| 40 | + x[3 * i] = pixel.0[0] as f32; |
| 41 | + x[3 * i + 1] = pixel.0[1] as f32; |
| 42 | + x[3 * i + 2] = pixel.0[2] as f32; |
| 43 | + } |
| 44 | + |
| 45 | + // Load the computation graph defined by addition.py. |
| 46 | + let mut graph = Graph::new(); |
| 47 | + let mut proto = Vec::new(); |
| 48 | + File::open(filename)?.read_to_end(&mut proto)?; |
| 49 | + graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?; |
| 50 | + let session = Session::new(&SessionOptions::new(), &graph)?; |
| 51 | + |
| 52 | + // Run the graph. |
| 53 | + let mut args = SessionRunArgs::new(); |
| 54 | + args.add_feed(&graph.operation_by_name_required("x")?, 0, &x); |
| 55 | + let output = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0); |
| 56 | + session.run(&mut args)?; |
| 57 | + |
| 58 | + // Check our results. |
| 59 | + let output: Tensor<f32> = args.fetch(output)?; |
| 60 | + let res: ndarray::Array<f32, _> = output.into(); |
| 61 | + println!("{:?}", res); |
| 62 | + |
| 63 | + Ok(()) |
| 64 | +} |
0 commit comments