@@ -279,10 +279,23 @@ def export_testcase(
279
279
os .makedirs (out_dir , exist_ok = True )
280
280
if isinstance (args , torch .Tensor ):
281
281
args = args ,
282
- input_names = kwargs .pop (
283
- 'input_names' ,
284
- ['input_{}' .format (i ) for i in range (len (args ))])
285
- assert len (input_names ) == len (args )
282
+
283
+ # We unroll list args and generate names for each tensor.
284
+ gen_input_names = []
285
+ unrolled_args = []
286
+
287
+ def append_input_name (prefix : str , arg : Any ) -> None :
288
+ if isinstance (arg , list ):
289
+ for i , a in enumerate (arg ):
290
+ append_input_name (prefix + f"_{ i } " , a )
291
+ else :
292
+ gen_input_names .append (prefix )
293
+ unrolled_args .append (arg )
294
+ for i , arg in enumerate (args ):
295
+ append_input_name (f"input_{ i } " , arg )
296
+
297
+ input_names = kwargs .pop ('input_names' , gen_input_names )
298
+ assert len (input_names ) == len (unrolled_args )
286
299
assert not isinstance (args , torch .Tensor )
287
300
288
301
onnx_graph , outs = _export (
@@ -302,7 +315,7 @@ def export_testcase(
302
315
if used_input .name not in initializer_names :
303
316
used_input_index_list .append (input_names .index (used_input .name ))
304
317
input_names = [input_names [i ] for i in used_input_index_list ]
305
- args = [args [i ] for i in used_input_index_list ]
318
+ unrolled_args = [unrolled_args [i ] for i in used_input_index_list ]
306
319
307
320
output_path = os .path .join (out_dir , 'model.onnx' )
308
321
is_on_memory = True
@@ -341,7 +354,7 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non
341
354
os .makedirs (data_set_path , exist_ok = True )
342
355
for pb_name in glob .glob (os .path .join (data_set_path , "*.pb" )):
343
356
os .remove (pb_name )
344
- for i , (arg , name ) in enumerate (zip (args , input_names )):
357
+ for i , (arg , name ) in enumerate (zip (unrolled_args , input_names )):
345
358
f = os .path .join (data_set_path , 'input_{}.pb' .format (i ))
346
359
write_to_pb (f , arg , name )
347
360
0 commit comments