@@ -276,14 +276,19 @@ def test_extract_constant_to_initializer(self): # type: () -> None
276
276
[helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (1 , 5 , 3 , 3 )),
277
277
helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (16 , 5 , 3 , 3 ))],
278
278
[helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (1 , 16 , 3 , 3 ))],
279
- value_info = [
280
- helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (16 , 1 , 1 )),
281
- ]
282
279
)
283
280
optimized_model = self ._optimized (graph , ["extract_constant_to_initializer" ])
284
- assert len (list (optimized_model .graph .initializer )) == 1
285
- assert len (list (optimized_model .graph .node )) == 2
286
- assert "A" in [i .name for i in optimized_model .graph .initializer ]
281
+ self .assertEqual (
282
+ set (vi .name for vi in optimized_model .graph .input ),
283
+ {'X' , 'Y' , 'A' })
284
+
285
+ self .assertEqual (len (optimized_model .graph .initializer ), 1 )
286
+ init = optimized_model .graph .initializer [0 ]
287
+ self .assertEqual (init .name , 'A' )
288
+ self .assertEqual (init .dims , [16 ])
289
+ self .assertEqual (init .data_type , TensorProto .FLOAT )
290
+
291
+ self .assertEqual ([n .op_type for n in optimized_model .graph .node ], ['Conv' , 'Add' ])
287
292
288
293
def test_fuse_transpose (self ): # type: () -> None
289
294
nodes = [helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [1 , 0 , 2 ]),
0 commit comments