6
6
//! let mut scope = Scope::new_root_scope();
7
7
//! // add operations to define the graph
8
8
//! // ...
9
- //! // let "w" and "b" the name of the variables that we wish to save
9
+ //! // let w and b the variables that we wish to save
10
10
//! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11
- //! vec![String::from("w" ), String::from("b" )].into_boxed_slice(),
11
+ //! vec![w.clone( ), b.clone( )].into_boxed_slice(),
12
12
//! );
13
13
//! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
14
14
//! // run some training
19
19
//! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
20
20
//! checkpoint_maker.save(&new_session, "data/checkpoint")?;
21
21
use crate :: option_insert_result:: OptionInsertWithResult ;
22
- use crate :: { ops, Operation , Scope , Session , SessionRunArgs , Status , Tensor } ;
22
+ use crate :: { ops, Operation , Scope , Session , SessionRunArgs , Status , Tensor , Variable } ;
23
23
24
24
#[ derive( Debug ) ]
25
25
struct SaveRestoreOps {
@@ -35,25 +35,25 @@ struct SaveRestoreOps {
35
35
#[ derive( Debug ) ]
36
36
pub struct CheckpointMaker {
37
37
scope : Scope ,
38
- variables : Box < [ String ] > ,
38
+ variables : Box < [ Variable ] > ,
39
39
save_restore_ops : Option < SaveRestoreOps > ,
40
40
}
41
41
42
42
impl CheckpointMaker {
43
43
/// Creates a new CheckpointMaker for a Scope, with a list of variables to save/restore.
44
44
/// The scope is used to modify the graph to add the save and restore ops.
45
45
///
46
- /// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("")
47
- /// as Scope does not support the Clone trait at present
48
- pub fn new ( scope : Scope , variables : Box < [ String ] > ) -> CheckpointMaker {
46
+ /// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("checkpoint ")
47
+ /// in order to create the nodes with scoped names
48
+ pub fn new ( scope : Scope , variables : Box < [ Variable ] > ) -> CheckpointMaker {
49
49
CheckpointMaker {
50
50
scope,
51
51
variables,
52
52
save_restore_ops : None ,
53
53
}
54
54
}
55
55
56
- fn make_all_variable_ops ( & mut self ) -> Result < Vec < Operation > , Status > {
56
+ /* fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
57
57
let graph = self.scope.graph();
58
58
Ok(self
59
59
.variables
@@ -62,7 +62,7 @@ impl CheckpointMaker {
62
62
Ok(graph.operation_by_name_required(v.as_str())?.clone())
63
63
})
64
64
.collect::<Result<Vec<_>, Status>>()?)
65
- }
65
+ }*/
66
66
67
67
/// Add save and restore ops to the graph
68
68
fn build_save_ops ( & mut self ) -> Result < SaveRestoreOps , Status > {
@@ -77,16 +77,17 @@ impl CheckpointMaker {
77
77
( prefix_save_op, op)
78
78
} else {
79
79
let all_variable_ops =
80
- all_variable_ops_opt. get_or_insert_with_result ( || self . make_all_variable_ops ( ) ) ?;
80
+ all_variable_ops_opt. get_or_insert_with (
81
+ || self . variables . iter ( ) . map ( |v| v. output . operation . clone ( ) ) . collect :: < Vec < _ > > ( ) ) ;
81
82
let prefix_save = ops:: Placeholder :: new ( )
82
83
. dtype ( crate :: DataType :: String )
83
84
. build ( & mut self . scope . with_op_name ( "prefix_save" ) ) ?;
84
85
let tensor_names = ops:: constant (
85
- & self
86
+ self
86
87
. variables
87
88
. iter ( )
88
- . map ( |v| ( * v ) . to_string ( ) )
89
- . collect :: < Vec < _ > > ( ) [ .. ] ,
89
+ . map ( |v| String :: from ( v . name ( ) ) )
90
+ . collect :: < Vec < _ > > ( ) . as_slice ( ) ,
90
91
& mut self . scope ,
91
92
) ?;
92
93
let shape_and_slices = ops:: constant (
@@ -126,14 +127,15 @@ impl CheckpointMaker {
126
127
( the_prefix_restore, op)
127
128
} else {
128
129
let all_variable_ops =
129
- all_variable_ops_opt. get_or_insert_with_result ( || self . make_all_variable_ops ( ) ) ?;
130
+ all_variable_ops_opt. get_or_insert_with (
131
+ || self . variables . iter ( ) . map ( |v| v. output . operation . clone ( ) ) . collect :: < Vec < _ > > ( ) ) ;
130
132
let prefix_restore = ops:: Placeholder :: new ( )
131
133
. dtype ( crate :: DataType :: String )
132
134
. build ( & mut self . scope . with_op_name ( "prefix_restore" ) ) ?;
133
135
let all_var_names = self
134
136
. variables
135
137
. iter ( )
136
- . map ( |v| v. to_string ( ) )
138
+ . map ( |v| v. name . clone ( ) )
137
139
. collect :: < Vec < _ > > ( ) ;
138
140
let tensor_names = ops:: constant ( & all_var_names[ ..] , & mut self . scope ) ?;
139
141
let shape_and_slices = ops:: constant (
@@ -158,10 +160,7 @@ impl CheckpointMaker {
158
160
drop ( g) ;
159
161
let mut restore_var_ops = Vec :: < Operation > :: new ( ) ;
160
162
for ( i, var) in self . variables . iter ( ) . enumerate ( ) {
161
- let var_op = self
162
- . scope
163
- . graph ( )
164
- . operation_by_name_required ( var. as_str ( ) ) ?;
163
+ let var_op = var. output . operation . clone ( ) ;
165
164
restore_var_ops. push ( ops:: assign (
166
165
var_op,
167
166
crate :: Output {
@@ -357,16 +356,9 @@ mod tests {
357
356
& [ 11.0 , 12.0 , 13.6 , 17.1 , 18.4 , 19.5 ] ,
358
357
] ;
359
358
assign_variables ( & first_session, & first_scope_data, & assign_data, & new_values) ?;
360
- let variable_names = first_scope_data
361
- . variables
362
- . as_ref ( )
363
- . iter ( )
364
- . map ( |v| String :: from ( v. name ( ) ) )
365
- . collect :: < Vec < _ > > ( )
366
- . into_boxed_slice ( ) ;
367
359
let mut checkpoint = CheckpointMaker :: new (
368
- first_scope_data. scope . new_sub_scope ( "" ) ,
369
- variable_names . clone ( ) ,
360
+ first_scope_data. scope . new_sub_scope ( "checkpoint " ) ,
361
+ Box :: from ( first_scope_data . variables . clone ( ) ) ,
370
362
) ;
371
363
let temp_dir = tempdir:: TempDir :: new ( "test-tensorflow" ) ?;
372
364
let checkpoint_path = temp_dir. path ( ) . join ( "checkpoint-vars" ) ;
@@ -380,7 +372,7 @@ mod tests {
380
372
variables : second_variables,
381
373
} = create_scope ( ) ?;
382
374
let second_session = Session :: new ( & SessionOptions :: new ( ) , & second_scope. graph ( ) ) ?;
383
- let mut second_checkpoint = CheckpointMaker :: new ( second_scope, variable_names ) ;
375
+ let mut second_checkpoint = CheckpointMaker :: new ( second_scope, Box :: new ( second_variables . clone ( ) ) ) ;
384
376
second_checkpoint. restore ( & second_session, checkpoint_path_str. as_str ( ) ) ?;
385
377
check_variables ( & second_session, & second_variables, & new_values) ?;
386
378
Ok ( ( ) )
0 commit comments