Skip to content

Commit c46b8f0

Browse files
authored
Run all models in modelzoo with sample test data (onnx#841)
1. Updated test_modelzoo.py to run all models with the provided test data in model zoo Signed-off-by: Winnie Tsang <[email protected]>
1 parent 80d2806 commit c46b8f0

File tree

1 file changed

+201
-27
lines changed

1 file changed

+201
-27
lines changed

test/test_modelzoo.py

+201-27
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
import argparse
1616
import datetime
17+
import glob
1718
import math
1819
import os
1920
import platform
2021
import shutil
2122
import subprocess
2223
import sys
24+
import tarfile
2325
import tempfile
2426

27+
import numpy as np
2528
import onnx
2629
import tensorflow as tf
2730
import onnx_tf
@@ -97,6 +100,40 @@ def summary(self):
97100
' (dry_run)' if _CFG['dry_run'] else '')
98101

99102

103+
def _get_model_and_test_data():
104+
"""Get the filename of the model and directory of the test data set"""
105+
onnx_model = None
106+
test_data_set = []
107+
for root, dirs, files in os.walk(_CFG['untar_directory']):
108+
for dir_name in dirs:
109+
if dir_name.startswith('test_data_set_'):
110+
test_data_set.append(os.path.join(root, dir_name))
111+
for file_name in files:
112+
if file_name.endswith('.onnx') and not file_name.startswith('.'):
113+
onnx_model = os.path.join(root, file_name)
114+
elif (file_name.startswith('input_') and file_name.endswith('.pb') and
115+
len(test_data_set) == 0):
116+
# data files are not in test_data_set_* but in the same
117+
# directory of onnx file
118+
test_data_set.append(root)
119+
elif file_name.startswith('test_data_') and file_name.endswith('.npz'):
120+
test_data_file = os.path.join(root, file_name)
121+
test_data_dir = os.path.join(root, file_name.split('.')[0])
122+
new_test_data_file = os.path.join(test_data_dir, file_name)
123+
os.mkdir(test_data_dir)
124+
os.rename(test_data_file, new_test_data_file)
125+
test_data_set.append(test_data_dir)
126+
return onnx_model, test_data_set
127+
128+
129+
def _extract_model_and_test_data(file_path):
130+
"""Extract all files in the tar.gz to test_model_and_data_dir"""
131+
tar = tarfile.open(file_path, "r:gz")
132+
tar.extractall(_CFG['untar_directory'])
133+
tar.close()
134+
return _get_model_and_test_data()
135+
136+
100137
def _pull_model_file(file_path):
101138
"""Use Git LFS to pull down a large file.
102139
@@ -123,15 +160,28 @@ def _pull_model_file(file_path):
123160
shell=True,
124161
check=True,
125162
stdout=subprocess.DEVNULL)
126-
new_size = os.stat(model_path).st_size
163+
if file_path.endswith('.tar.gz'):
164+
onnx_model, test_data_set = _extract_model_and_test_data(model_path)
165+
else:
166+
onnx_model = model_path
167+
test_data_set = []
168+
new_size = os.stat(onnx_model).st_size
127169
pulled = new_size != file_size
128170
file_size = new_size
129-
return (file_size, pulled)
171+
else:
172+
# model file is pulled already
173+
if file_path.endswith('.tar.gz'):
174+
onnx_model, test_data_set = _extract_model_and_test_data(model_path)
175+
else:
176+
onnx_model = model_path
177+
test_data_set = []
178+
return (file_size, pulled), onnx_model, test_data_set
130179

131180

132181
def _revert_model_pointer(file_path):
133182
"""Remove downloaded model, revert to pointer, remove cached file."""
134183
cmd_args = ('rm -f {0} && '
184+
'git reset HEAD {0} && '
135185
'git checkout {0} && '
136186
'rm -f $(find . | grep $(grep oid {0} | cut -d ":" -f 2))'
137187
).format(file_path)
@@ -147,17 +197,24 @@ def _include_model(file_path):
147197
return True
148198
for item in _CFG['include']:
149199
if (file_path.startswith(item) or file_path.endswith(item + '.onnx') or
150-
'/{}/model/'.format(item) in file_path):
200+
'/{}/model/'.format(item) in file_path or
201+
'/{}/models/'.format(item) in file_path):
151202
return True
152203
return False
153204

154205

155206
def _has_models(dir_path):
156-
for item in os.listdir(os.path.join(_CFG['models_dir'], dir_path, 'model')):
157-
if item.endswith('.onnx'):
158-
file_path = os.path.join(dir_path, 'model', item)
159-
if _include_model(file_path):
160-
return True
207+
for m_dir in ['model', 'models']:
208+
# in age_gender there are 2 different models in there so the
209+
# directory is "models" instead of "model" like the rest of
210+
# the other models
211+
model_dir = os.path.join(_CFG['models_dir'], dir_path, m_dir)
212+
if os.path.exists(model_dir):
213+
for item in os.listdir(model_dir):
214+
if item.endswith('.onnx'):
215+
file_path = os.path.join(dir_path, model_dir, item)
216+
if _include_model(file_path):
217+
return True
161218
return False
162219

163220

@@ -187,6 +244,7 @@ def _report_check_model(model):
187244
onnx.checker.check_model(model)
188245
return ''
189246
except Exception as ex:
247+
_del_location(_CFG['untar_directory'])
190248
first_line = str(ex).strip().split('\n')[0].strip()
191249
return '{}: {}'.format(type(ex).__name__, first_line)
192250

@@ -195,53 +253,157 @@ def _report_convert_model(model):
195253
"""Test conversion and returns a report string."""
196254
try:
197255
tf_rep = onnx_tf.backend.prepare(model)
198-
tf_rep.export_graph(_CFG['output_filename'])
199-
_del_location(_CFG['output_filename'])
256+
tf_rep.export_graph(_CFG['output_directory'])
200257
return ''
201258
except Exception as ex:
202-
_del_location(_CFG['output_filename'])
203-
strack_trace = str(ex).strip().split('\n')
204-
if len(strack_trace) > 1:
205-
err_msg = strack_trace[-1].strip()
259+
_del_location(_CFG['untar_directory'])
260+
_del_location(_CFG['output_directory'])
261+
stack_trace = str(ex).strip().split('\n')
262+
if len(stack_trace) > 1:
263+
err_msg = stack_trace[-1].strip()
206264
# OpUnsupportedException gets raised as a RuntimeError
207265
if 'OP_UNSUPPORTED_EXCEPT' in str(ex):
208266
err_msg = err_msg.replace(type(ex).__name__, 'OpUnsupportedException')
209267
return err_msg
210-
return '{}: {}'.format(type(ex).__name__, strack_trace[0].strip())
268+
return '{}: {}'.format(type(ex).__name__, stack_trace[0].strip())
269+
270+
271+
def _get_inputs_outputs_pb(tf_rep, data_dir):
272+
"""Get the input and reference output tensors"""
273+
inputs = {}
274+
inputs_num = len(glob.glob(os.path.join(data_dir, 'input_*.pb')))
275+
for i in range(inputs_num):
276+
input_file = os.path.join(data_dir, 'input_{}.pb'.format(i))
277+
tensor = onnx.TensorProto()
278+
with open(input_file, 'rb') as f:
279+
tensor.ParseFromString(f.read())
280+
tensor.name = tensor.name if tensor.name else tf_rep.inputs[i]
281+
inputs[tensor.name] = onnx.numpy_helper.to_array(tensor)
282+
ref_outputs = {}
283+
ref_outputs_num = len(glob.glob(os.path.join(data_dir, 'output_*.pb')))
284+
for i in range(ref_outputs_num):
285+
output_file = os.path.join(data_dir, 'output_{}.pb'.format(i))
286+
tensor = onnx.TensorProto()
287+
with open(output_file, 'rb') as f:
288+
tensor.ParseFromString(f.read())
289+
tensor.name = tensor.name if tensor.name else tf_rep.outputs[i]
290+
ref_outputs[tensor.name] = onnx.numpy_helper.to_array(tensor)
291+
return inputs, ref_outputs
292+
293+
294+
def _get_inputs_outputs_npz(tf_rep, data_dir):
295+
"""Get the input and reference output tensors"""
296+
npz_file = os.path.join(data_dir, '{}.npz'.format(data_dir.split('/')[-1]))
297+
data = np.load(npz_file, encoding='bytes')
298+
inputs = {}
299+
ref_outputs = {}
300+
for i in range(len(tf_rep.inputs)):
301+
inputs[tf_rep.inputs[i]] = data['inputs'][i]
302+
for i in range(len(tf_rep.outputs)):
303+
ref_outputs[tf_rep.outputs[i]] = data['outputs'][i]
304+
return inputs, ref_outputs
305+
306+
307+
def _get_inputs_and_ref_outputs(tf_rep, data_dir):
308+
"""Get the input and reference output tensors"""
309+
if len(glob.glob(os.path.join(data_dir, 'input_*.pb'))) > 0:
310+
inputs, ref_outputs = _get_inputs_outputs_pb(tf_rep, data_dir)
311+
else:
312+
inputs, ref_outputs = _get_inputs_outputs_npz(tf_rep, data_dir)
313+
return inputs, ref_outputs
314+
315+
316+
def _assert_outputs(outputs, ref_outputs, rtol, atol):
317+
np.testing.assert_equal(len(outputs), len(ref_outputs))
318+
for key in outputs.keys():
319+
np.testing.assert_equal(outputs[key].dtype, ref_outputs[key].dtype)
320+
if ref_outputs[key].dtype == np.object:
321+
np.testing.assert_array_equal(outputs[key], ref_outputs[key])
322+
else:
323+
np.testing.assert_allclose(outputs[key],
324+
ref_outputs[key],
325+
rtol=rtol,
326+
atol=atol)
327+
328+
329+
def _report_run_model(model, data_set):
330+
"""Run the model and returns a report string."""
331+
try:
332+
tf_rep = onnx_tf.backend.prepare(model)
333+
for data in data_set:
334+
inputs, ref_outputs = _get_inputs_and_ref_outputs(tf_rep, data)
335+
outputs = tf_rep.run(inputs)
336+
outputs_dict = {}
337+
for i in range(len(tf_rep.outputs)):
338+
outputs_dict[tf_rep.outputs[i]] = outputs[i]
339+
_assert_outputs(outputs_dict, ref_outputs, rtol=1e-3, atol=1e-3)
340+
except Exception as ex:
341+
stack_trace = str(ex).strip().split('\n')
342+
if len(stack_trace) > 1:
343+
if ex.__class__ == AssertionError:
344+
return stack_trace[:5]
345+
else:
346+
return stack_trace[-1].strip()
347+
return '{}: {}'.format(type(ex).__name__, stack_trace[0].strip())
348+
finally:
349+
_del_location(_CFG['untar_directory'])
350+
_del_location(_CFG['output_directory'])
211351

212352

213353
def _report_model(file_path, results=Results(), onnx_model_count=1):
214354
"""Generate a report status for a single model, and append it to results."""
215-
size_pulled = _pull_model_file(file_path)
355+
size_pulled, onnx_model, test_data_set = _pull_model_file(file_path)
216356
if _CFG['dry_run']:
217357
ir_version = ''
218358
opset_version = ''
219359
check_err = ''
220360
convert_err = ''
361+
ran_err = ''
221362
emoji_validated = ''
222363
emoji_converted = ''
364+
emoji_ran = ''
223365
emoji_overall = ':heavy_minus_sign:'
224366
results.skip_count += 1
225367
else:
226368
if _CFG['verbose']:
227369
print('Testing', file_path)
228-
model = onnx.load(os.path.join(_CFG['models_dir'], file_path))
370+
model = onnx.load(onnx_model)
229371
ir_version = model.ir_version
230372
opset_version = model.opset_import[0].version
231373
check_err = _report_check_model(model)
232374
convert_err = '' if check_err else _report_convert_model(model)
375+
run_err = '' if convert_err or len(
376+
test_data_set) == 0 else _report_run_model(model, test_data_set)
233377

234-
if not check_err and not convert_err:
378+
if (not check_err and not convert_err and not run_err and
379+
len(test_data_set) > 0):
235380
# https://github-emoji-list.herokuapp.com/
236-
# validation and conversion passed
381+
# ran successfully
237382
emoji_validated = ':ok:'
238383
emoji_converted = ':ok:'
384+
emoji_ran = ':ok:'
239385
emoji_overall = ':heavy_check_mark:'
240386
results.pass_count += 1
387+
elif (not check_err and not convert_err and not run_err and
388+
len(test_data_set) == 0):
389+
# validation & conversion passed but no test data available
390+
emoji_validated = ':ok:'
391+
emoji_converted = ':ok:'
392+
emoji_ran = 'No test data provided in model zoo'
393+
emoji_overall = ':warning:'
394+
results.warn_count += 1
395+
elif not check_err and not convert_err:
396+
# validation & conversion passed but failed to run
397+
emoji_validated = ':ok:'
398+
emoji_converted = ':ok:'
399+
emoji_ran = run_err
400+
emoji_overall = ':x:'
401+
results.fail_count += 1
241402
elif not check_err:
242403
# validation pass, but conversion did not
243404
emoji_validated = ':ok:'
244405
emoji_converted = convert_err
406+
emoji_ran = ':heavy_minus_sign:'
245407
if ('BackendIsNotSupposedToImplementIt' in convert_err or
246408
'OpUnsupportedException' in convert_err):
247409
# known limitations
@@ -257,14 +419,18 @@ def _report_model(file_path, results=Results(), onnx_model_count=1):
257419
# validation failed
258420
emoji_validated = check_err
259421
emoji_converted = ':heavy_minus_sign:'
422+
emoji_ran = ':heavy_minus_sign:'
260423
emoji_overall = ':x:'
261424
results.fail_count += 1
262425

263-
results.append_detail('{} | {}. {} | {} | {} | {} | {} | {}'.format(
264-
emoji_overall, onnx_model_count, file_path[file_path.rindex('/') + 1:],
426+
results.append_detail('{} | {}. {} | {} | {} | {} | {} | {} | {}'.format(
427+
emoji_overall, onnx_model_count,
428+
file_path[file_path.rindex('/') + 1:file_path.index('.')],
265429
_size_with_units(size_pulled[0]), ir_version, opset_version,
266-
emoji_validated, emoji_converted))
430+
emoji_validated, emoji_converted, emoji_ran))
267431

432+
if len(test_data_set) == 0:
433+
_del_location(_CFG['output_directory'])
268434
if size_pulled[1]:
269435
# only remove model if it was pulled above on-demand
270436
_revert_model_pointer(file_path)
@@ -291,7 +457,8 @@ def _configure(models_dir='models',
291457
_configure_env()
292458

293459
norm_output_dir = os.path.normpath(output_dir)
294-
_CFG['output_filename'] = os.path.join(norm_output_dir, 'tmp_model.pb')
460+
_CFG['untar_directory'] = os.path.join(norm_output_dir, 'test_model_and_data')
461+
_CFG['output_directory'] = os.path.join(norm_output_dir, 'test_model_pb')
295462
_CFG['report_filename'] = os.path.join(norm_output_dir,
296463
_CFG['report_filename'])
297464

@@ -354,13 +521,14 @@ def modelzoo_report(models_dir='models',
354521

355522
_configure(models_dir, output_dir, include, verbose, dry_run)
356523
_del_location(_CFG['report_filename'])
357-
_del_location(_CFG['output_filename'])
524+
_del_location(_CFG['output_directory'])
525+
_del_location(_CFG['untar_directory'])
358526

359527
# run tests first, but append to report after summary
360528
results = Results()
361529
for root, subdir, files in os.walk(_CFG['models_dir']):
362530
subdir.sort()
363-
if 'model' in subdir:
531+
if 'model' in subdir or 'models' in subdir:
364532
dir_path = os.path.relpath(root, _CFG['models_dir'])
365533
if _has_models(dir_path):
366534
results.model_count += 1
@@ -371,15 +539,21 @@ def modelzoo_report(models_dir='models',
371539
results.append_detail('')
372540
results.append_detail(
373541
'Status | Model | Size | IR | Opset | ONNX Checker | '
374-
'ONNX-TF Converted')
542+
'ONNX-TF Converted | ONNX-TF Ran')
375543
results.append_detail(
376544
'------ | ----- | ---- | -- | ----- | ------------ | '
377-
'---------')
545+
'----------------- | -----------')
378546
onnx_model_count = 0
547+
file_path = ''
379548
for item in sorted(files):
380549
if item.endswith('.onnx'):
381550
file_path = os.path.relpath(os.path.join(root, item),
382551
_CFG['models_dir'])
552+
# look for gz file for this model
553+
gzfile_path = file_path.replace('.onnx', '.tar.gz')
554+
gzfile_path = os.path.join(_CFG['models_dir'], gzfile_path)
555+
if gzfile_path in glob.glob(gzfile_path):
556+
file_path = os.path.relpath(gzfile_path, _CFG['models_dir'])
383557
if _include_model(file_path):
384558
onnx_model_count += 1
385559
results.total_count += 1

0 commit comments

Comments
 (0)