@@ -115,6 +115,7 @@ def test_version_convert_compatible(self):
115
115
)
116
116
model = ir .serde .deserialize_model (model_proto )
117
117
target_version = 19
118
+ self .assertEqual (model .opset_imports ["" ], 18 )
118
119
version_converter .convert_version (model , target_version = target_version )
119
120
120
121
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
@@ -123,6 +124,7 @@ def test_version_convert_compatible(self):
123
124
self .assertEqual (model .graph .node (1 ).version , 19 )
124
125
self .assertEqual (model .graph .node (4 ).op_type , "MatMul" )
125
126
self .assertEqual (model .graph .node (4 ).version , 19 )
127
+ self .assertEqual (model .opset_imports ["" ], 19 )
126
128
127
129
128
130
class VersionConverter19to20Test (unittest .TestCase ):
@@ -142,6 +144,7 @@ def test_version_convert_compatible(self):
142
144
)
143
145
model = ir .serde .deserialize_model (model_proto )
144
146
target_version = 20
147
+ self .assertEqual (model .opset_imports ["" ], 18 )
145
148
version_converter .convert_version (model , target_version = target_version )
146
149
147
150
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
@@ -153,6 +156,7 @@ def test_version_convert_compatible(self):
153
156
self .assertEqual (model .graph .node (3 ).op_type , "DFT" )
154
157
self .assertEqual (model .graph .node (3 ).version , 20 )
155
158
self .assertEqual (len (model .graph .node (3 ).inputs ), 2 )
159
+ self .assertEqual (model .opset_imports ["" ], 20 )
156
160
157
161
def test_version_convert_gridsample_linear (self ):
158
162
model_proto = onnx .parser .parse_model (
@@ -175,6 +179,7 @@ def test_version_convert_gridsample_linear(self):
175
179
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "bilinear" )
176
180
177
181
target_version = 20
182
+ self .assertEqual (model .opset_imports ["" ], 18 )
178
183
version_converter .convert_version (model , target_version = target_version )
179
184
180
185
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
@@ -184,6 +189,7 @@ def test_version_convert_gridsample_linear(self):
184
189
self .assertEqual (model .graph .node (4 ).op_type , "GridSample" )
185
190
self .assertEqual (model .graph .node (4 ).version , 20 )
186
191
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "linear" )
192
+ self .assertEqual (model .opset_imports ["" ], 20 )
187
193
188
194
def test_version_convert_gridsample_cubic (self ):
189
195
model_proto = onnx .parser .parse_model (
@@ -206,6 +212,7 @@ def test_version_convert_gridsample_cubic(self):
206
212
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "bicubic" )
207
213
208
214
target_version = 20
215
+ self .assertEqual (model .opset_imports ["" ], 18 )
209
216
version_converter .convert_version (model , target_version = target_version )
210
217
211
218
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
@@ -215,6 +222,7 @@ def test_version_convert_gridsample_cubic(self):
215
222
self .assertEqual (model .graph .node (4 ).op_type , "GridSample" )
216
223
self .assertEqual (model .graph .node (4 ).version , 20 )
217
224
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "cubic" )
225
+ self .assertEqual (model .opset_imports ["" ], 20 )
218
226
219
227
def test_version_convert_inline (self ):
220
228
model_proto = onnx .parser .parse_model (
@@ -238,6 +246,7 @@ def test_version_convert_inline(self):
238
246
)
239
247
model = ir .serde .deserialize_model (model_proto )
240
248
target_version = 20
249
+ self .assertEqual (model .opset_imports ["" ], 18 )
241
250
version_converter .convert_version (model , target_version = target_version )
242
251
243
252
self .assertEqual (model .graph .node (0 ).op_type , "Constant" )
@@ -250,6 +259,7 @@ def test_version_convert_inline(self):
250
259
self .assertEqual (model .graph .node (6 ).op_type , "DFT" )
251
260
self .assertEqual (model .graph .node (6 ).version , 20 )
252
261
self .assertEqual (len (model .graph .node (6 ).inputs ), 2 )
262
+ self .assertEqual (model .opset_imports ["" ], 20 )
253
263
254
264
255
265
class VersionConverter20to21Test (unittest .TestCase ):
@@ -267,6 +277,7 @@ def test_version_groupnorm(self):
267
277
)
268
278
model = ir .serde .deserialize_model (model_proto )
269
279
target_version = 21
280
+ self .assertEqual (model .opset_imports ["" ], 18 )
270
281
version_converter .convert_version (model , target_version = target_version )
271
282
272
283
self .assertEqual (model .graph .node (3 ).op_type , "Reshape" )
@@ -283,6 +294,7 @@ def test_version_groupnorm(self):
283
294
self .assertEqual (model .graph .node (8 ).version , 21 )
284
295
self .assertEqual (model .graph .node (9 ).op_type , "GroupNormalization" )
285
296
self .assertEqual (model .graph .node (9 ).version , 21 )
297
+ self .assertEqual (model .opset_imports ["" ], 21 )
286
298
287
299
def test_version_groupnorm_no_bias (self ):
288
300
model_proto = onnx .parser .parse_model (
@@ -298,10 +310,12 @@ def test_version_groupnorm_no_bias(self):
298
310
)
299
311
model = ir .serde .deserialize_model (model_proto )
300
312
target_version = 21
313
+ self .assertEqual (model .opset_imports ["" ], 18 )
301
314
version_converter .convert_version (model , target_version = target_version )
302
315
303
316
self .assertEqual (model .graph .node (0 ).op_type , "GroupNormalization" )
304
317
self .assertEqual (model .graph .node (0 ).version , 20 )
318
+ self .assertEqual (model .opset_imports ["" ], 21 )
305
319
306
320
307
321
class VersionConverter23to24Test (unittest .TestCase ):
0 commit comments