Skip to content

Commit 9f73a9c

Browse files
committed
Formatting and comment fixes
1 parent 2247ed8 commit 9f73a9c

File tree

1 file changed

+60
-64
lines changed

1 file changed

+60
-64
lines changed

src/checkpoint.rs

Lines changed: 60 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,3 @@
1-
//! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
2-
//! First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
3-
//! The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
4-
//! When one wants to save/restore from or into a session, one calls the save/restore methods
5-
//! # Example
6-
//! let mut scope = Scope::new_root_scope();
7-
//! // add operations to define the graph
8-
//! // ...
9-
//! // let w and b the variables that we wish to save
10-
//! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11-
//! vec![w.clone(), b.clone()].into_boxed_slice(),
12-
//! );
13-
//! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
14-
//! // run some training
15-
//! // ...
16-
//! // to save the training
17-
//! checkpoint_maker.save(&session, "data/checkpoint")?;
18-
//! // then we restore in a different session to continue there
19-
//! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
20-
//! checkpoint_maker.save(&new_session, "data/checkpoint")?;
211
use crate::option_insert_result::OptionInsertWithResult;
222
use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor, Variable};
233

@@ -29,9 +9,29 @@ struct SaveRestoreOps {
299
restore_op: Operation,
3010
}
3111

32-
/// Checkpointing and restoring support for Tensorflow.
33-
/// This struct is manages a scope, adds lazily the Tensorflow ops
34-
/// to perform the save/restore operations
12+
/// This struct supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
13+
/// First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
14+
/// The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
15+
/// When one wants to save/restore from or into a session, one calls the save/restore methods
16+
/// # Example
17+
/// ```
18+
/// let mut scope = Scope::new_root_scope();
19+
/// // add operations to define the graph
20+
/// // ...
21+
/// // let w and b the variables that we wish to save
22+
/// let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
23+
/// vec![w.clone(), b.clone()].into_boxed_slice(),
24+
/// );
25+
/// let session = Session::new(&SessionOptions::new(), &scope.graph())?;
26+
/// // run some training
27+
/// // ...
28+
/// // to save the training
29+
/// checkpoint_maker.save(&session, "data/checkpoint")?;
30+
/// // then we restore in a different session to continue there
31+
/// let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
32+
/// checkpoint_maker.save(&new_session, "data/checkpoint")?;
33+
/// ```
34+
///
3535
#[derive(Debug)]
3636
pub struct CheckpointMaker {
3737
scope: Scope,
@@ -44,7 +44,7 @@ impl CheckpointMaker {
4444
/// The scope is used to modify the graph to add the save and restore ops.
4545
///
4646
/// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("checkpoint")
47-
/// in order to create the nodes with scoped names
47+
/// in order to create the nodes with scoped names.
4848
pub fn new(scope: Scope, variables: Box<[Variable]>) -> CheckpointMaker {
4949
CheckpointMaker {
5050
scope,
@@ -53,18 +53,7 @@ impl CheckpointMaker {
5353
}
5454
}
5555

56-
/* fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
57-
let graph = self.scope.graph();
58-
Ok(self
59-
.variables
60-
.iter()
61-
.map(|v: &String| -> Result<Operation, Status> {
62-
Ok(graph.operation_by_name_required(v.as_str())?.clone())
63-
})
64-
.collect::<Result<Vec<_>, Status>>()?)
65-
}*/
66-
67-
/// Add save and restore ops to the graph
56+
// Add save and restore ops to the graph.
6857
fn build_save_ops(&mut self) -> Result<SaveRestoreOps, Status> {
6958
let mut all_variable_ops_opt: Option<Vec<Operation>> = None;
7059

@@ -76,18 +65,21 @@ impl CheckpointMaker {
7665
.operation_by_name_required("prefix_save")?;
7766
(prefix_save_op, op)
7867
} else {
79-
let all_variable_ops =
80-
all_variable_ops_opt.get_or_insert_with(
81-
|| self.variables.iter().map(|v| v.output.operation.clone() ).collect::<Vec<_>>());
68+
let all_variable_ops = all_variable_ops_opt.get_or_insert_with(|| {
69+
self.variables
70+
.iter()
71+
.map(|v| v.output.operation.clone())
72+
.collect::<Vec<_>>()
73+
});
8274
let prefix_save = ops::Placeholder::new()
8375
.dtype(crate::DataType::String)
8476
.build(&mut self.scope.with_op_name("prefix_save"))?;
8577
let tensor_names = ops::constant(
86-
self
87-
.variables
78+
self.variables
8879
.iter()
8980
.map(|v| String::from(v.name()))
90-
.collect::<Vec<_>>().as_slice(),
81+
.collect::<Vec<_>>()
82+
.as_slice(),
9183
&mut self.scope,
9284
)?;
9385
let shape_and_slices = ops::constant(
@@ -126,9 +118,12 @@ impl CheckpointMaker {
126118
.operation_by_name_required("prefix_restore")?;
127119
(the_prefix_restore, op)
128120
} else {
129-
let all_variable_ops =
130-
all_variable_ops_opt.get_or_insert_with(
131-
|| self.variables.iter().map(|v| v.output.operation.clone() ).collect::<Vec<_>>());
121+
let all_variable_ops = all_variable_ops_opt.get_or_insert_with(|| {
122+
self.variables
123+
.iter()
124+
.map(|v| v.output.operation.clone())
125+
.collect::<Vec<_>>()
126+
});
132127
let prefix_restore = ops::Placeholder::new()
133128
.dtype(crate::DataType::String)
134129
.build(&mut self.scope.with_op_name("prefix_restore"))?;
@@ -159,22 +154,22 @@ impl CheckpointMaker {
159154
let restore_op = nd.finish()?;
160155
drop(g);
161156
let mut restore_var_ops = Vec::<Operation>::new();
162-
for (i, var) in self.variables.iter().enumerate() {
163-
let var_op = var.output.operation.clone();
164-
restore_var_ops.push(ops::assign(
165-
var_op,
166-
crate::Output {
167-
operation: restore_op.clone(),
168-
index: i as i32,
169-
},
170-
&mut self.scope.new_sub_scope(format!("restore{}", i).as_str()),
171-
)?);
172-
}
173-
let mut no_op = ops::NoOp::new();
174-
for op in restore_var_ops {
175-
no_op = no_op.add_control_input(op);
176-
}
177-
(prefix_restore, no_op.build(&mut self.scope)?)
157+
for (i, var) in self.variables.iter().enumerate() {
158+
let var_op = var.output.operation.clone();
159+
restore_var_ops.push(ops::assign(
160+
var_op,
161+
crate::Output {
162+
operation: restore_op.clone(),
163+
index: i as i32,
164+
},
165+
&mut self.scope.new_sub_scope(format!("restore{}", i).as_str()),
166+
)?);
167+
}
168+
let mut no_op = ops::NoOp::new();
169+
for op in restore_var_ops {
170+
no_op = no_op.add_control_input(op);
171+
}
172+
(prefix_restore, no_op.build(&mut self.scope)?)
178173
};
179174
Ok(SaveRestoreOps {
180175
prefix_save,
@@ -194,7 +189,7 @@ impl CheckpointMaker {
194189
Ok(save_r_op)
195190
}
196191

197-
/// Save the variables listed in this CheckpointMaker to the checkpoint with base filename backup_filename_base
192+
/// Save the variables listed in this CheckpointMaker to the checkpoint with base filename backup_filename_base.
198193
pub fn save(&mut self, session: &Session, backup_filename_base: &str) -> Result<(), Status> {
199194
let save_restore_ops = self.get_save_operation()?;
200195
let prefix_arg = Tensor::from(backup_filename_base.to_string());
@@ -206,7 +201,7 @@ impl CheckpointMaker {
206201
}
207202

208203
/// Restore into the session the variables listed in this CheckpointMaker from the checkpoint
209-
/// in path_base
204+
/// in path_base.
210205
pub fn restore(&mut self, session: &Session, path_base: &str) -> Result<(), Status> {
211206
let save_restore_ops = self.get_save_operation()?;
212207
let prefix_arg = Tensor::from(path_base.to_string());
@@ -372,7 +367,8 @@ mod tests {
372367
variables: second_variables,
373368
} = create_scope()?;
374369
let second_session = Session::new(&SessionOptions::new(), &second_scope.graph())?;
375-
let mut second_checkpoint = CheckpointMaker::new(second_scope, Box::new(second_variables.clone()));
370+
let mut second_checkpoint =
371+
CheckpointMaker::new(second_scope, Box::new(second_variables.clone()));
376372
second_checkpoint.restore(&second_session, checkpoint_path_str.as_str())?;
377373
check_variables(&second_session, &second_variables, &new_values)?;
378374
Ok(())

0 commit comments

Comments
 (0)