Skip to content

Commit 15fffa5

Browse files
authored
Merge pull request tensorflow#93 from iduartgomez/fix-tensorshape
fix tensor_shape method for Graph
2 parents 81456d3 + f0b26d2 commit 15fffa5

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

src/graph.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ impl Graph {
257257
tf::TF_GraphGetTensorShape(self.gimpl.inner,
258258
output.to_c(),
259259
dims.as_mut_ptr(),
260-
dims.len() as c_int,
260+
n,
261261
status.inner());
262262
if status.is_ok() {
263263
dims.set_len(n as usize);
@@ -1039,4 +1039,34 @@ mod tests {
10391039
let status = g.import_graph_def(&[], &opts);
10401040
assert!(status.is_ok());
10411041
}
1042+
1043+
#[test]
1044+
fn test_get_tensor_shape() {
1045+
fn constant<T: TensorType>(graph: &mut Graph, name: &str, value: Tensor<T>) -> Operation {
1046+
let mut c = graph.new_operation("Const", name).unwrap();
1047+
c.set_attr_tensor("value", value).unwrap();
1048+
c.set_attr_type("dtype", T::data_type()).unwrap();
1049+
c.finish().unwrap()
1050+
}
1051+
1052+
let mut graph = Graph::new();
1053+
let x_init = Tensor::<i32>::new(&[3, 3]);
1054+
let x = constant(&mut graph, "x/assign_0", x_init);
1055+
assert_eq!(1, x.num_outputs());
1056+
assert_eq!(x.output_type(0), DataType::Int32);
1057+
let dims = graph
1058+
.num_dims(Output {
1059+
operation: x.clone(),
1060+
index: 0,
1061+
})
1062+
.unwrap();
1063+
assert_eq!(dims, 2);
1064+
let shape = graph
1065+
.tensor_shape(Output {
1066+
operation: x.clone(),
1067+
index: 0,
1068+
})
1069+
.unwrap();
1070+
assert_eq!(shape, Shape(Some(vec![Some(3_i64), Some(3_i64)])));
1071+
}
10421072
}

0 commit comments

Comments
 (0)