Skip to content

Commit 41ddd70

Browse files
authored
Merge pull request #6 from dskkato/use_signature
Use signature to extract op names
2 parents f7f91d0 + 0169e9a commit 41ddd70

11 files changed

+52
-25
lines changed

examples/mnist_savedmodel.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,15 @@ fn main() -> Result<(), Box<dyn Error>> {
3737

3838
// Load the saved model exported by regression_savedmodel.py.
3939
let mut graph = Graph::new();
40-
let session =
41-
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?.session;
42-
let op_x = graph.operation_by_name_required("serving_default_sequential_input")?;
43-
let op_predict = graph.operation_by_name_required("StatefulPartitionedCall")?;
40+
let bundle =
41+
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
42+
let session = &bundle.session;
43+
44+
let signature = bundle.meta_graph_def().get_signature("serving_default")?;
45+
let input_info = signature.get_input("input")?;
46+
let op_x = graph.operation_by_name_required(&input_info.name().name)?;
47+
let output_info = signature.get_output("output")?;
48+
let op_predict = graph.operation_by_name_required(&output_info.name().name)?;
4449

4550
// Train the model (e.g. for fine tuning).
4651
let mut args = SessionRunArgs::new();
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
4.0663176e-06, 1.4199884e-07, 9.556003e-05, 0.00065914105, 2.260991e-07, 4.076631e-06, 2.5459945e-09, 0.99904054, 1.5654963e-05, 0.00018059688
1+
3.112342e-05, 8.721303e-08, 0.0005018024, 0.0003709061, 1.6482764e-08, 1.8595395e-06, 1.3620006e-09, 0.999046, 5.4331244e-06, 4.2815118e-05

examples/mnist_savedmodel/mnist_savedmodel.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
model.fit(x_train, y_train, epochs=1)
3232

3333
# convert output type through softmax so that it can be interpreted as probability
34-
probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax(name="output")])
34+
inputs = tf.keras.Input((28, 28), name="input", dtype=tf.float32)
35+
x = model(inputs)
36+
outputs = tf.keras.layers.Softmax(name="output")(x)
37+
38+
probability_model = tf.keras.Model(inputs=inputs, outputs=outputs)
3539

3640
# dump expected values to compare Rust's outputs
3741
with open("examples/mnist_savedmodel/expected_values.txt", "w") as f:
@@ -45,6 +49,6 @@
4549
logdir = "logs/mnist_savedmodel"
4650
writer = tf.summary.create_file_writer(logdir)
4751
tf.summary.trace_on()
48-
values = probability_model(x_test[:1, :, :])
52+
values = probability_model.predict(x_test[:1, :, :])
4953
with writer.as_default():
5054
tf.summary.trace_export("Default", step=0)
-478 Bytes
Binary file not shown.
Binary file not shown.
-20 Bytes
Binary file not shown.

examples/regression_savedmodel.rs

+23-9
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,41 @@ fn main() -> Result<(), Box<dyn Error>> {
4141

4242
// Load the saved model exported by regression_savedmodel.py.
4343
let mut graph = Graph::new();
44-
let session =
45-
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?.session;
46-
let op_x = graph.operation_by_name_required("train_x")?;
47-
let op_y = graph.operation_by_name_required("train_y")?;
48-
let op_train = graph.operation_by_name_required("StatefulPartitionedCall")?;
49-
let op_w = graph.operation_by_name_required("StatefulPartitionedCall_1")?;
50-
let op_b = graph.operation_by_name_required("StatefulPartitionedCall_1")?;
44+
let bundle =
45+
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
46+
let session = &bundle.session;
47+
48+
let train_signature = bundle.meta_graph_def().get_signature("train")?;
49+
let x_info = train_signature.get_input("x")?;
50+
let y_info = train_signature.get_input("y")?;
51+
let train_info = train_signature.get_output("train")?;
52+
let op_x = graph.operation_by_name_required(&x_info.name().name)?;
53+
let op_y = graph.operation_by_name_required(&y_info.name().name)?;
54+
let op_train = graph.operation_by_name_required(&train_info.name().name)?;
55+
let w_info = bundle
56+
.meta_graph_def()
57+
.get_signature("w")?
58+
.get_output("output")?;
59+
let op_w = graph.operation_by_name_required(&w_info.name().name)?;
60+
let b_info = bundle
61+
.meta_graph_def()
62+
.get_signature("b")?
63+
.get_output("output")?;
64+
let op_b = graph.operation_by_name_required(&b_info.name().name)?;
5165

5266
// Train the model (e.g. for fine tuning).
5367
let mut train_step = SessionRunArgs::new();
5468
train_step.add_feed(&op_x, 0, &x);
5569
train_step.add_feed(&op_y, 0, &y);
56-
train_step.request_fetch(&op_train, 0);
70+
train_step.add_target(&op_train);
5771
for _ in 0..steps {
5872
session.run(&mut train_step)?;
5973
}
6074

6175
// Grab the data out of the session.
6276
let mut output_step = SessionRunArgs::new();
6377
let w_ix = output_step.request_fetch(&op_w, 0);
64-
let b_ix = output_step.request_fetch(&op_b, 1);
78+
let b_ix = output_step.request_fetch(&op_b, 0);
6579
session.run(&mut output_step)?;
6680

6781
// Check our results.

examples/regression_savedmodel/regression_savedmodel.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,39 @@ def __call__(self, x):
1414
return y_hat
1515

1616
@tf.function
17-
def get_weights(self):
18-
return self.w, self.b
17+
def get_w(self):
18+
return {"output": self.w}
19+
20+
@tf.function
21+
def get_b(self):
22+
return {"output": self.b}
1923

2024
@tf.function
2125
def train(self, x, y):
2226
with tf.GradientTape() as tape:
2327
y_hat = self(x)
2428
loss = tf.reduce_mean(tf.square(y_hat - y))
2529
grads = tape.gradient(loss, self.trainable_variables)
26-
_ = self.optimizer.apply_gradients(
27-
zip(grads, self.trainable_variables), name="train"
28-
)
29-
return loss
30+
_ = self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
31+
return {"train": loss}
3032

3133

3234
model = LinearRegresstion()
3335

3436
x = tf.TensorSpec([None], tf.float32, name="x")
3537
y = tf.TensorSpec([None], tf.float32, name="y")
3638
train = model.train.get_concrete_function(x, y)
37-
weights = model.get_weights.get_concrete_function()
39+
w = model.get_w.get_concrete_function()
40+
b = model.get_b.get_concrete_function()
3841

3942
directory = "examples/regression_savedmodel"
40-
signatures = {"train": train, "weights": weights}
43+
signatures = {"train": train, "w": w, "b": b}
4144
tf.saved_model.save(model, directory, signatures=signatures)
4245

4346
# export graph info to TensorBoard
4447
logdir = "logs/regression_savedmodel"
4548
writer = tf.summary.create_file_writer(logdir)
4649
with writer.as_default():
4750
tf.summary.graph(train.graph)
48-
tf.summary.graph(weights.graph)
51+
tf.summary.graph(w.graph)
52+
tf.summary.graph(b.graph)
-5.14 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)