1
- //! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format
2
-
1
+ //! This module supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
2
+ //! First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
3
+ //! The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
4
+ //! When one wants to save/restore from or into a session, one calls the save/restore methods
5
+ //! # Example
6
+ //! let mut scope = Scope::new_root_scope();
7
+ //! // add operations to define the graph
8
+ //! // ...
9
+ //! // let "w" and "b" the name of the variables that we wish to save
10
+ //! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11
+ //! vec![String::from("w"), String::from("b")].into_boxed_slice(),
12
+ //! );
13
+ //! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
14
+ //! // run some training
15
+ //! // ...
16
+ //! // to save the training
17
+ //! checkpoint_maker.save(&session, "data/checkpoint")?;
18
+ //! // then we restore in a different session to continue there
19
+ //! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
20
+ //! checkpoint_maker.save(&new_session, "data/checkpoint")?;
3
21
use crate :: option_insert_result:: OptionInsertWithResult ;
4
22
use crate :: { ops, Operation , Scope , Session , SessionRunArgs , Status , Tensor } ;
5
23
6
24
#[ derive( Debug ) ]
7
25
struct SaveRestoreOps {
8
- pub prefix_save : Operation ,
9
- pub prefix_restore : Operation ,
10
- pub save_op : Operation ,
11
- pub restore_op : Operation ,
26
+ prefix_save : Operation ,
27
+ prefix_restore : Operation ,
28
+ save_op : Operation ,
29
+ restore_op : Operation ,
12
30
}
13
31
14
- /// Checkpointing and restoring struct
32
+ /// Checkpointing and restoring support for Tensorflow.
33
+ /// This struct is manages a scope, adds lazily the Tensorflow ops
34
+ /// to perform the save/restore operations
15
35
#[ derive( Debug ) ]
16
36
pub struct CheckpointMaker {
17
37
scope : Scope ,
@@ -33,19 +53,20 @@ impl CheckpointMaker {
33
53
}
34
54
}
35
55
56
+ fn make_all_variable_ops ( & mut self ) -> Result < Vec < Operation > , Status > {
57
+ let graph = self . scope . graph ( ) ;
58
+ Ok ( self
59
+ . variables
60
+ . iter ( )
61
+ . map ( |v : & String | -> Result < Operation , Status > {
62
+ Ok ( graph. operation_by_name_required ( v. as_str ( ) ) ?. clone ( ) )
63
+ } )
64
+ . collect :: < Result < Vec < _ > , Status > > ( ) ?)
65
+ }
66
+
36
67
/// Add save and restore ops to the graph
37
68
fn build_save_ops ( & mut self ) -> Result < SaveRestoreOps , Status > {
38
69
let mut all_variable_ops_opt: Option < Vec < Operation > > = None ;
39
- fn make_all_variable_ops ( myself : & mut CheckpointMaker ) -> Result < Vec < Operation > , Status > {
40
- let graph = myself. scope . graph ( ) ;
41
- Ok ( myself
42
- . variables
43
- . iter ( )
44
- . map ( |v : & String | -> Result < Operation , Status > {
45
- Ok ( graph. operation_by_name_required ( v. as_str ( ) ) ?. clone ( ) )
46
- } )
47
- . collect :: < Result < Vec < _ > , Status > > ( ) ?)
48
- }
49
70
50
71
let existing_save_op = self . scope . graph ( ) . operation_by_name ( "save" ) ?;
51
72
let ( prefix_save, save_op) = if let Some ( op) = existing_save_op {
@@ -56,7 +77,7 @@ impl CheckpointMaker {
56
77
( prefix_save_op, op)
57
78
} else {
58
79
let all_variable_ops =
59
- all_variable_ops_opt. get_or_insert_with_result ( || make_all_variable_ops ( self ) ) ?;
80
+ all_variable_ops_opt. get_or_insert_with_result ( || self . make_all_variable_ops ( ) ) ?;
60
81
let prefix_save = ops:: Placeholder :: new ( )
61
82
. dtype ( crate :: DataType :: String )
62
83
. build ( & mut self . scope . with_op_name ( "prefix_save" ) ) ?;
@@ -105,39 +126,37 @@ impl CheckpointMaker {
105
126
( the_prefix_restore, op)
106
127
} else {
107
128
let all_variable_ops =
108
- all_variable_ops_opt. get_or_insert_with_result ( || make_all_variable_ops ( self ) ) ?;
129
+ all_variable_ops_opt. get_or_insert_with_result ( || self . make_all_variable_ops ( ) ) ?;
109
130
let prefix_restore = ops:: Placeholder :: new ( )
110
131
. dtype ( crate :: DataType :: String )
111
132
. build ( & mut self . scope . with_op_name ( "prefix_restore" ) ) ?;
112
- let restore_op = {
113
- let all_var_names = self
133
+ let all_var_names = self
134
+ . variables
135
+ . iter ( )
136
+ . map ( |v| v. to_string ( ) )
137
+ . collect :: < Vec < _ > > ( ) ;
138
+ let tensor_names = ops:: constant ( & all_var_names[ ..] , & mut self . scope ) ?;
139
+ let shape_and_slices = ops:: constant (
140
+ & self
114
141
. variables
115
142
. iter ( )
116
- . map ( |v| v. to_string ( ) )
117
- . collect :: < Vec < _ > > ( ) ;
118
- let tensor_names = ops:: constant ( & all_var_names[ ..] , & mut self . scope ) ?;
119
- let shape_and_slices = ops:: constant (
120
- & self
121
- . variables
122
- . iter ( )
123
- . map ( |_| "" . to_string ( ) )
124
- . collect :: < Vec < _ > > ( ) [ ..] ,
125
- & mut self . scope ,
126
- ) ?;
127
- let mut g = self . scope . graph_mut ( ) ;
128
- let mut nd = g. new_operation ( "RestoreV2" , "restore" ) ?;
129
- nd. add_input ( prefix_restore. clone ( ) ) ;
130
- nd. add_input ( tensor_names) ;
131
- nd. add_input ( shape_and_slices) ;
132
- let dtypes = all_variable_ops
133
- . iter ( )
134
- . map ( |v| v. get_attr_type ( "dtype" ) )
135
- . collect :: < Result < Vec < _ > , Status > > ( ) ?;
136
- nd. set_attr_type_list ( "dtypes" , & dtypes[ ..] ) ?;
137
- nd. finish ( ) ?
138
- } ;
139
- {
140
- let mut restore_var_ops = Vec :: < Operation > :: new ( ) ;
143
+ . map ( |_| "" . to_string ( ) )
144
+ . collect :: < Vec < _ > > ( ) [ ..] ,
145
+ & mut self . scope ,
146
+ ) ?;
147
+ let mut g = self . scope . graph_mut ( ) ;
148
+ let mut nd = g. new_operation ( "RestoreV2" , "restore" ) ?;
149
+ nd. add_input ( prefix_restore. clone ( ) ) ;
150
+ nd. add_input ( tensor_names) ;
151
+ nd. add_input ( shape_and_slices) ;
152
+ let dtypes = all_variable_ops
153
+ . iter ( )
154
+ . map ( |v| v. get_attr_type ( "dtype" ) )
155
+ . collect :: < Result < Vec < _ > , Status > > ( ) ?;
156
+ nd. set_attr_type_list ( "dtypes" , & dtypes[ ..] ) ?;
157
+ let restore_op = nd. finish ( ) ?;
158
+ drop ( g) ;
159
+ let mut restore_var_ops = Vec :: < Operation > :: new ( ) ;
141
160
for ( i, var) in self . variables . iter ( ) . enumerate ( ) {
142
161
let var_op = self
143
162
. scope
@@ -157,7 +176,6 @@ impl CheckpointMaker {
157
176
no_op = no_op. add_control_input ( op) ;
158
177
}
159
178
( prefix_restore, no_op. build ( & mut self . scope ) ?)
160
- }
161
179
} ;
162
180
Ok ( SaveRestoreOps {
163
181
prefix_save,
@@ -236,8 +254,8 @@ mod tests {
236
254
}
237
255
238
256
struct MyScopeData {
239
- pub scope : Scope ,
240
- pub variables : [ Variable ; 3 ] ,
257
+ scope : Scope ,
258
+ variables : [ Variable ; 3 ] ,
241
259
}
242
260
243
261
// Initialize a scope and place same variables in it
0 commit comments