Skip to content

Fixes to version converter #2318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import onnx

import onnxscript.ir.passes
import onnxscript.ir.passes.common
from onnxscript import ir
from onnxscript.ir.passes.common import _c_api_utils
Expand Down
99 changes: 53 additions & 46 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@
SUPPORTED_MIN_ONNX_OPSET = 18


def _get_onnx_opset_version(model: ir.Model) -> int | None:
"""Get the ONNX opset version imported by the model."""
model_version1 = model.opset_imports.get("")
model_version2 = model.opset_imports.get("ai.onnx")
if model_version1 is not None and model_version2 is not None:
if model_version1 != model_version2:
raise ValueError(

Check warning on line 29 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L29

Added line #L29 was not covered by tests
f"Model imports multiple onnx opsets: {model_version1} and {model_version2}."
)
return model_version1 or model_version2


def _set_onnx_opset_version(model: ir.Model, version: int) -> None:
"""Set the ONNX opset version imported by the model."""
if "ai.onnx" in model.opset_imports:
del model.opset_imports["ai.onnx"]

Check warning on line 38 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L38

Added line #L38 was not covered by tests
model.opset_imports[""] = version


class VersionConverterError(RuntimeError):
"""Raised when an node's version cannot be upgraded/downgraded successfully."""

Expand Down Expand Up @@ -215,25 +234,15 @@


class _VersionConverter:
opset_imports: dict[str, int]
model_version: int

def __init__(self, target_version: int):
self.target_version = target_version

def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: bool) -> None:
if up_conversion is True:
node.version = opset_version + 1
else:
node.version = opset_version - 1
self._target_version = target_version

def process_node(
self, node: ir.Node, opset_version: int, up_conversion: bool = True
self, node: ir.Node, from_version: int, up_conversion: bool = True
) -> Replacement | None:
if node.domain != "":
return None
assert node.domain == ""
adapter = registry.lookup_adapters(
node.domain, node.op_type, opset_version, up_conversion
node.domain, node.op_type, from_version, up_conversion
)
if adapter is None:
return None
Expand Down Expand Up @@ -264,67 +273,65 @@
self,
node: ir.Node,
root: ir.Graph | ir.Function,
opset_version: int,
from_version: int,
up_conversion: bool = True,
) -> None:
replacement = self.process_node(node, opset_version, up_conversion)
if up_conversion:
to_version = from_version + 1
else:
to_version = from_version - 1

Check warning on line 282 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L282

Added line #L282 was not covered by tests
replacement = self.process_node(node, from_version, up_conversion)
if replacement is None:
# No change. Process attributes.
for attr in node.attributes.values():
self.visit_attribute(attr)
return None
node.version = to_version
else:
for new_node in replacement.new_nodes:
# TODO: control-flow
new_node.version = to_version
self.replace_node(node, replacement, root)
return None

def visit_graph(self, graph: ir.Graph) -> None:
if self.target_version > SUPPORTED_MAX_ONNX_OPSET:
logger.warning(
"Conversion to target opset: %s not currently supported.",
self.target_version,
)
return None
for node in graph:
up_conversion = True
if node.version is None:
node.version = self.model_version
if node.domain != "":
continue

Check warning on line 298 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L298

