Skip to content

Commit 013a829

Browse files
committed
Address reviewer's requests
1 parent 568ff56 commit 013a829

File tree

2 files changed

+70
-52
lines changed

2 files changed

+70
-52
lines changed

src/checkpoint.rs

Lines changed: 67 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,37 @@
1-
//! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format
2-
1+
//! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
2+
//! First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
3+
//! The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
4+
//! When one wants to save/restore from or into a session, one calls the save/restore methods
5+
//! # Example
6+
//! let mut scope = Scope::new_root_scope();
7+
//! // add operations to define the graph
8+
//! // ...
9+
//! // let "w" and "b" the name of the variables that we wish to save
10+
//! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11+
//! vec![String::from("w"), String::from("b")].into_boxed_slice(),
12+
//! );
13+
//! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
14+
//! // run some training
15+
//! // ...
16+
//! // to save the training
17+
//! checkpoint_maker.save(&session, "data/checkpoint")?;
18+
//! // then we restore in a different session to continue there
19+
//! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
20+
//! checkpoint_maker.save(&new_session, "data/checkpoint")?;
321
use crate::option_insert_result::OptionInsertWithResult;
422
use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor};
523

624
#[derive(Debug)]
725
struct SaveRestoreOps {
8-
pub prefix_save: Operation,
9-
pub prefix_restore: Operation,
10-
pub save_op: Operation,
11-
pub restore_op: Operation,
26+
prefix_save: Operation,
27+
prefix_restore: Operation,
28+
save_op: Operation,
29+
restore_op: Operation,
1230
}
1331

14-
/// Checkpointing and restoring struct
32+
/// Checkpointing and restoring support for Tensorflow.
33+
/// This struct is manages a scope, adds lazily the Tensorflow ops
34+
/// to perform the save/restore operations
1535
#[derive(Debug)]
1636
pub struct CheckpointMaker {
1737
scope: Scope,
@@ -33,19 +53,20 @@ impl CheckpointMaker {
3353
}
3454
}
3555

