@@ -508,9 +508,9 @@ def _get_primary_config(self, group_size: int) -> WeightCompressionConfig:
508508 def _set_weight_compression_config (
509509 self ,
510510 ratio_defining_params : list [WeightCompressionParameters ],
511+ primary_precision_weight_params : list [WeightCompressionParameters ],
511512 model : TModel ,
512513 graph : NNCFGraph ,
513- statistics_points : StatisticPointsContainer ,
514514 group_size_values : dict [str , int ],
515515 ) -> None :
516516 """
@@ -520,16 +520,8 @@ def _set_weight_compression_config(
520520 backup precisions.
521521 :param model: The model.
522522 :param graph: The model graph associated with the model.
523- :param statistics_points: Statistics points.
524523 :param group_size_values: A dictionary mapping weight names to their group size values.
525524 """
526- if self ._ratio < 1 and len (ratio_defining_params ) > 0 :
527- primary_precision_weight_params = self ._mixed_precision_algo .apply (
528- model , graph , statistics_points , weight_params = ratio_defining_params
529- )
530- else :
531- primary_precision_weight_params = ratio_defining_params
532-
533525 for weight_param in primary_precision_weight_params :
534526 weight_param .compression_config = self ._get_primary_config (group_size_values [weight_param .weight_name ])
535527
@@ -805,6 +797,49 @@ def collect_weight_compression_statistics(
805797 statistics = self ._get_statistics_for_weights_compression (matmul_input_to_output_nodes_map , statistic_points )
806798 return statistics , statistic_points
807799
800+ def is_last_layer_skipped (
801+ self ,
802+ skipped_params : list [WeightCompressionParameters ],
803+ nodes_to_compress : list [NNCFNode ],
804+ ) -> bool :
805+ """
806+ Returns True if the final node in nodes_to_compress does not appear in compressed weights.
807+ """
808+ if not (nodes_to_compress and skipped_params ):
809+ return False
810+ last_node = nodes_to_compress [- 1 ]
811+ return any (param .node_with_weight == last_node for param in skipped_params )
812+
813+ def get_skipped_weight_compression_parameters (
814+ self ,
815+ model : TModel ,
816+ graph : NNCFGraph ,
817+ nodes_to_compress : list [NNCFNode ],
818+ ) -> list [WeightCompressionParameters ]:
819+ skipped_weight_params : list [WeightCompressionParameters ] = []
820+ weight_names = set ()
821+ ignored_names = self .get_ignored_node_names (graph )
822+
823+ for i , node in enumerate (nodes_to_compress ):
824+ is_target_node = should_consider_scope (node .node_name , ignored_names )
825+ for weight_name , weight_port_id in self ._backend_entity .get_weight_names_and_port_ids (node , graph ):
826+ is_duplicate = weight_name in weight_names
827+ weight_dtype = self ._backend_entity .get_weight_dtype (node , weight_port_id , model , graph )
828+ weight_shape = self ._backend_entity .get_weight_shape (node , weight_port_id , graph )
829+ reduction_axes = self ._backend_entity .get_reduction_axes (node , weight_port_id , graph )
830+
831+ wc_config = None
832+ should_skip = is_duplicate or (not is_target_node ) or (not self .is_weight_compression_supported (weight_dtype , self ._mode ))
833+ if should_skip :
834+ skipped_weight_params .append (
835+ WeightCompressionParameters (
836+ weight_name , node , weight_port_id , weight_dtype , weight_shape , reduction_axes , wc_config
837+ )
838+ )
839+ weight_names .add (weight_name )
840+
841+ return skipped_weight_params
842+
808843 def get_weight_compression_parameters (
809844 self ,
810845 model : TModel ,
@@ -829,7 +864,6 @@ def get_weight_compression_parameters(
829864 collected statistics.
830865 """
831866 all_weight_params : list [WeightCompressionParameters ] = []
832- skipped_weight_params : list [WeightCompressionParameters ] = []
833867
834868 weight_names = set ()
835869 is_last_layer_skipped = False
@@ -885,43 +919,8 @@ def get_weight_compression_parameters(
885919 weight_names .add (weight_name )
886920 else :
887921 is_last_layer_skipped = is_last_layer
888- skipped_weight_params .append (
889- WeightCompressionParameters (
890- weight_name , node , weight_port_id , weight_dtype , weight_shape , reduction_axes , wc_config
891- )
892- )
893-
894- # Get subset of nodes to define compression ratio
895- ratio_defining_params = self ._get_ratio_defining_params (all_weight_params , is_last_layer_skipped )
896-
897- # Handle group size fallback modes
898- if self ._group_size_fallback_mode == GroupSizeFallbackMode .IGNORE :
899- all_weight_params , ratio_defining_params , skipped_weight_params = self ._handle_ignore_group_size_fallback (
900- all_weight_params , ratio_defining_params , skipped_weight_params
901- )
902- if self ._group_size_fallback_mode == GroupSizeFallbackMode .ADJUST :
903- ratio_defining_params , group_size_values = self ._handle_adjust_group_size_fallback (ratio_defining_params )
904- else :
905- group_size_values = {w_params .weight_name : self ._group_size for w_params in ratio_defining_params }
906-
907- # Collect statistics for the weights compression
908- weight_params = ratio_defining_params if self ._backup_mode == BackupMode .NONE else all_weight_params
909- statistics , statistic_points = self .collect_weight_compression_statistics (
910- model , graph , dataset , weight_params , statistic_points
911- )
912-
913- # Set weight compression configuration
914- self ._set_weight_compression_config (ratio_defining_params , model , graph , statistic_points , group_size_values )
915-
916- # Print statistics
917- nncf_logger .info (
918- self ._get_bitwidth_distribution_str (all_weight_params , ratio_defining_params , skipped_weight_params )
919- )
920-
921- # Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
922- all_weight_params = list (filter (lambda w_params : w_params .compression_config is not None , all_weight_params ))
923922
924- return all_weight_params , statistics
923+ return all_weight_params
925924
926925 def apply_wc_algos (
927926 self ,
@@ -1013,10 +1012,51 @@ def apply(
10131012 ) -> TModel :
10141013 self .set_backend_entity (model )
10151014 nodes_to_compress = self .get_nodes_to_compress (graph )
1015+
10161016 # Get processed weight compression parameters ready for compression
1017- all_weight_params , statistics = self .get_weight_compression_parameters (
1017+ all_weight_params = self .get_weight_compression_parameters (
10181018 model , graph , nodes_to_compress , statistic_points , dataset
10191019 )
1020+
1021+ statistics , statistic_points = self .collect_weight_compression_statistics (
1022+ model , graph , dataset , all_weight_params , statistic_points
1023+ )
1024+
1025+ skipped_weight_params = self .get_skipped_weight_compression_parameters (model , graph , nodes_to_compress )
1026+ is_last_layer_skipped = self .is_last_layer_skipped (skipped_weight_params , nodes_to_compress )
1027+
1028+ # Get subset of nodes to define compression ratio
1029+ ratio_defining_params = self ._get_ratio_defining_params (all_weight_params , is_last_layer_skipped )
1030+
1031+ # Handle group size fallback modes
1032+ if self ._group_size_fallback_mode == GroupSizeFallbackMode .IGNORE :
1033+ all_weight_params , ratio_defining_params , skipped_weight_params = self ._handle_ignore_group_size_fallback (
1034+ all_weight_params , ratio_defining_params , skipped_weight_params
1035+ )
1036+ if self ._group_size_fallback_mode == GroupSizeFallbackMode .ADJUST :
1037+ ratio_defining_params , group_size_values = self ._handle_adjust_group_size_fallback (ratio_defining_params )
1038+ else :
1039+ group_size_values = {w_params .weight_name : self ._group_size for w_params in ratio_defining_params }
1040+
1041+ weight_params = ratio_defining_params if self ._backup_mode == BackupMode .NONE else all_weight_params
1042+
1043+
1044+ if self ._ratio < 1 and len (ratio_defining_params ) > 0 :
1045+ primary_precision_weight_params = self ._mixed_precision_algo .apply (
1046+ model , graph , statistic_points , weight_params = ratio_defining_params
1047+ )
1048+ else :
1049+ primary_precision_weight_params = ratio_defining_params
1050+ # Set weight compression configuration
1051+ self ._set_weight_compression_config (ratio_defining_params , primary_precision_weight_params , model , graph , group_size_values )
1052+
1053+ # Print statistics
1054+ nncf_logger .info (
1055+ self ._get_bitwidth_distribution_str (all_weight_params , ratio_defining_params , skipped_weight_params )
1056+ )
1057+ # Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
1058+ all_weight_params = list (filter (lambda w_params : w_params .compression_config is not None , all_weight_params ))
1059+
10201060 transformed_model = self .apply_wc_algos (model , graph , all_weight_params , statistics , dataset )
10211061
10221062 return transformed_model
0 commit comments