Added line #L298 was not covered by tests
node_version = node.version or self._default_onnx_opset
if node_version is None:
raise VersionConverterError(f"Node {node} has no version.")
# Iterate each node from current node version -> target version
# and updating node based on the correct adapter
# Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
# TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted
if self.target_version < node.version:
up_conversion = False
logger.warning(
"Target opset: %s less than %s, downstream version conversion not currently handled.",
self.target_version,
self.model_version,
if self._target_version < node_version:
raise VersionConverterError(

Check warning on line 307 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L307

Added line #L307 was not covered by tests
f"Target opset: {self._target_version} less than node version: {node.version}, "
"downstream version conversion not currently handled."
)
return None
for opset_version in range(node.version, self.target_version):
for from_version in range(node_version, self._target_version):
try:
self.visit_node(node, graph, opset_version, up_conversion)
self._upgrade_version(node, opset_version, up_conversion)
self.visit_node(node, graph, from_version, up_conversion=True)
except VersionConverterError as e:
logger.warning(
"Skipping version conversion for node %s due to exception: %s",
node.op_type,
e,
)
return None

def visit_model(self, model: ir.Model) -> None:
self.opset_imports = model.opset_imports
model_version = self.opset_imports.get("")
if model_version is None:
model_version = model.opset_imports.get("ai.onnx")
if model_version is None:
return None
self.model_version = model_version
self._default_onnx_opset = _get_onnx_opset_version(model)
self.visit_graph(model.graph)
return None
_set_onnx_opset_version(model, self._target_version)


def convert_version(model: ir.Model, target_version: int) -> None:
"""Convert the model to the specified ONNX opset version."""
if (target_version > SUPPORTED_MAX_ONNX_OPSET) or (
target_version < SUPPORTED_MIN_ONNX_OPSET
):
raise ValueError(
f"Target opset version {target_version} is not supported. "
f"Supported range: {SUPPORTED_MIN_ONNX_OPSET} to {SUPPORTED_MAX_ONNX_OPSET}."
)
version_converter = _VersionConverter(target_version=target_version)
version_converter.visit_model(model)
27 changes: 15 additions & 12 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import unittest

import onnx.defs
import pytest

from onnxscript import ir, version_converter

Expand Down Expand Up @@ -41,18 +42,19 @@ def test_upstream_coverage(self):
self.assertEqual(domain, "")
self.assertIn((name, upgrade_version), op_upgrades)

def test_version_convert_non_standard_onnx_domain(self):
@pytest.mark.xfail(reason="TODO: Cleanup error status API.")
def test_version_convert_no_source_version(self):
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "local" : 1]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
{
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512}>()
shape_a = Constant<value: tensor = int64[4] {1, 4, 512, 512}>()
reshape_x = Reshape (input_x, shape_a)
shape_b = Constant<value: tensor = int64[5] {1, 4, 1024, 1024}>()
shape_b = Constant<value: tensor = int64[4] {1, 4, 1024, 1024}>()
reshape_y = Reshape (input_x, shape_b)
gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
shape_c = Constant<value: tensor = int64[4] {4, 1024, 1024}>()
shape_c = Constant<value: tensor = int64[3] {4, 1024, 1024}>()
output = Reshape (gridsample, shape_c)
}
"""
Expand All @@ -63,16 +65,9 @@ def test_version_convert_non_standard_onnx_domain(self):
target_version = 20
version_converter.convert_version(model, target_version=target_version)

self.assertEqual(model.graph.node(0).op_type, "Constant")
self.assertEqual(model.graph.node(0).version, None)
self.assertEqual(model.graph.node(1).op_type, "Reshape")
self.assertEqual(model.graph.node(1).version, None)
self.assertEqual(model.graph.node(4).op_type, "GridSample")
self.assertEqual(model.graph.node(4).version, None)
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")


class VersionConverter18to17Test(unittest.TestCase):
@pytest.mark.xfail(strict=True, reason="Version downgrade not yet supported.")
def test_version_convert_compatible(self):
model = ir.from_onnx_text(
"""
Expand Down Expand Up @@ -112,6 +107,7 @@ def test_version_convert_compatible(self):
)
target_version = 19
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

self.assertEqual(model.graph.node(0).op_type, "Constant")
self.assertEqual(model.graph.node(0).version, 19)
Expand All @@ -138,6 +134,7 @@ def test_version_convert_compatible(self):
)
target_version = 20
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

self.assertEqual(model.graph.node(0).op_type, "Constant")
self.assertEqual(model.graph.node(0).version, 20)
Expand Down Expand Up @@ -170,6 +167,7 @@ def test_version_convert_gridsample_linear(self):

target_version = 20
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

self.assertEqual(model.graph.node(0).op_type, "Constant")
self.assertEqual(model.graph.node(0).version, 20)
Expand Down Expand Up @@ -200,6 +198,7 @@ def test_version_convert_gridsample_cubic(self):

target_version = 20
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

self.assertEqual(model.graph.node(0).op_type, "Constant")
self.assertEqual(model.graph.node(0).version, 20)
Expand Down Expand Up @@ -231,6 +230,7 @@ def test_version_convert_inline(self):
)
target_version = 20
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

self.assertEqual(model.graph.node(0).op_type, "Constant")
self.assertEqual(model.graph.node(0).version, 20)
Expand Down Expand Up @@ -259,6 +259,7 @@ def test_version_groupnorm(self):
)
target_version = 21
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

self.assertEqual(model.graph.node(3).op_type, "Reshape")
self.assertEqual(model.graph.node(3).version, 21)
Expand Down Expand Up @@ -289,12 +290,14 @@ def test_version_groupnorm_no_bias(self):
)
target_version = 21
version_converter.convert_version(model, target_version=target_version)
self.assertEqual(model.opset_imports[""], target_version)

self.assertEqual(model.graph.node(0).op_type, "GroupNormalization")
self.assertEqual(model.graph.node(0).version, 20)


class VersionConverter23to24Test(unittest.TestCase):
@pytest.mark.xfail(strict=True, reason="Version upgrade beyond 23 not yet supported.")
def test_version_convert_compatible(self):
model = ir.from_onnx_text(
"""
Expand Down
Loading