1
- //! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
2
- //! First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
3
- //! The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
4
- //! When one wants to save/restore from or into a session, one calls the save/restore methods
5
- //! # Example
6
- //! let mut scope = Scope::new_root_scope();
7
- //! // add operations to define the graph
8
- //! // ...
9
- //! // let w and b the variables that we wish to save
10
- //! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11
- //! vec![w.clone(), b.clone()].into_boxed_slice(),
12
- //! );
13
- //! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
14
- //! // run some training
15
- //! // ...
16
- //! // to save the training
17
- //! checkpoint_maker.save(&session, "data/checkpoint")?;
18
- //! // then we restore in a different session to continue there
19
- //! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
20
- //! checkpoint_maker.save(&new_session, "data/checkpoint")?;
21
1
use crate :: option_insert_result:: OptionInsertWithResult ;
22
2
use crate :: { ops, Operation , Scope , Session , SessionRunArgs , Status , Tensor , Variable } ;
23
3
@@ -29,9 +9,29 @@ struct SaveRestoreOps {
29
9
restore_op : Operation ,
30
10
}
31
11
32
- /// Checkpointing and restoring support for Tensorflow.
33
- /// This struct is manages a scope, adds lazily the Tensorflow ops
34
- /// to perform the save/restore operations
12
+ /// This struct supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
13
+ /// First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
14
+ /// The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
15
+ /// When one wants to save/restore from or into a session, one calls the save/restore methods
16
+ /// # Example
17
+ /// ```
18
+ /// let mut scope = Scope::new_root_scope();
19
+ /// // add operations to define the graph
20
+ /// // ...
21
+ /// // let w and b the variables that we wish to save
22
+ /// let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
23
+ /// vec![w.clone(), b.clone()].into_boxed_slice(),
24
+ /// );
25
+ /// let session = Session::new(&SessionOptions::new(), &scope.graph())?;
26
+ /// // run some training
27
+ /// // ...
28
+ /// // to save the training
29
+ /// checkpoint_maker.save(&session, "data/checkpoint")?;
30
+ /// // then we restore in a different session to continue there
31
+ /// let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
32
+ /// checkpoint_maker.save(&new_session, "data/checkpoint")?;
33
+ /// ```
34
+ ///
35
35
#[ derive( Debug ) ]
36
36
pub struct CheckpointMaker {
37
37
scope : Scope ,
@@ -44,7 +44,7 @@ impl CheckpointMaker {
44
44
/// The scope is used to modify the graph to add the save and restore ops.
45
45
///
46
46
/// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("checkpoint")
47
- /// in order to create the nodes with scoped names
47
+ /// in order to create the nodes with scoped names.
48
48
pub fn new ( scope : Scope , variables : Box < [ Variable ] > ) -> CheckpointMaker {
49
49
CheckpointMaker {
50
50
scope,
@@ -53,18 +53,7 @@ impl CheckpointMaker {
53
53
}
54
54
}
55
55
56
- /* fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
57
- let graph = self.scope.graph();
58
- Ok(self
59
- .variables
60
- .iter()
61
- .map(|v: &String| -> Result<Operation, Status> {
62
- Ok(graph.operation_by_name_required(v.as_str())?.clone())
63
- })
64
- .collect::<Result<Vec<_>, Status>>()?)
65
- }*/
66
-
67
- /// Add save and restore ops to the graph
56
+ // Add save and restore ops to the graph.
68
57
fn build_save_ops ( & mut self ) -> Result < SaveRestoreOps , Status > {
69
58
let mut all_variable_ops_opt: Option < Vec < Operation > > = None ;
70
59
@@ -76,18 +65,21 @@ impl CheckpointMaker {
76
65
. operation_by_name_required ( "prefix_save" ) ?;
77
66
( prefix_save_op, op)
78
67
} else {
79
- let all_variable_ops =
80
- all_variable_ops_opt. get_or_insert_with (
81
- || self . variables . iter ( ) . map ( |v| v. output . operation . clone ( ) ) . collect :: < Vec < _ > > ( ) ) ;
68
+ let all_variable_ops = all_variable_ops_opt. get_or_insert_with ( || {
69
+ self . variables
70
+ . iter ( )
71
+ . map ( |v| v. output . operation . clone ( ) )
72
+ . collect :: < Vec < _ > > ( )
73
+ } ) ;
82
74
let prefix_save = ops:: Placeholder :: new ( )
83
75
. dtype ( crate :: DataType :: String )
84
76
. build ( & mut self . scope . with_op_name ( "prefix_save" ) ) ?;
85
77
let tensor_names = ops:: constant (
86
- self
87
- . variables
78
+ self . variables
88
79
. iter ( )
89
80
. map ( |v| String :: from ( v. name ( ) ) )
90
- . collect :: < Vec < _ > > ( ) . as_slice ( ) ,
81
+ . collect :: < Vec < _ > > ( )
82
+ . as_slice ( ) ,
91
83
& mut self . scope ,
92
84
) ?;
93
85
let shape_and_slices = ops:: constant (
@@ -126,9 +118,12 @@ impl CheckpointMaker {
126
118
. operation_by_name_required ( "prefix_restore" ) ?;
127
119
( the_prefix_restore, op)
128
120
} else {
129
- let all_variable_ops =
130
- all_variable_ops_opt. get_or_insert_with (
131
- || self . variables . iter ( ) . map ( |v| v. output . operation . clone ( ) ) . collect :: < Vec < _ > > ( ) ) ;
121
+ let all_variable_ops = all_variable_ops_opt. get_or_insert_with ( || {
122
+ self . variables
123
+ . iter ( )
124
+ . map ( |v| v. output . operation . clone ( ) )
125
+ . collect :: < Vec < _ > > ( )
126
+ } ) ;
132
127
let prefix_restore = ops:: Placeholder :: new ( )
133
128
. dtype ( crate :: DataType :: String )
134
129
. build ( & mut self . scope . with_op_name ( "prefix_restore" ) ) ?;
@@ -159,22 +154,22 @@ impl CheckpointMaker {
159
154
let restore_op = nd. finish ( ) ?;
160
155
drop ( g) ;
161
156
let mut restore_var_ops = Vec :: < Operation > :: new ( ) ;
162
- for ( i, var) in self . variables . iter ( ) . enumerate ( ) {
163
- let var_op = var. output . operation . clone ( ) ;
164
- restore_var_ops. push ( ops:: assign (
165
- var_op,
166
- crate :: Output {
167
- operation : restore_op. clone ( ) ,
168
- index : i as i32 ,
169
- } ,
170
- & mut self . scope . new_sub_scope ( format ! ( "restore{}" , i) . as_str ( ) ) ,
171
- ) ?) ;
172
- }
173
- let mut no_op = ops:: NoOp :: new ( ) ;
174
- for op in restore_var_ops {
175
- no_op = no_op. add_control_input ( op) ;
176
- }
177
- ( prefix_restore, no_op. build ( & mut self . scope ) ?)
157
+ for ( i, var) in self . variables . iter ( ) . enumerate ( ) {
158
+ let var_op = var. output . operation . clone ( ) ;
159
+ restore_var_ops. push ( ops:: assign (
160
+ var_op,
161
+ crate :: Output {
162
+ operation : restore_op. clone ( ) ,
163
+ index : i as i32 ,
164
+ } ,
165
+ & mut self . scope . new_sub_scope ( format ! ( "restore{}" , i) . as_str ( ) ) ,
166
+ ) ?) ;
167
+ }
168
+ let mut no_op = ops:: NoOp :: new ( ) ;
169
+ for op in restore_var_ops {
170
+ no_op = no_op. add_control_input ( op) ;
171
+ }
172
+ ( prefix_restore, no_op. build ( & mut self . scope ) ?)
178
173
} ;
179
174
Ok ( SaveRestoreOps {
180
175
prefix_save,
@@ -194,7 +189,7 @@ impl CheckpointMaker {
194
189
Ok ( save_r_op)
195
190
}
196
191
197
- /// Save the variables listed in this CheckpointMaker to the checkpoint with base filename backup_filename_base
192
+ /// Save the variables listed in this CheckpointMaker to the checkpoint with base filename backup_filename_base.
198
193
pub fn save ( & mut self , session : & Session , backup_filename_base : & str ) -> Result < ( ) , Status > {
199
194
let save_restore_ops = self . get_save_operation ( ) ?;
200
195
let prefix_arg = Tensor :: from ( backup_filename_base. to_string ( ) ) ;
@@ -206,7 +201,7 @@ impl CheckpointMaker {
206
201
}
207
202
208
203
/// Restore into the session the variables listed in this CheckpointMaker from the checkpoint
209
- /// in path_base
204
+ /// in path_base.
210
205
pub fn restore ( & mut self , session : & Session , path_base : & str ) -> Result < ( ) , Status > {
211
206
let save_restore_ops = self . get_save_operation ( ) ?;
212
207
let prefix_arg = Tensor :: from ( path_base. to_string ( ) ) ;
@@ -372,7 +367,8 @@ mod tests {
372
367
variables : second_variables,
373
368
} = create_scope ( ) ?;
374
369
let second_session = Session :: new ( & SessionOptions :: new ( ) , & second_scope. graph ( ) ) ?;
375
- let mut second_checkpoint = CheckpointMaker :: new ( second_scope, Box :: new ( second_variables. clone ( ) ) ) ;
370
+ let mut second_checkpoint =
371
+ CheckpointMaker :: new ( second_scope, Box :: new ( second_variables. clone ( ) ) ) ;
376
372
second_checkpoint. restore ( & second_session, checkpoint_path_str. as_str ( ) ) ?;
377
373
check_variables ( & second_session, & second_variables, & new_values) ?;
378
374
Ok ( ( ) )
0 commit comments