56+
fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
57+
let graph = self.scope.graph();
58+
Ok(self
59+
.variables
60+
.iter()
61+
.map(|v: &String| -> Result<Operation, Status> {
62+
Ok(graph.operation_by_name_required(v.as_str())?.clone())
63+
})
64+
.collect::<Result<Vec<_>, Status>>()?)
65+
}
66+
3667
/// Add save and restore ops to the graph
3768
fn build_save_ops(&mut self) -> Result<SaveRestoreOps, Status> {
3869
let mut all_variable_ops_opt: Option<Vec<Operation>> = None;
39-
fn make_all_variable_ops(myself: &mut CheckpointMaker) -> Result<Vec<Operation>, Status> {
40-
let graph = myself.scope.graph();
41-
Ok(myself
42-
.variables
43-
.iter()
44-
.map(|v: &String| -> Result<Operation, Status> {
45-
Ok(graph.operation_by_name_required(v.as_str())?.clone())
46-
})
47-
.collect::<Result<Vec<_>, Status>>()?)
48-
}
4970

5071
let existing_save_op = self.scope.graph().operation_by_name("save")?;
5172
let (prefix_save, save_op) = if let Some(op) = existing_save_op {
@@ -56,7 +77,7 @@ impl CheckpointMaker {
5677
(prefix_save_op, op)
5778
} else {
5879
let all_variable_ops =
59-
all_variable_ops_opt.get_or_insert_with_result(|| make_all_variable_ops(self))?;
80+
all_variable_ops_opt.get_or_insert_with_result(|| self.make_all_variable_ops())?;
6081
let prefix_save = ops::Placeholder::new()
6182
.dtype(crate::DataType::String)
6283
.build(&mut self.scope.with_op_name("prefix_save"))?;
@@ -105,39 +126,37 @@ impl CheckpointMaker {
105126
(the_prefix_restore, op)
106127
} else {
107128
let all_variable_ops =
108-
all_variable_ops_opt.get_or_insert_with_result(|| make_all_variable_ops(self))?;
129+
all_variable_ops_opt.get_or_insert_with_result(|| self.make_all_variable_ops())?;
109130
let prefix_restore = ops::Placeholder::new()
110131
.dtype(crate::DataType::String)
111132
.build(&mut self.scope.with_op_name("prefix_restore"))?;
112-
let restore_op = {
113-
let all_var_names = self
133+
let all_var_names = self
134+
.variables
135+
.iter()
136+
.map(|v| v.to_string())
137+
.collect::<Vec<_>>();
138+
let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?;
139+
let shape_and_slices = ops::constant(
140+
&self
114141
.variables
115142
.iter()
116-
.map(|v| v.to_string())
117-
.collect::<Vec<_>>();
118-
let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?;
119-
let shape_and_slices = ops::constant(
120-
&self
121-
.variables
122-
.iter()
123-
.map(|_| "".to_string())
124-
.collect::<Vec<_>>()[..],
125-
&mut self.scope,
126-
)?;
127-
let mut g = self.scope.graph_mut();
128-
let mut nd = g.new_operation("RestoreV2", "restore")?;
129-
nd.add_input(prefix_restore.clone());
130-
nd.add_input(tensor_names);
131-
nd.add_input(shape_and_slices);
132-
let dtypes = all_variable_ops
133-
.iter()
134-
.map(|v| v.get_attr_type("dtype"))
135-
.collect::<Result<Vec<_>, Status>>()?;
136-
nd.set_attr_type_list("dtypes", &dtypes[..])?;
137-
nd.finish()?
138-
};
139-
{
140-
let mut restore_var_ops = Vec::<Operation>::new();
143+
.map(|_| "".to_string())
144+
.collect::<Vec<_>>()[..],
145+
&mut self.scope,
146+
)?;
147+
let mut g = self.scope.graph_mut();
148+
let mut nd = g.new_operation("RestoreV2", "restore")?;
149+
nd.add_input(prefix_restore.clone());
150+
nd.add_input(tensor_names);
151+
nd.add_input(shape_and_slices);
152+
let dtypes = all_variable_ops
153+
.iter()
154+
.map(|v| v.get_attr_type("dtype"))
155+
.collect::<Result<Vec<_>, Status>>()?;
156+
nd.set_attr_type_list("dtypes", &dtypes[..])?;
157+
let restore_op = nd.finish()?;
158+
drop(g);
159+
let mut restore_var_ops = Vec::<Operation>::new();
141160
for (i, var) in self.variables.iter().enumerate() {
142161
let var_op = self
143162
.scope
@@ -157,7 +176,6 @@ impl CheckpointMaker {
157176
no_op = no_op.add_control_input(op);
158177
}
159178
(prefix_restore, no_op.build(&mut self.scope)?)
160-
}
161179
};
162180
Ok(SaveRestoreOps {
163181
prefix_save,
@@ -236,8 +254,8 @@ mod tests {
236254
}
237255

238256
struct MyScopeData {
239-
pub scope: Scope,
240-
pub variables: [Variable; 3],
257+
scope: Scope,
258+
variables: [Variable; 3],
241259
}
242260

243261
// Initialize a scope and place same variables in it

src/option_insert_result.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
// Similar to Option<T>.get_or_insert_with, for a function that returns a result.
22
pub trait OptionInsertWithResult<T> {
3-
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&T, E>
3+
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&mut T, E>
44
where
55
F: FnOnce() -> Result<T, E>;
66
}
77

88
impl<T> OptionInsertWithResult<T> for Option<T> {
9-
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&T, E>
9+
fn get_or_insert_with_result<F, E>(&mut self, f: F) -> Result<&mut T, E>
1010
where
1111
F: FnOnce() -> Result<T, E>,
1212
{
1313
if self.is_none() {
1414
*self = Some(f()?);
1515
}
16-
Ok(self.as_ref().unwrap())
16+
Ok(self.as_mut().unwrap())
1717
}
1818
}

0 commit comments

Comments
 (0)