Skip to content

Commit eabbfbc

Browse files
committed
Automatically skip unit tests of not implemented operators in onnx-tf
1 parent 663502d commit eabbfbc

File tree

4 files changed

+37
-4
lines changed

4 files changed

+37
-4
lines changed

doc/CLI.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ optional arguments:
2222
### Convert:
2323

2424
#### From ONNX to Tensorflow:
25-
`onnx-tf convert -t tf -i /path/to/input.onnx -o /path/to/output.pb`
25+
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
2626

2727
More information: `onnx-tf convert -h`
2828
```

doc/CLI_template.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ More information: `onnx-tf -h`
1414
### Convert:
1515

1616
#### From ONNX to Tensorflow:
17-
`onnx-tf convert -t tf -i /path/to/input.onnx -o /path/to/output.pb`
17+
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
1818

1919
More information: `onnx-tf convert -h`
2020
```

test/backend/test_onnx_backend.py

+35
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,54 @@
44
from __future__ import unicode_literals
55

66
import os
7+
import re
78
import unittest
89

910
import onnx.backend.test
1011

12+
from onnx import defs
13+
14+
from onnx_tf import opset_version
1115
from onnx_tf.backend import TensorflowBackend
1216
from onnx_tf.common.legacy import legacy_onnx_pre_ver
1317
from onnx_tf.common.legacy import legacy_opset_pre_ver
1418

19+
def get_onnxtf_supported_ops():
20+
return opset_version.backend_opset_version
21+
22+
def get_onnx_supported_ops():
23+
onnx_opset_dict = {}
24+
for schema in defs.get_all_schemas():
25+
op = schema.name
26+
onnx_opset_dict[op] = schema.since_version
27+
return onnx_opset_dict
28+
29+
def skip_not_implemented_ops_test(test):
30+
onnxtf_ops_list = get_onnxtf_supported_ops()
31+
onnx_ops_list = get_onnx_supported_ops()
32+
for op in onnx_ops_list:
33+
if op in onnxtf_ops_list:
34+
if onnx_ops_list[op] not in onnxtf_ops_list[op]:
35+
test.exclude(r'[a-z,_]*' + op.lower() + '[a-z,_]*')
36+
else:
37+
test.exclude(r'[a-z,_]*' + op.lower() + '[a-z,_]*')
38+
return test
39+
1540
# This is a pytest magic variable to load extra plugins
1641
pytest_plugins = 'onnx.backend.test.report',
1742

1843
backend_test = onnx.backend.test.BackendTest(TensorflowBackend, __name__)
1944

45+
# exclude tests of not-implemented-ops
46+
backend_test = skip_not_implemented_ops_test(backend_test)
47+
48+
# manually exclude tests of not-implemented-ops that are using "short name" in their testcase name
49+
# need to remove these lines once those ops support are added into onnx-tf
50+
# temporary exclude StringNormalizer test
51+
backend_test.exclude(r'[a-z,_]*strnorm[a-z,_]*')
52+
# temporary exclude MeanVarianceNormalization test
53+
backend_test.exclude(r'[a-z,_]*mvn[a-z,_]*')
54+
2055
# https://github.com/onnx/onnx/issues/349
2156
backend_test.exclude(r'[a-z,_]*GLU[a-z,_]*')
2257

test/test_cli.py

-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ def test_convert_to_tf(self):
6868
subprocess.check_call([
6969
"onnx-tf",
7070
"convert",
71-
"-t",
72-
"tf",
7371
"-i",
7472
os.path.join(model_dir, '{}.onnx'.format(model_name)),
7573
"-o",

0 commit comments

Comments
 (0)