Skip to content

Commit f7f91d0

Browse files
authored
Merge pull request #5 from dskkato/zenn
add zenn book example
2 parents 9573a4c + 0fa6911 commit f7f91d0

7 files changed

+96
-1
lines changed

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ edition = "2018"
77

88
[dependencies]
99
image = "0.23.14"
10+
ndarray = "0.15.3"
1011
random = "0.12.2"
11-
tensorflow = "0.17.0"
12+
tensorflow = {version="0.17.0", features = ["ndarray"]}

examples/zenn.rs

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use std::error::Error;
2+
use std::fs::File;
3+
use std::io::Read;
4+
use std::path::Path;
5+
use std::result::Result;
6+
use tensorflow::Code;
7+
use tensorflow::Graph;
8+
use tensorflow::ImportGraphDefOptions;
9+
use tensorflow::Session;
10+
use tensorflow::SessionOptions;
11+
use tensorflow::SessionRunArgs;
12+
use tensorflow::Status;
13+
use tensorflow::Tensor;
14+
15+
use ndarray;
16+
17+
use image::io::Reader as ImageReader;
18+
use image::GenericImageView;
19+
20+
fn main() -> Result<(), Box<dyn Error>> {
21+
let filename = "examples/zenn/mobilenetv3large.pb";
22+
if !Path::new(filename).exists() {
23+
return Err(Box::new(
24+
Status::new_set(
25+
Code::NotFound,
26+
&format!(
27+
"Run 'python examples/zenn/zenn.py' to generate {} \
28+
and try again.",
29+
filename
30+
),
31+
)
32+
.unwrap(),
33+
));
34+
}
35+
36+
// Create input variables for our addition
37+
let mut x = Tensor::new(&[1, 224, 224, 3]);
38+
let img = ImageReader::open("examples/zenn/sample.png")?.decode()?;
39+
for (i, (_, _, pixel)) in img.pixels().enumerate() {
40+
x[3 * i] = pixel.0[0] as f32;
41+
x[3 * i + 1] = pixel.0[1] as f32;
42+
x[3 * i + 2] = pixel.0[2] as f32;
43+
}
44+
45+
// Load the computation graph defined by addition.py.
46+
let mut graph = Graph::new();
47+
let mut proto = Vec::new();
48+
File::open(filename)?.read_to_end(&mut proto)?;
49+
graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;
50+
let session = Session::new(&SessionOptions::new(), &graph)?;
51+
52+
// Run the graph.
53+
let mut args = SessionRunArgs::new();
54+
args.add_feed(&graph.operation_by_name_required("x")?, 0, &x);
55+
let output = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0);
56+
session.run(&mut args)?;
57+
58+
// Check our results.
59+
let output: Tensor<f32> = args.fetch(output)?;
60+
let res: ndarray::Array<f32, _> = output.into();
61+
println!("{:?}", res);
62+
63+
Ok(())
64+
}

examples/zenn/imagenet_class_index.json

+1
Large diffs are not rendered by default.

examples/zenn/mobilenetv3large.pb

21.2 MB
Binary file not shown.
172 KB
Loading

examples/zenn/sample.png

118 KB
Loading

examples/zenn/zenn.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import tensorflow as tf
2+
from tensorflow.python.framework.convert_to_constants import (
3+
convert_variables_to_constants_v2,
4+
)
5+
from numpy.testing import assert_almost_equal
6+
7+
# default input shape 224x224x3
8+
model = tf.keras.applications.MobileNetV3Large()
9+
10+
x = tf.TensorSpec(model.input_shape, tf.float32, name="x")
11+
concrete_function = tf.function(lambda x: model(x)).get_concrete_function(x)
12+
# now all variables are converted to constants.
13+
# if this step is omitted, dumped graph does not include trained weights
14+
frozen_model = convert_variables_to_constants_v2(concrete_function)
15+
directory = "examples/zenn"
16+
tf.io.write_graph(frozen_model.graph, directory, "mobilenetv3large.pb", as_text=False)
17+
18+
# original prediction
19+
buf = tf.io.read_file("examples/zenn/sample.png")
20+
img = tf.image.decode_png(buf)
21+
sample = img[tf.newaxis, :, :, :]
22+
original_pred = model(sample)
23+
24+
frozen_model_pred = frozen_model(tf.cast(sample, tf.float32))[0]
25+
assert_almost_equal(original_pred.numpy(), frozen_model_pred.numpy())
26+
decoded = tf.keras.applications.mobilenet_v3.decode_predictions(original_pred.numpy())[
27+
0
28+
]
29+
print(decoded)

0 commit comments

Comments
 (0)