Skip to content

Commit 9341a32

Browse files
authored
[DRAFT] Fixes to version converter (#2318)
Redo of PR #2295 as discussed there. * Ensure opset_imports is updated when version converter is applied TODO (in a separate PR): * Cleanup error status API (and return value) --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 2dd6b2d commit 9341a32

File tree

3 files changed

+69
-58
lines changed

3 files changed

+69
-58
lines changed

onnxscript/version_converter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import onnx
1313

14+
import onnxscript.ir.passes
1415
import onnxscript.ir.passes.common
1516
from onnxscript import ir
1617
from onnxscript.ir.passes.common import _c_api_utils

onnxscript/version_converter/_version_converter.py

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@
2020
SUPPORTED_MIN_ONNX_OPSET = 18
2121

2222

23+
def _get_onnx_opset_version(model: ir.Model) -> int | None:
24+
"""Get the ONNX opset version imported by the model."""
25+
model_version1 = model.opset_imports.get("")
26+
model_version2 = model.opset_imports.get("ai.onnx")
27+
if model_version1 is not None and model_version2 is not None:
28+
if model_version1 != model_version2:
29+
raise ValueError(
30+
f"Model imports multiple onnx opsets: {model_version1} and {model_version2}."
31+
)
32+
return model_version1 or model_version2
33+
34+
35+
def _set_onnx_opset_version(model: ir.Model, version: int) -> None:
36+
"""Set the ONNX opset version imported by the model."""
37+
if "ai.onnx" in model.opset_imports:
38+
del model.opset_imports["ai.onnx"]
39+
model.opset_imports[""] = version
40+
41+
2342
class VersionConverterError(RuntimeError):
2443
"""Raised when an node's version cannot be upgraded/downgraded successfully."""
2544

@@ -215,25 +234,15 @@ def groupnormalization_20_21(node: ir.Node, op):
215234

216235

217236
class _VersionConverter:
218-
opset_imports: dict[str, int]
219-
model_version: int
220-
221237
def __init__(self, target_version: int):
222-
self.target_version = target_version
223-
224-
def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: bool) -> None:
225-
if up_conversion is True:
226-
node.version = opset_version + 1
227-
else:
228-
node.version = opset_version - 1
238+
self._target_version = target_version
229239

230240
def process_node(
231-
self, node: ir.Node, opset_version: int, up_conversion: bool = True
241+
self, node: ir.Node, from_version: int, up_conversion: bool = True
232242
) -> Replacement | None:
233-
if node.domain != "":
234-
return None
243+
assert node.domain == ""
235244
adapter = registry.lookup_adapters(
236-
node.domain, node.op_type, opset_version, up_conversion
245+
node.domain, node.op_type, from_version, up_conversion
237246
)
238247
if adapter is None:
239248
return None
@@ -264,67 +273,65 @@ def visit_node(
264273
self,
265274
node: ir.Node,
266275
root: ir.Graph | ir.Function,
267-
opset_version: int,
276+
from_version: int,
268277
up_conversion: bool = True,
269278
) -> None:
270-
replacement = self.process_node(node, opset_version, up_conversion)
279+
if up_conversion:
280+
to_version = from_version + 1
281+
else:
282+
to_version = from_version - 1
283+
replacement = self.process_node(node, from_version, up_conversion)
271284
if replacement is None:
272285
# No change. Process attributes.
273286
for attr in node.attributes.values():
274287
self.visit_attribute(attr)
275-
return None
288+
node.version = to_version
276289
else:
290+
for new_node in replacement.new_nodes:
291+
# TODO: control-flow
292+
new_node.version = to_version
277293
self.replace_node(node, replacement, root)
278-
return None
279294

