File tree Expand file tree Collapse file tree 1 file changed +31
-1
lines changed Expand file tree Collapse file tree 1 file changed +31
-1
lines changed Original file line number Diff line number Diff line change @@ -257,7 +257,7 @@ impl Graph {
257
257
tf:: TF_GraphGetTensorShape ( self . gimpl . inner ,
258
258
output. to_c ( ) ,
259
259
dims. as_mut_ptr ( ) ,
260
- dims . len ( ) as c_int ,
260
+ n ,
261
261
status. inner ( ) ) ;
262
262
if status. is_ok ( ) {
263
263
dims. set_len ( n as usize ) ;
@@ -1039,4 +1039,34 @@ mod tests {
1039
1039
let status = g. import_graph_def ( & [ ] , & opts) ;
1040
1040
assert ! ( status. is_ok( ) ) ;
1041
1041
}
1042
+
1043
+ #[ test]
1044
+ fn test_get_tensor_shape ( ) {
1045
+ fn constant < T : TensorType > ( graph : & mut Graph , name : & str , value : Tensor < T > ) -> Operation {
1046
+ let mut c = graph. new_operation ( "Const" , name) . unwrap ( ) ;
1047
+ c. set_attr_tensor ( "value" , value) . unwrap ( ) ;
1048
+ c. set_attr_type ( "dtype" , T :: data_type ( ) ) . unwrap ( ) ;
1049
+ c. finish ( ) . unwrap ( )
1050
+ }
1051
+
1052
+ let mut graph = Graph :: new ( ) ;
1053
+ let x_init = Tensor :: < i32 > :: new ( & [ 3 , 3 ] ) ;
1054
+ let x = constant ( & mut graph, "x/assign_0" , x_init) ;
1055
+ assert_eq ! ( 1 , x. num_outputs( ) ) ;
1056
+ assert_eq ! ( x. output_type( 0 ) , DataType :: Int32 ) ;
1057
+ let dims = graph
1058
+ . num_dims ( Output {
1059
+ operation : x. clone ( ) ,
1060
+ index : 0 ,
1061
+ } )
1062
+ . unwrap ( ) ;
1063
+ assert_eq ! ( dims, 2 ) ;
1064
+ let shape = graph
1065
+ . tensor_shape ( Output {
1066
+ operation : x. clone ( ) ,
1067
+ index : 0 ,
1068
+ } )
1069
+ . unwrap ( ) ;
1070
+ assert_eq ! ( shape, Shape ( Some ( vec![ Some ( 3_i64 ) , Some ( 3_i64 ) ] ) ) ) ;
1071
+ }
1042
1072
}
You can’t perform that action at this time.
0 commit comments