|
4 | 4 | from __future__ import unicode_literals
|
5 | 5 |
|
6 | 6 | import os
|
| 7 | +import re |
7 | 8 | import unittest
|
8 | 9 |
|
9 | 10 | import onnx.backend.test
|
10 | 11 |
|
| 12 | +from onnx import defs |
| 13 | + |
| 14 | +from onnx_tf import opset_version |
11 | 15 | from onnx_tf.backend import TensorflowBackend
|
12 | 16 | from onnx_tf.common.legacy import legacy_onnx_pre_ver
|
13 | 17 | from onnx_tf.common.legacy import legacy_opset_pre_ver
|
14 | 18 |
|
| 19 | +def get_onnxtf_supported_ops(): |
| 20 | + return opset_version.backend_opset_version |
| 21 | + |
| 22 | +def get_onnx_supported_ops(): |
| 23 | + onnx_opset_dict = {} |
| 24 | + for schema in defs.get_all_schemas(): |
| 25 | + op = schema.name |
| 26 | + onnx_opset_dict[op] = schema.since_version |
| 27 | + return onnx_opset_dict |
| 28 | + |
| 29 | +def skip_not_implemented_ops_test(test): |
| 30 | + onnxtf_ops_list = get_onnxtf_supported_ops() |
| 31 | + onnx_ops_list = get_onnx_supported_ops() |
| 32 | + for op in onnx_ops_list: |
| 33 | + if op in onnxtf_ops_list: |
| 34 | + if onnx_ops_list[op] not in onnxtf_ops_list[op]: |
| 35 | + test.exclude(r'[a-z,_]*' + op.lower() + '[a-z,_]*') |
| 36 | + else: |
| 37 | + test.exclude(r'[a-z,_]*' + op.lower() + '[a-z,_]*') |
| 38 | + return test |
| 39 | + |
15 | 40 | # This is a pytest magic variable to load extra plugins
|
16 | 41 | pytest_plugins = 'onnx.backend.test.report',
|
17 | 42 |
|
18 | 43 | backend_test = onnx.backend.test.BackendTest(TensorflowBackend, __name__)
|
19 | 44 |
|
| 45 | +# exclude tests of not-implemented-ops |
| 46 | +backend_test = skip_not_implemented_ops_test(backend_test) |
| 47 | + |
| 48 | +# manually exclude tests of not-implemented-ops that are using "short name" in their testcase name |
| 49 | +# need to remove these lines once those ops support are added into onnx-tf |
| 50 | +# temporary exclude StringNormalizer test |
| 51 | +backend_test.exclude(r'[a-z,_]*strnorm[a-z,_]*') |
| 52 | +# temporary exclude MeanVarianceNormalization test |
| 53 | +backend_test.exclude(r'[a-z,_]*mvn[a-z,_]*') |
| 54 | + |
20 | 55 | # https://github.com/onnx/onnx/issues/349
|
21 | 56 | backend_test.exclude(r'[a-z,_]*GLU[a-z,_]*')
|
22 | 57 |
|
|
0 commit comments