5
5
import unittest
6
6
7
7
import onnx .defs
8
+ import pytest
8
9
9
10
from onnxscript import ir , version_converter
10
11
@@ -41,18 +42,19 @@ def test_upstream_coverage(self):
41
42
self .assertEqual (domain , "" )
42
43
self .assertIn ((name , upgrade_version ), op_upgrades )
43
44
44
- def test_version_convert_non_standard_onnx_domain (self ):
45
+ @pytest .mark .xfail (reason = "TODO: Cleanup error status API." )
46
+ def test_version_convert_no_source_version (self ):
45
47
model = ir .from_onnx_text (
46
48
"""
47
49
<ir_version: 7, opset_import: [ "local" : 1]>
48
50
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
49
51
{
50
- shape_a = Constant<value: tensor = int64[5 ] {1, 4, 512, 512}>()
52
+ shape_a = Constant<value: tensor = int64[4 ] {1, 4, 512, 512}>()
51
53
reshape_x = Reshape (input_x, shape_a)
52
- shape_b = Constant<value: tensor = int64[5 ] {1, 4, 1024, 1024}>()
54
+ shape_b = Constant<value: tensor = int64[4 ] {1, 4, 1024, 1024}>()
53
55
reshape_y = Reshape (input_x, shape_b)
54
56
gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
55
- shape_c = Constant<value: tensor = int64[4 ] {4, 1024, 1024}>()
57
+ shape_c = Constant<value: tensor = int64[3 ] {4, 1024, 1024}>()
56
58
output = Reshape (gridsample, shape_c)
57
59
}
58
60
"""
@@ -63,16 +65,9 @@ def test_version_convert_non_standard_onnx_domain(self):
63
65
target_version = 20
64
66
version_converter .convert_version (model , target_version = target_version )
65
67
66
- self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
67
- self .assertEqual (model .graph .node (0 ).version , None )
68
- self .assertEqual (model .graph .node (1 ).op_type , "Reshape" )
69
- self .assertEqual (model .graph .node (1 ).version , None )
70
- self .assertEqual (model .graph .node (4 ).op_type , "GridSample" )
71
- self .assertEqual (model .graph .node (4 ).version , None )
72
- self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "bilinear" )
73
-
74
68
75
69
class VersionConverter18to17Test (unittest .TestCase ):
70
+ @pytest .mark .xfail (strict = True , reason = "Version downgrade not yet supported." )
76
71
def test_version_convert_compatible (self ):
77
72
model = ir .from_onnx_text (
78
73
"""
@@ -112,6 +107,7 @@ def test_version_convert_compatible(self):
112
107
)
113
108
target_version = 19
114
109
version_converter .convert_version (model , target_version = target_version )
110
+ self .assertEqual (model .opset_imports ["" ], target_version )
115
111
116
112
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
117
113
self .assertEqual (model .graph .node (0 ).version , 19 )
@@ -138,6 +134,7 @@ def test_version_convert_compatible(self):
138
134
)
139
135
target_version = 20
140
136
version_converter .convert_version (model , target_version = target_version )
137
+ self .assertEqual (model .opset_imports ["" ], target_version )
141
138
142
139
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
143
140
self .assertEqual (model .graph .node (0 ).version , 20 )
@@ -170,6 +167,7 @@ def test_version_convert_gridsample_linear(self):
170
167
171
168
target_version = 20
172
169
version_converter .convert_version (model , target_version = target_version )
170
+ self .assertEqual (model .opset_imports ["" ], target_version )
173
171
174
172
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
175
173
self .assertEqual (model .graph .node (0 ).version , 20 )
@@ -200,6 +198,7 @@ def test_version_convert_gridsample_cubic(self):
200
198
201
199
target_version = 20
202
200
version_converter .convert_version (model , target_version = target_version )
201
+ self .assertEqual (model .opset_imports ["" ], target_version )
203
202
204
203
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
205
204
self .assertEqual (model .graph .node (0 ).version , 20 )
@@ -231,6 +230,7 @@ def test_version_convert_inline(self):
231
230
)
232
231
target_version = 20
233
232
version_converter .convert_version (model , target_version = target_version )
233
+ self .assertEqual (model .opset_imports ["" ], target_version )
234
234
235
235
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
236
236
self .assertEqual (model .graph .node (0 ).version , 20 )
@@ -259,6 +259,7 @@ def test_version_groupnorm(self):
259
259
)
260
260
target_version = 21
261
261
version_converter .convert_version (model , target_version = target_version )
262
+ self .assertEqual (model .opset_imports ["" ], target_version )
262
263
263
264
self .assertEqual (model .graph .node (3 ).op_type , "Reshape" )
264
265
self .assertEqual (model .graph .node (3 ).version , 21 )
@@ -289,12 +290,14 @@ def test_version_groupnorm_no_bias(self):
289
290
)
290
291
target_version = 21
291
292
version_converter .convert_version (model , target_version = target_version )
293
+ self .assertEqual (model .opset_imports ["" ], target_version )
292
294
293
295
self .assertEqual (model .graph .node (0 ).op_type , "GroupNormalization" )
294
296
self .assertEqual (model .graph .node (0 ).version , 20 )
295
297
296
298
297
299
class VersionConverter23to24Test (unittest .TestCase ):
300
+ @pytest .mark .xfail (strict = True , reason = "Version upgrade beyond 23 not yet supported." )
298
301
def test_version_convert_compatible (self ):
299
302
model = ir .from_onnx_text (
300
303
"""
0 commit comments