Skip to content

Commit 6c73264

Browse files
committed
Generate name for each member of list arg
1 parent c40cd02 commit 6c73264

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

pytorch_pfn_extras/onnx/export_testcase.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,23 @@ def export_testcase(
279279
os.makedirs(out_dir, exist_ok=True)
280280
if isinstance(args, torch.Tensor):
281281
args = args,
282-
input_names = kwargs.pop(
283-
'input_names',
284-
['input_{}'.format(i) for i in range(len(args))])
285-
assert len(input_names) == len(args)
282+
283+
# We unroll list args and generate names for each tensor.
284+
gen_input_names = []
285+
unrolled_args = []
286+
287+
def append_input_name(prefix: str, arg: Any) -> None:
288+
if isinstance(arg, list):
289+
for i, a in enumerate(arg):
290+
append_input_name(prefix + f"_{i}", a)
291+
else:
292+
gen_input_names.append(prefix)
293+
unrolled_args.append(arg)
294+
for i, arg in enumerate(args):
295+
append_input_name(f"input_{i}", arg)
296+
297+
input_names = kwargs.pop('input_names', gen_input_names)
298+
assert len(input_names) == len(unrolled_args)
286299
assert not isinstance(args, torch.Tensor)
287300

288301
onnx_graph, outs = _export(
@@ -302,7 +315,7 @@ def export_testcase(
302315
if used_input.name not in initializer_names:
303316
used_input_index_list.append(input_names.index(used_input.name))
304317
input_names = [input_names[i] for i in used_input_index_list]
305-
args = [args[i] for i in used_input_index_list]
318+
unrolled_args = [unrolled_args[i] for i in used_input_index_list]
306319

307320
output_path = os.path.join(out_dir, 'model.onnx')
308321
is_on_memory = True
@@ -341,7 +354,7 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non
341354
os.makedirs(data_set_path, exist_ok=True)
342355
for pb_name in glob.glob(os.path.join(data_set_path, "*.pb")):
343356
os.remove(pb_name)
344-
for i, (arg, name) in enumerate(zip(args, input_names)):
357+
for i, (arg, name) in enumerate(zip(unrolled_args, input_names)):
345358
f = os.path.join(data_set_path, 'input_{}.pb'.format(i))
346359
write_to_pb(f, arg, name)
347360

0 commit comments

Comments
 (0)