Skip to content

Commit 1ca1bd2

Browse files
update version_converter loop order
1 parent 9a71e8f commit 1ca1bd2

File tree

2 files changed

+30
-53
lines changed

2 files changed

+30
-53
lines changed

onnxscript/version_converter/_version_converter.py

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
import copy
78
import dataclasses
89
import functools
910
import logging
@@ -220,38 +221,6 @@ class _VersionConverter:
220221
def __init__(self, target_version: int):
221222
self.target_version = target_version
222223

223-
def _maybe_set_opset_version(
224-
self, model_or_function: ir.Model | ir.Function, domain: str, version: int | None
225-
):
226-
"""Set the opset version for the domain."""
227-
current_version = model_or_function.opset_imports.get(domain)
228-
if version is None or current_version is None:
229-
return
230-
if domain == "":
231-
model_or_function.opset_imports[domain] = max(version, current_version)
232-
return
233-
elif domain == "ai.onnx":
234-
model_or_function.opset_imports[domain] = max(version, current_version)
235-
return
236-
else:
237-
return
238-
239-
def _update_opset_imports(self, model: ir.Model) -> None:
240-
"""Collect all opsets used and add opset imports to the model and functions."""
241-
for node in ir.traversal.RecursiveGraphIterator(model.graph):
242-
domain = node.domain
243-
self._maybe_set_opset_version(model, domain, node.version)
244-
245-
for function in model.functions.values():
246-
for node in ir.traversal.RecursiveGraphIterator(function):
247-
domain = node.domain
248-
self._maybe_set_opset_version(function, domain, node.version)
249-
for domain, version in function.opset_imports.items():
250-
# Add all opsets used in the function to the model, because ONNX Runtime
251-
# does not handle adding the opset imports to the model after inlining during inference.
252-
# This should happen after all opsets are collected for the function from its nodes.
253-
self._maybe_set_opset_version(model, domain, version)
254-
255224
def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: bool) -> None:
256225
if up_conversion is True:
257226
node.version = opset_version + 1
@@ -315,32 +284,40 @@ def visit_graph(self, graph: ir.Graph) -> None:
315284
self.target_version,
316285
)
317286
return None
318-
for node in graph:
319-
up_conversion = True
320-
if node.version is None:
321-
node.version = self.model_version
322-
# Iterate each node from current node version -> target version
323-
# and updating node based on the correct adapter
324-
# Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
325-
# TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted
326-
if self.target_version < node.version:
327-
up_conversion = False
328-
logger.warning(
329-
"Target opset: %s less than %s, downstream version conversion not currently handled.",
330-
self.target_version,
331-
self.model_version,
332-
)
333-
return None
334-
for opset_version in range(node.version, self.target_version):
287+
288+
# TODO(shubhambhokare1): Support down-conversion
289+
while self.model_version < self.target_version:
290+
pre_conversion_graph = copy.copy(graph)
291+
# Up-convert each node in the graph from opset_version -> opset_version + 1
292+
# or down-convert from opset_version -> opset_version - 1
293+
# Return non-converted graph if any node fails to convert.
294+
for node in graph:
295+
up_conversion = True
296+
if node.version is None:
297+
node.version = self.model_version
298+
if self.target_version < node.version:
299+
up_conversion = False
300+
# TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted
301+
logger.warning(
302+
"Target opset: %s less than %s, downstream version conversion not currently handled.",
303+
self.target_version,
304+
node.version,
305+
)
306+
graph = pre_conversion_graph
307+
return None
335308
try:
336-
self.visit_node(node, graph, opset_version, up_conversion)
337-
self._upgrade_version(node, opset_version, up_conversion)
309+
self.visit_node(node, graph, self.model_version, up_conversion)
310+
self._upgrade_version(node, self.model_version, up_conversion)
338311
except VersionConverterError as e:
339312
logger.warning(
340313
"Skipping version conversion for node %s due to exception: %s",
341314
node.op_type,
342315
e,
343316
)
317+
graph = pre_conversion_graph
318+
return None
319+
self.model_version += 1
320+
del pre_conversion_graph
344321
return None
345322

346323
def visit_model(self, model: ir.Model) -> None:
@@ -353,7 +330,7 @@ def visit_model(self, model: ir.Model) -> None:
353330
self.model_version = model_version
354331
self.visit_graph(model.graph)
355332
# Finally, update the opset imports for the model
356-
self._update_opset_imports(model)
333+
model.opset_imports[""] = self.model_version
357334
return None
358335

359336

onnxscript/version_converter/_version_converter_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_version_groupnorm_no_bias(self):
315315

316316
self.assertEqual(model.graph.node(0).op_type, "GroupNormalization")
317317
self.assertEqual(model.graph.node(0).version, 20)
318-
self.assertEqual(model.opset_imports[""], 21)
318+
self.assertEqual(model.opset_imports[""], 20)
319319

320320

321321
class VersionConverter23to24Test(unittest.TestCase):

0 commit comments

Comments
 (0)