Skip to content

Commit 2247ed8

Browse files
committed
In checkpoints, change the specification of variables to Variable objects rather than variable names.
1 parent 013a829 commit 2247ed8

File tree

1 file changed

+21
-29
lines changed

1 file changed

+21
-29
lines changed

src/checkpoint.rs

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
//! let mut scope = Scope::new_root_scope();
77
//! // add operations to define the graph
88
//! // ...
9-
//! // let "w" and "b" the name of the variables that we wish to save
9+
//! // let w and b the variables that we wish to save
1010
//! let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
11-
//! vec![String::from("w"), String::from("b")].into_boxed_slice(),
11+
//! vec![w.clone(), b.clone()].into_boxed_slice(),
1212
//! );
1313
//! let session = Session::new(&SessionOptions::new(), &scope.graph())?;
1414
//! // run some training
@@ -19,7 +19,7 @@
1919
//! let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
2020
//! checkpoint_maker.save(&new_session, "data/checkpoint")?;
2121
use crate::option_insert_result::OptionInsertWithResult;
22-
use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor};
22+
use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor, Variable};
2323

2424
#[derive(Debug)]
2525
struct SaveRestoreOps {
@@ -35,25 +35,25 @@ struct SaveRestoreOps {
3535
#[derive(Debug)]
3636
pub struct CheckpointMaker {
3737
scope: Scope,
38-
variables: Box<[String]>,
38+
variables: Box<[Variable]>,
3939
save_restore_ops: Option<SaveRestoreOps>,
4040
}
4141

4242
impl CheckpointMaker {
4343
/// Creates a new CheckpointMaker for a Scope, with a list of variables to save/restore.
4444
/// The scope is used to modify the graph to add the save and restore ops.
4545
///
46-
/// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("")
47-
/// as Scope does not support the Clone trait at present
48-
pub fn new(scope: Scope, variables: Box<[String]>) -> CheckpointMaker {
46+
/// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("checkpoint")
47+
/// in order to create the nodes with scoped names
48+
pub fn new(scope: Scope, variables: Box<[Variable]>) -> CheckpointMaker {
4949
CheckpointMaker {
5050
scope,
5151
variables,
5252
save_restore_ops: None,
5353
}
5454
}
5555

56-
fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
56+
/* fn make_all_variable_ops(&mut self) -> Result<Vec<Operation>, Status> {
5757
let graph = self.scope.graph();
5858
Ok(self
5959
.variables
@@ -62,7 +62,7 @@ impl CheckpointMaker {
6262
Ok(graph.operation_by_name_required(v.as_str())?.clone())
6363
})
6464
.collect::<Result<Vec<_>, Status>>()?)
65-
}
65+
}*/
6666

6767
/// Add save and restore ops to the graph
6868
fn build_save_ops(&mut self) -> Result<SaveRestoreOps, Status> {
@@ -77,16 +77,17 @@ impl CheckpointMaker {
7777
(prefix_save_op, op)
7878
} else {
7979
let all_variable_ops =
80-
all_variable_ops_opt.get_or_insert_with_result(|| self.make_all_variable_ops())?;
80+
all_variable_ops_opt.get_or_insert_with(
81+
|| self.variables.iter().map(|v| v.output.operation.clone() ).collect::<Vec<_>>());
8182
let prefix_save = ops::Placeholder::new()
8283
.dtype(crate::DataType::String)
8384
.build(&mut self.scope.with_op_name("prefix_save"))?;
8485
let tensor_names = ops::constant(
85-
&self
86+
self
8687
.variables
8788
.iter()
88-
.map(|v| (*v).to_string())
89-
.collect::<Vec<_>>()[..],
89+
.map(|v| String::from(v.name()))
90+
.collect::<Vec<_>>().as_slice(),
9091
&mut self.scope,
9192
)?;
9293
let shape_and_slices = ops::constant(
@@ -126,14 +127,15 @@ impl CheckpointMaker {
126127
(the_prefix_restore, op)
127128
} else {
128129
let all_variable_ops =
129-
all_variable_ops_opt.get_or_insert_with_result(|| self.make_all_variable_ops())?;
130+
all_variable_ops_opt.get_or_insert_with(
131+
|| self.variables.iter().map(|v| v.output.operation.clone() ).collect::<Vec<_>>());
130132
let prefix_restore = ops::Placeholder::new()
131133
.dtype(crate::DataType::String)
132134
.build(&mut self.scope.with_op_name("prefix_restore"))?;
133135
let all_var_names = self
134136
.variables
135137
.iter()
136-
.map(|v| v.to_string())
138+
.map(|v| v.name.clone())
137139
.collect::<Vec<_>>();
138140
let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?;
139141
let shape_and_slices = ops::constant(
@@ -158,10 +160,7 @@ impl CheckpointMaker {
158160
drop(g);
159161
let mut restore_var_ops = Vec::<Operation>::new();
160162
for (i, var) in self.variables.iter().enumerate() {
161-
let var_op = self
162-
.scope
163-
.graph()
164-
.operation_by_name_required(var.as_str())?;
163+
let var_op = var.output.operation.clone();
165164
restore_var_ops.push(ops::assign(
166165
var_op,
167166
crate::Output {
@@ -357,16 +356,9 @@ mod tests {
357356
&[11.0, 12.0, 13.6, 17.1, 18.4, 19.5],
358357
];
359358
assign_variables(&first_session, &first_scope_data, &assign_data, &new_values)?;
360-
let variable_names = first_scope_data
361-
.variables
362-
.as_ref()
363-
.iter()
364-
.map(|v| String::from(v.name()))
365-
.collect::<Vec<_>>()
366-
.into_boxed_slice();
367359
let mut checkpoint = CheckpointMaker::new(
368-
first_scope_data.scope.new_sub_scope(""),
369-
variable_names.clone(),
360+
first_scope_data.scope.new_sub_scope("checkpoint"),
361+
Box::from(first_scope_data.variables.clone()),
370362
);
371363
let temp_dir = tempdir::TempDir::new("test-tensorflow")?;
372364
let checkpoint_path = temp_dir.path().join("checkpoint-vars");
@@ -380,7 +372,7 @@ mod tests {
380372
variables: second_variables,
381373
} = create_scope()?;
382374
let second_session = Session::new(&SessionOptions::new(), &second_scope.graph())?;
383-
let mut second_checkpoint = CheckpointMaker::new(second_scope, variable_names);
375+
let mut second_checkpoint = CheckpointMaker::new(second_scope, Box::new(second_variables.clone()));
384376
second_checkpoint.restore(&second_session, checkpoint_path_str.as_str())?;
385377
check_variables(&second_session, &second_variables, &new_values)?;
386378
Ok(())

0 commit comments

Comments
 (0)