Skip to content

Commit 38e6d28

Browse files
Update opset imports in version_converter
1 parent 0e09e58 commit 38e6d28

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

onnxscript/version_converter/_version_converter.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,38 @@ class _VersionConverter:
220220
def __init__(self, target_version: int):
221221
self.target_version = target_version
222222

223+
def _maybe_set_opset_version(
224+
self, model_or_function: ir.Model | ir.Function, domain: str, version: int | None
225+
):
226+
"""Set the opset version for the domain."""
227+
current_version = model_or_function.opset_imports.get(domain)
228+
if version is None or current_version is None:
229+
return
230+
if domain == "":
231+
model_or_function.opset_imports[domain] = max(version, current_version)
232+
return
233+
elif domain == "ai.onnx":
234+
model_or_function.opset_imports[domain] = max(version, current_version)
235+
return
236+
else:
237+
return
238+
239+
def _update_opset_imports(self, model: ir.Model) -> None:
240+
"""Collect all opsets used and add opset imports to the model and functions."""
241+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
242+
domain = node.domain
243+
self._maybe_set_opset_version(model, domain, node.version)
244+
245+
for function in model.functions.values():
246+
for node in ir.traversal.RecursiveGraphIterator(function):
247+
domain = node.domain
248+
self._maybe_set_opset_version(function, domain, node.version)
249+
for domain, version in function.opset_imports.items():
250+
# Add all opsets used in the function to the model, because ONNX Runtime
251+
# does not handle adding the opset imports to the model after inlining during inference.
252+
# This should happen after all opsets are collected for the function from its nodes.
253+
self._maybe_set_opset_version(model, domain, version)
254+
223255
def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: bool) -> None:
224256
if up_conversion is True:
225257
node.version = opset_version + 1
@@ -320,6 +352,8 @@ def visit_model(self, model: ir.Model) -> None:
320352
return None
321353
self.model_version = model_version
322354
self.visit_graph(model.graph)
355+
# Finally, update the opset imports for the model
356+
self._update_opset_imports(model)
323357
return None
324358

325359

onnxscript/version_converter/_version_converter_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def test_version_convert_compatible(self):
115115
)
116116
model = ir.serde.deserialize_model(model_proto)
117117
target_version = 19
118+
self.assertEqual(model.opset_imports[""], 18)
118119
version_converter.convert_version(model, target_version=target_version)
119120

120121
self.assertEqual(model.graph.node(0).op_type, "Constant")
@@ -123,6 +124,7 @@ def test_version_convert_compatible(self):
123124
self.assertEqual(model.graph.node(1).version, 19)
124125
self.assertEqual(model.graph.node(4).op_type, "MatMul")
125126
self.assertEqual(model.graph.node(4).version, 19)
127+
self.assertEqual(model.opset_imports[""], 19)
126128

127129

128130
class VersionConverter19to20Test(unittest.TestCase):
@@ -142,6 +144,7 @@ def test_version_convert_compatible(self):
142144
)
143145
model = ir.serde.deserialize_model(model_proto)
144146
target_version = 20
147+
self.assertEqual(model.opset_imports[""], 18)
145148
version_converter.convert_version(model, target_version=target_version)
146149

147150
self.assertEqual(model.graph.node(0).op_type, "Constant")
@@ -153,6 +156,7 @@ def test_version_convert_compatible(self):
153156
self.assertEqual(model.graph.node(3).op_type, "DFT")
154157
self.assertEqual(model.graph.node(3).version, 20)
155158
self.assertEqual(len(model.graph.node(3).inputs), 2)
159+
self.assertEqual(model.opset_imports[""], 20)
156160

157161
def test_version_convert_gridsample_linear(self):
158162
model_proto = onnx.parser.parse_model(
@@ -175,6 +179,7 @@ def test_version_convert_gridsample_linear(self):
175179
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")
176180

177181
target_version = 20
182+
self.assertEqual(model.opset_imports[""], 18)
178183
version_converter.convert_version(model, target_version=target_version)
179184

180185
self.assertEqual(model.graph.node(0).op_type, "Constant")
@@ -184,6 +189,7 @@ def test_version_convert_gridsample_linear(self):
184189
self.assertEqual(model.graph.node(4).op_type, "GridSample")
185190
self.assertEqual(model.graph.node(4).version, 20)
186191
self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear")
192+
self.assertEqual(model.opset_imports[""], 20)
187193

188194
def test_version_convert_gridsample_cubic(self):
189195
model_proto = onnx.parser.parse_model(
@@ -206,6 +212,7 @@ def test_version_convert_gridsample_cubic(self):
206212
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bicubic")
207213

208214
target_version = 20
215+
self.assertEqual(model.opset_imports[""], 18)
209216
version_converter.convert_version(model, target_version=target_version)
210217

211218
self.assertEqual(model.graph.node(0).op_type, "Constant")
@@ -215,6 +222,7 @@ def test_version_convert_gridsample_cubic(self):
215222
self.assertEqual(model.graph.node(4).op_type, "GridSample")
216223
self.assertEqual(model.graph.node(4).version, 20)
217224
self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic")
225+
self.assertEqual(model.opset_imports[""], 20)
218226

219227
def test_version_convert_inline(self):
220228
model_proto = onnx.parser.parse_model(
@@ -238,6 +246,7 @@ def test_version_convert_inline(self):
238246
)
239247
model = ir.serde.deserialize_model(model_proto)
240248
target_version = 20
249+
self.assertEqual(model.opset_imports[""], 18)
241250
version_converter.convert_version(model, target_version=target_version)
242251

243252
self.assertEqual(model.graph.node(0).op_type, "Constant")
@@ -250,6 +259,7 @@ def test_version_convert_inline(self):
250259
self.assertEqual(model.graph.node(6).op_type, "DFT")
251260
self.assertEqual(model.graph.node(6).version, 20)
252261
self.assertEqual(len(model.graph.node(6).inputs), 2)
262+
self.assertEqual(model.opset_imports[""], 20)
253263

254264

255265
class VersionConverter20to21Test(unittest.TestCase):
@@ -267,6 +277,7 @@ def test_version_groupnorm(self):
267277
)
268278
model = ir.serde.deserialize_model(model_proto)
269279
target_version = 21
280+
self.assertEqual(model.opset_imports[""], 18)
270281
version_converter.convert_version(model, target_version=target_version)
271282

272283
self.assertEqual(model.graph.node(3).op_type, "Reshape")
@@ -283,6 +294,7 @@ def test_version_groupnorm(self):
283294
self.assertEqual(model.graph.node(8).version, 21)
284295
self.assertEqual(model.graph.node(9).op_type, "GroupNormalization")
285296
self.assertEqual(model.graph.node(9).version, 21)
297+
self.assertEqual(model.opset_imports[""], 21)
286298

287299
def test_version_groupnorm_no_bias(self):
288300
model_proto = onnx.parser.parse_model(
@@ -298,10 +310,12 @@ def test_version_groupnorm_no_bias(self):
298310
)
299311
model = ir.serde.deserialize_model(model_proto)
300312
target_version = 21
313+
self.assertEqual(model.opset_imports[""], 18)
301314
version_converter.convert_version(model, target_version=target_version)
302315

303316
self.assertEqual(model.graph.node(0).op_type, "GroupNormalization")
304317
self.assertEqual(model.graph.node(0).version, 20)
318+
self.assertEqual(model.opset_imports[""], 21)
305319

306320

307321
class VersionConverter23to24Test(unittest.TestCase):

0 commit comments

Comments
 (0)