280295
def visit_graph(self, graph: ir.Graph) -> None:
281-
if self.target_version > SUPPORTED_MAX_ONNX_OPSET:
282-
logger.warning(
283-
"Conversion to target opset: %s not currently supported.",
284-
self.target_version,
285-
)
286-
return None
287296
for node in graph:
288-
up_conversion = True
289-
if node.version is None:
290-
node.version = self.model_version
297+
if node.domain != "":
298+
continue
299+
node_version = node.version or self._default_onnx_opset
300+
if node_version is None:
301+
raise VersionConverterError(f"Node {node} has no version.")
291302
# Iterate each node from current node version -> target version
292303
# and updating node based on the correct adapter
293304
# Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
294305
# TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted
295-
if self.target_version < node.version:
296-
up_conversion = False
297-
logger.warning(
298-
"Target opset: %s less than %s, downstream version conversion not currently handled.",
299-
self.target_version,
300-
self.model_version,
306+
if self._target_version < node_version:
307+
raise VersionConverterError(
308+
f"Target opset: {self._target_version} less than node version: {node.version}, "
309+
"downstream version conversion not currently handled."
301310
)
302-
return None
303-
for opset_version in range(node.version, self.target_version):
311+
for from_version in range(node_version, self._target_version):
304312
try:
305-
self.visit_node(node, graph, opset_version, up_conversion)
306-
self._upgrade_version(node, opset_version, up_conversion)
313+
self.visit_node(node, graph, from_version, up_conversion=True)
307314
except VersionConverterError as e:
308315
logger.warning(
309316
"Skipping version conversion for node %s due to exception: %s",
310317
node.op_type,
311318
e,
312319
)
313-
return None
314320

315321
def visit_model(self, model: ir.Model) -> None:
316-
self.opset_imports = model.opset_imports
317-
model_version = self.opset_imports.get("")
318-
if model_version is None:
319-
model_version = model.opset_imports.get("ai.onnx")
320-
if model_version is None:
321-
return None
322-
self.model_version = model_version
322+
self._default_onnx_opset = _get_onnx_opset_version(model)
323323
self.visit_graph(model.graph)
324-
return None
324+
_set_onnx_opset_version(model, self._target_version)
325325

326326

327327
def convert_version(model: ir.Model, target_version: int) -> None:
328328
"""Convert the model to the specified ONNX opset version."""
329+
if (target_version > SUPPORTED_MAX_ONNX_OPSET) or (
330+
target_version < SUPPORTED_MIN_ONNX_OPSET
331+
):
332+
raise ValueError(
333+
f"Target opset version {target_version} is not supported. "
334+
f"Supported range: {SUPPORTED_MIN_ONNX_OPSET} to {SUPPORTED_MAX_ONNX_OPSET}."
335+
)
329336
version_converter = _VersionConverter(target_version=target_version)
330337
version_converter.visit_model(model)

onnxscript/version_converter/_version_converter_test.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66

77
import onnx.defs
8+
import pytest
89

910
from onnxscript import ir, version_converter
1011

@@ -41,18 +42,19 @@ def test_upstream_coverage(self):
4142
self.assertEqual(domain, "")
4243
self.assertIn((name, upgrade_version), op_upgrades)
4344

44-
def test_version_convert_non_standard_onnx_domain(self):
45+
@pytest.mark.xfail(reason="TODO: Cleanup error status API.")
46+
def test_version_convert_no_source_version(self):
4547
model = ir.from_onnx_text(
4648
"""
4749
<ir_version: 7, opset_import: [ "local" : 1]>
4850
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
4951
{
50-
shape_a = Constant<value: tensor = int64[5] {1, 4, 512, 512}>()
52+
shape_a = Constant<value: tensor = int64[4] {1, 4, 512, 512}>()
5153
reshape_x = Reshape (input_x, shape_a)
52-
shape_b = Constant<value: tensor = int64[5] {1, 4, 1024, 1024}>()
54+
shape_b = Constant<value: tensor = int64[4] {1, 4, 1024, 1024}>()
5355
reshape_y = Reshape (input_x, shape_b)
5456
gridsample = GridSample <mode = "bilinear"> (reshape_x, reshape_y)
55-
shape_c = Constant<value: tensor = int64[4] {4, 1024, 1024}>()
57+
shape_c = Constant<value: tensor = int64[3] {4, 1024, 1024}>()
5658
output = Reshape (gridsample, shape_c)
5759
}
5860
"""
@@ -63,16 +65,9 @@ def test_version_convert_non_standard_onnx_domain(self):
6365
target_version = 20
6466
version_converter.convert_version(model, target_version=target_version)
6567

66-
self.assertEqual(model.graph.node(0).op_type, "Constant")
67-
self.assertEqual(model.graph.node(0).version, None)
68-
self.assertEqual(model.graph.node(1).op_type, "Reshape")
69-
self.assertEqual(model.graph.node(1).version, None)
70-
self.assertEqual(model.graph.node(4).op_type, "GridSample")
71-
self.assertEqual(model.graph.node(4).version, None)
72-
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")
73-
7468

7569
class VersionConverter18to17Test(unittest.TestCase):
70+
@pytest.mark.xfail(strict=True, reason="Version downgrade not yet supported.")
7671
def test_version_convert_compatible(self):
7772
model = ir.from_onnx_text(
7873
"""
@@ -112,6 +107,7 @@ def test_version_convert_compatible(self):
112107
)
113108
target_version = 19
114109
version_converter.convert_version(model, target_version=target_version)
110+
self.assertEqual(model.opset_imports[""], target_version)
115111

116112
self.assertEqual(model.graph.node(0).op_type, "Constant")
117113
self.assertEqual(model.graph.node(0).version, 19)
@@ -138,6 +134,7 @@ def test_version_convert_compatible(self):
138134
)
139135
target_version = 20
140136
version_converter.convert_version(model, target_version=target_version)
137+
self.assertEqual(model.opset_imports[""], target_version)
141138

142139
self.assertEqual(model.graph.node(0).op_type, "Constant")
143140
self.assertEqual(model.graph.node(0).version, 20)
@@ -170,6 +167,7 @@ def test_version_convert_gridsample_linear(self):
170167

171168
target_version = 20
172169
version_converter.convert_version(model, target_version=target_version)
170+
self.assertEqual(model.opset_imports[""], target_version)
173171

174172
self.assertEqual(model.graph.node(0).op_type, "Constant")
175173
self.assertEqual(model.graph.node(0).version, 20)
@@ -200,6 +198,7 @@ def test_version_convert_gridsample_cubic(self):
200198

201199
target_version = 20
202200
version_converter.convert_version(model, target_version=target_version)
201+
self.assertEqual(model.opset_imports[""], target_version)
203202

204203
self.assertEqual(model.graph.node(0).op_type, "Constant")
205204
self.assertEqual(model.graph.node(0).version, 20)
@@ -231,6 +230,7 @@ def test_version_convert_inline(self):
231230
)
232231
target_version = 20
233232
version_converter.convert_version(model, target_version=target_version)
233+
self.assertEqual(model.opset_imports[""], target_version)
234234

235235
self.assertEqual(model.graph.node(0).op_type, "Constant")
236236
self.assertEqual(model.graph.node(0).version, 20)
@@ -259,6 +259,7 @@ def test_version_groupnorm(self):
259259
)
260260
target_version = 21
261261
version_converter.convert_version(model, target_version=target_version)
262+
self.assertEqual(model.opset_imports[""], target_version)
262263

263264
self.assertEqual(model.graph.node(3).op_type, "Reshape")
264265
self.assertEqual(model.graph.node(3).version, 21)
@@ -289,12 +290,14 @@ def test_version_groupnorm_no_bias(self):
289290
)
290291
target_version = 21
291292
version_converter.convert_version(model, target_version=target_version)
293+
self.assertEqual(model.opset_imports[""], target_version)
292294

293295
self.assertEqual(model.graph.node(0).op_type, "GroupNormalization")
294296
self.assertEqual(model.graph.node(0).version, 20)
295297

296298

297299
class VersionConverter23to24Test(unittest.TestCase):
300+
@pytest.mark.xfail(strict=True, reason="Version upgrade beyond 23 not yet supported.")
298301
def test_version_convert_compatible(self):
299302
model = ir.from_onnx_text(
300303
"""

0 commit comments

Comments
 (0)