4
4
5
5
from __future__ import annotations
6
6
7
+ import copy
7
8
import dataclasses
8
9
import functools
9
10
import logging
@@ -220,38 +221,6 @@ class _VersionConverter:
220
221
def __init__ (self , target_version : int ):
221
222
self .target_version = target_version
222
223
223
- def _maybe_set_opset_version (
224
- self , model_or_function : ir .Model | ir .Function , domain : str , version : int | None
225
- ):
226
- """Set the opset version for the domain."""
227
- current_version = model_or_function .opset_imports .get (domain )
228
- if version is None or current_version is None :
229
- return
230
- if domain == "" :
231
- model_or_function .opset_imports [domain ] = max (version , current_version )
232
- return
233
- elif domain == "ai.onnx" :
234
- model_or_function .opset_imports [domain ] = max (version , current_version )
235
- return
236
- else :
237
- return
238
-
239
- def _update_opset_imports (self , model : ir .Model ) -> None :
240
- """Collect all opsets used and add opset imports to the model and functions."""
241
- for node in ir .traversal .RecursiveGraphIterator (model .graph ):
242
- domain = node .domain
243
- self ._maybe_set_opset_version (model , domain , node .version )
244
-
245
- for function in model .functions .values ():
246
- for node in ir .traversal .RecursiveGraphIterator (function ):
247
- domain = node .domain
248
- self ._maybe_set_opset_version (function , domain , node .version )
249
- for domain , version in function .opset_imports .items ():
250
- # Add all opsets used in the function to the model, because ONNX Runtime
251
- # does not handle adding the opset imports to the model after inlining during inference.
252
- # This should happen after all opsets are collected for the function from its nodes.
253
- self ._maybe_set_opset_version (model , domain , version )
254
-
255
224
def _upgrade_version (self , node : ir .Node , opset_version : int , up_conversion : bool ) -> None :
256
225
if up_conversion is True :
257
226
node .version = opset_version + 1
@@ -315,32 +284,40 @@ def visit_graph(self, graph: ir.Graph) -> None:
315
284
self .target_version ,
316
285
)
317
286
return None
318
- for node in graph :
319
- up_conversion = True
320
- if node .version is None :
321
- node .version = self .model_version
322
- # Iterate each node from current node version -> target version
323
- # and updating node based on the correct adapter
324
- # Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
325
- # TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted
326
- if self .target_version < node .version :
327
- up_conversion = False
328
- logger .warning (
329
- "Target opset: %s less than %s, downstream version conversion not currently handled." ,
330
- self .target_version ,
331
- self .model_version ,
332
- )
333
- return None
334
- for opset_version in range (node .version , self .target_version ):
287
+
288
+ # TODO(shubhambhokare1): Support down-conversion
289
+ while self .model_version < self .target_version :
290
+ pre_conversion_graph = copy .copy (graph )
291
+ # Up-convert each node in the graph from opset_version -> opset_version + 1
292
+ # or down-convert from opset_version -> opset_version - 1
293
+ # Return non-converted graph if any node fails to convert.
294
+ for node in graph :
295
+ up_conversion = True
296
+ if node .version is None :
297
+ node .version = self .model_version
298
+ if self .target_version < node .version :
299
+ up_conversion = False
300
+ # TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted
301
+ logger .warning (
302
+ "Target opset: %s less than %s, downstream version conversion not currently handled." ,
303
+ self .target_version ,
304
+ node .version ,
305
+ )
306
+ graph = pre_conversion_graph
307
+ return None
335
308
try :
336
- self .visit_node (node , graph , opset_version , up_conversion )
337
- self ._upgrade_version (node , opset_version , up_conversion )
309
+ self .visit_node (node , graph , self . model_version , up_conversion )
310
+ self ._upgrade_version (node , self . model_version , up_conversion )
338
311
except VersionConverterError as e :
339
312
logger .warning (
340
313
"Skipping version conversion for node %s due to exception: %s" ,
341
314
node .op_type ,
342
315
e ,
343
316
)
317
+ graph = pre_conversion_graph
318
+ return None
319
+ self .model_version += 1
320
+ del pre_conversion_graph
344
321
return None
345
322
346
323
def visit_model (self , model : ir .Model ) -> None :
@@ -353,7 +330,7 @@ def visit_model(self, model: ir.Model) -> None:
353
330
self .model_version = model_version
354
331
self .visit_graph (model .graph )
355
332
# Finally, update the opset imports for the model
356
- self ._update_opset_imports ( model )
333
+ model . opset_imports [ "" ] = self .model_version
357
334
return None
358
335
359
336
0 commit comments