1616import onnx
1717
1818import nncf
19+ from nncf .common .factory import NNCFGraphFactory
1920from nncf .common .graph .model_transformer import ModelTransformer
2021from nncf .common .graph .transformations .commands import TargetType
2122from nncf .common .graph .transformations .layout import TransformationLayout
2223from nncf .onnx .graph .node_utils import get_input_edge
24+ from nncf .onnx .graph .node_utils import get_input_edges_mapping
2325from nncf .onnx .graph .onnx_helper import get_children
2426from nncf .onnx .graph .onnx_helper import get_children_node_mapping
2527from nncf .onnx .graph .onnx_helper import get_edge_dtype
2628from nncf .onnx .graph .onnx_helper import get_edge_info_mapping
2729from nncf .onnx .graph .onnx_helper import get_name_to_node_map
2830from nncf .onnx .graph .onnx_helper import get_node_index
31+ from nncf .onnx .graph .onnx_helper import get_parents_node_mapping
2932from nncf .onnx .graph .onnx_helper import get_tensor
3033from nncf .onnx .graph .transformations .commands import ONNXInitializerUpdateCommand
3134from nncf .onnx .graph .transformations .commands import ONNXModelExtractionCommand
@@ -55,8 +58,8 @@ def __init__(self, model: onnx.ModelProto, inplace: bool = False):
5558 self .onnx_model_extractor = onnx .utils .Extractor (inferred_model )
5659 self ._inplace = inplace
5760
61+ @staticmethod
5862 def _get_target_edge (
59- self ,
6063 port_id : int ,
6164 node_name : str ,
6265 transform_type : TargetType ,
@@ -381,10 +384,25 @@ def _apply_initializer_update_transformations(
381384 :return: Copy of original model with updated biases.
382385 """
383386 name_to_node_map = get_name_to_node_map (model )
387+ output_name_to_node_map = get_parents_node_mapping (model )
388+
384389 for transformation in transformations :
385390 node = name_to_node_map [transformation .target_point .target_node_name ]
386- initializer_name = node .input [transformation .target_point .port_id ]
387- set_initializer (initializer_name , model , transformation .new_value )
391+ # NOTE: An `input_name` is either the name of an initializer or the name of a `Constant` operation output
392+ input_name = node .input [transformation .target_point .port_id ]
393+
394+ constant_node = output_name_to_node_map .get (input_name , None )
395+
396+ if constant_node is None :
397+ set_initializer (input_name , model , transformation .new_value )
398+ else :
399+ for attr in constant_node .attribute :
400+ if attr .name == "value" :
401+ array = transformation .new_value .astype (onnx .helper .tensor_dtype_to_np_dtype (attr .t .data_type ))
402+ tensor_proto = onnx .numpy_helper .from_array (array )
403+ attr .t .CopyFrom (tensor_proto )
404+ break
405+
388406 return model
389407
390408 def _apply_model_extraction_transformation (self , transformation : ONNXModelExtractionCommand ) -> onnx .ModelProto :
@@ -484,16 +502,25 @@ def _apply_multiply_insertion_transformations(
484502 :returns: Transformed model with Multiply nodes.
485503 """
486504 node_name_to_node = get_name_to_node_map (model )
505+ # TODO(andrey-churkin): Optimize it
506+ graph = NNCFGraphFactory .create (model )
507+ input_edges_mapping = get_input_edges_mapping (graph )
487508
488509 for transformation in transformations :
510+ port_id = transformation .target_point .port_id
489511 target_node_name = transformation .target_point .target_node_name
490- target_output_port = transformation .target_point .port_id
491- target_node = node_name_to_node [target_node_name ]
492- output_tensor_name = target_node .output [target_output_port ]
512+ transform_type = transformation .target_point .type
513+ output_tensor_name = ONNXModelTransformer ._get_target_edge (
514+ port_id , target_node_name , transform_type , node_name_to_node , input_edges_mapping
515+ )
516+
517+ # TODO(andrey-churkin): Check type of `transformation.scale_value`
493518
494519 # Create a new initializer for the scale constant
495520 scale_tensor_name = f"{ transformation .multiply_node_name } _scale"
496- scale_tensor = onnx .numpy_helper .from_array (transformation .scale_value , name = scale_tensor_name )
521+ scale_tensor = onnx .numpy_helper .from_array (
522+ transformation .scale_value .astype (np .float32 ), name = scale_tensor_name
523+ )
497524 model .graph .initializer .append (scale_tensor )
498525
499526 # Create a new Multiply node
@@ -505,7 +532,8 @@ def _apply_multiply_insertion_transformations(
505532 name = transformation .multiply_node_name ,
506533 )
507534 target_index = get_node_index (model , target_node_name )
508- model .graph .node .insert (target_index + 1 , mul_node )
535+ insert_index = 0 if target_index is None else target_index + 1
536+ model .graph .node .insert (insert_index , mul_node )
509537
510538 for name in transformation .destination_node_names :
511539 node = node_name_to_node [name ]
@@ -524,6 +552,11 @@ def set_initializer(initializer_name: str, model: onnx.ModelProto, new_value: np
524552 :param new_value: New value for the initializer tensor.
525553 """
526554 initializer = get_tensor (model , initializer_name )
555+
556+ required_dtype = onnx .helper .tensor_dtype_to_np_dtype (initializer .data_type )
557+ if new_value .dtype != required_dtype :
558+ new_value = new_value .astype (required_dtype )
559+
527560 new_tensor = onnx .numpy_helper .from_array (new_value , initializer_name )
528561 initializer .CopyFrom (new_tensor )
529562
0 commit comments