Skip to content

Commit ee15246

Browse files
committed
Generate name for each member of list arg
1 parent c40cd02 commit ee15246

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

Diff for: pytorch_pfn_extras/onnx/export_testcase.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,22 @@ def export_testcase(
279279
os.makedirs(out_dir, exist_ok=True)
280280
if isinstance(args, torch.Tensor):
281281
args = args,
282-
input_names = kwargs.pop(
283-
'input_names',
284-
['input_{}'.format(i) for i in range(len(args))])
285-
assert len(input_names) == len(args)
282+
283+
# We unroll list args and generate names for each tensor.
284+
gen_input_names = []
285+
unrolled_args = []
286+
def append_input_name(prefix, arg) -> None:
287+
if isinstance(arg, list):
288+
for i, a in enumerate(arg):
289+
append_input_name(prefix + f"_{i}", a)
290+
else:
291+
gen_input_names.append(prefix)
292+
unrolled_args.append(arg)
293+
for i, arg in enumerate(args):
294+
append_input_name(f"input_{i}", arg)
295+
296+
input_names = kwargs.pop('input_names', gen_input_names)
297+
assert len(input_names) == len(unrolled_args)
286298
assert not isinstance(args, torch.Tensor)
287299

288300
onnx_graph, outs = _export(
@@ -302,7 +314,7 @@ def export_testcase(
302314
if used_input.name not in initializer_names:
303315
used_input_index_list.append(input_names.index(used_input.name))
304316
input_names = [input_names[i] for i in used_input_index_list]
305-
args = [args[i] for i in used_input_index_list]
317+
unrolled_args = [unrolled_args[i] for i in used_input_index_list]
306318

307319
output_path = os.path.join(out_dir, 'model.onnx')
308320
is_on_memory = True
@@ -341,7 +353,7 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non
341353
os.makedirs(data_set_path, exist_ok=True)
342354
for pb_name in glob.glob(os.path.join(data_set_path, "*.pb")):
343355
os.remove(pb_name)
344-
for i, (arg, name) in enumerate(zip(args, input_names)):
356+
for i, (arg, name) in enumerate(zip(unrolled_args, input_names)):
345357
f = os.path.join(data_set_path, 'input_{}.pb'.format(i))
346358
write_to_pb(f, arg, name)
347359

0 commit comments

Comments
 (0)