Skip to content

Commit 9653817

Browse files
committed
init
1 parent 1f1fda3 commit 9653817

File tree

1 file changed

+87
-47
lines changed
  • src/nncf/quantization/algorithms/weight_compression

1 file changed

+87
-47
lines changed

src/nncf/quantization/algorithms/weight_compression/algorithm.py

Lines changed: 87 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -508,9 +508,9 @@ def _get_primary_config(self, group_size: int) -> WeightCompressionConfig:
508508
def _set_weight_compression_config(
509509
self,
510510
ratio_defining_params: list[WeightCompressionParameters],
511+
primary_precision_weight_params: list[WeightCompressionParameters],
511512
model: TModel,
512513
graph: NNCFGraph,
513-
statistics_points: StatisticPointsContainer,
514514
group_size_values: dict[str, int],
515515
) -> None:
516516
"""
@@ -520,16 +520,8 @@ def _set_weight_compression_config(
520520
backup precisions.
521521
:param model: The model.
522522
:param graph: The model graph associated with the model.
523-
:param statistics_points: Statistics points.
524523
:param group_size_values: A dictionary mapping weight names to their group size values.
525524
"""
526-
if self._ratio < 1 and len(ratio_defining_params) > 0:
527-
primary_precision_weight_params = self._mixed_precision_algo.apply(
528-
model, graph, statistics_points, weight_params=ratio_defining_params
529-
)
530-
else:
531-
primary_precision_weight_params = ratio_defining_params
532-
533525
for weight_param in primary_precision_weight_params:
534526
weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name])
535527

@@ -805,6 +797,49 @@ def collect_weight_compression_statistics(
805797
statistics = self._get_statistics_for_weights_compression(matmul_input_to_output_nodes_map, statistic_points)
806798
return statistics, statistic_points
807799

800+
def is_last_layer_skipped(
801+
self,
802+
skipped_params: list[WeightCompressionParameters],
803+
nodes_to_compress: list[NNCFNode],
804+
) -> bool:
805+
"""
806+
Returns True if the final node in nodes_to_compress does not appear in compressed weights.
807+
"""
808+
if not (nodes_to_compress and skipped_params):
809+
return False
810+
last_node = nodes_to_compress[-1]
811+
return any(param.node_with_weight == last_node for param in skipped_params)
812+
813+
def get_skipped_weight_compression_parameters(
814+
self,
815+
model: TModel,
816+
graph: NNCFGraph,
817+
nodes_to_compress: list[NNCFNode],
818+
) -> list[WeightCompressionParameters]:
819+
skipped_weight_params: list[WeightCompressionParameters] = []
820+
weight_names = set()
821+
ignored_names = self.get_ignored_node_names(graph)
822+
823+
for i, node in enumerate(nodes_to_compress):
824+
is_target_node = should_consider_scope(node.node_name, ignored_names)
825+
for weight_name, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph):
826+
is_duplicate = weight_name in weight_names
827+
weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
828+
weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
829+
reduction_axes = self._backend_entity.get_reduction_axes(node, weight_port_id, graph)
830+
831+
wc_config = None
832+
should_skip = is_duplicate or (not is_target_node) or (not self.is_weight_compression_supported(weight_dtype, self._mode))
833+
if should_skip:
834+
skipped_weight_params.append(
835+
WeightCompressionParameters(
836+
weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config
837+
)
838+
)
839+
weight_names.add(weight_name)
840+
841+
return skipped_weight_params
842+
808843
def get_weight_compression_parameters(
809844
self,
810845
model: TModel,
@@ -829,7 +864,6 @@ def get_weight_compression_parameters(
829864
collected statistics.
830865
"""
831866
all_weight_params: list[WeightCompressionParameters] = []
832-
skipped_weight_params: list[WeightCompressionParameters] = []
833867

834868
weight_names = set()
835869
is_last_layer_skipped = False
@@ -885,43 +919,8 @@ def get_weight_compression_parameters(
885919
weight_names.add(weight_name)
886920
else:
887921
is_last_layer_skipped = is_last_layer
888-
skipped_weight_params.append(
889-
WeightCompressionParameters(
890-
weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config
891-
)
892-
)
893-
894-
# Get subset of nodes to define compression ratio
895-
ratio_defining_params = self._get_ratio_defining_params(all_weight_params, is_last_layer_skipped)
896-
897-
# Handle group size fallback modes
898-
if self._group_size_fallback_mode == GroupSizeFallbackMode.IGNORE:
899-
all_weight_params, ratio_defining_params, skipped_weight_params = self._handle_ignore_group_size_fallback(
900-
all_weight_params, ratio_defining_params, skipped_weight_params
901-
)
902-
if self._group_size_fallback_mode == GroupSizeFallbackMode.ADJUST:
903-
ratio_defining_params, group_size_values = self._handle_adjust_group_size_fallback(ratio_defining_params)
904-
else:
905-
group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params}
906-
907-
# Collect statistics for the weights compression
908-
weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params
909-
statistics, statistic_points = self.collect_weight_compression_statistics(
910-
model, graph, dataset, weight_params, statistic_points
911-
)
912-
913-
# Set weight compression configuration
914-
self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values)
915-
916-
# Print statistics
917-
nncf_logger.info(
918-
self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params)
919-
)
920-
921-
# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
922-
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params))
923922

924-
return all_weight_params, statistics
923+
return all_weight_params
925924

926925
def apply_wc_algos(
927926
self,
@@ -1013,10 +1012,51 @@ def apply(
10131012
) -> TModel:
10141013
self.set_backend_entity(model)
10151014
nodes_to_compress = self.get_nodes_to_compress(graph)
1015+
10161016
# Get processed weight compression parameters ready for compression
1017-
all_weight_params, statistics = self.get_weight_compression_parameters(
1017+
all_weight_params = self.get_weight_compression_parameters(
10181018
model, graph, nodes_to_compress, statistic_points, dataset
10191019
)
1020+
1021+
statistics, statistic_points = self.collect_weight_compression_statistics(
1022+
model, graph, dataset, all_weight_params, statistic_points
1023+
)
1024+
1025+
skipped_weight_params = self.get_skipped_weight_compression_parameters(model, graph, nodes_to_compress)
1026+
is_last_layer_skipped = self.is_last_layer_skipped(skipped_weight_params, nodes_to_compress)
1027+
1028+
# Get subset of nodes to define compression ratio
1029+
ratio_defining_params = self._get_ratio_defining_params(all_weight_params, is_last_layer_skipped)
1030+
1031+
# Handle group size fallback modes
1032+
if self._group_size_fallback_mode == GroupSizeFallbackMode.IGNORE:
1033+
all_weight_params, ratio_defining_params, skipped_weight_params = self._handle_ignore_group_size_fallback(
1034+
all_weight_params, ratio_defining_params, skipped_weight_params
1035+
)
1036+
if self._group_size_fallback_mode == GroupSizeFallbackMode.ADJUST:
1037+
ratio_defining_params, group_size_values = self._handle_adjust_group_size_fallback(ratio_defining_params)
1038+
else:
1039+
group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params}
1040+
1041+
weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params
1042+
1043+
1044+
if self._ratio < 1 and len(ratio_defining_params) > 0:
1045+
primary_precision_weight_params = self._mixed_precision_algo.apply(
1046+
model, graph, statistic_points, weight_params=ratio_defining_params
1047+
)
1048+
else:
1049+
primary_precision_weight_params = ratio_defining_params
1050+
# Set weight compression configuration
1051+
self._set_weight_compression_config(ratio_defining_params, primary_precision_weight_params, model, graph, group_size_values)
1052+
1053+
# Print statistics
1054+
nncf_logger.info(
1055+
self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params)
1056+
)
1057+
# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
1058+
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params))
1059+
10201060
transformed_model = self.apply_wc_algos(model, graph, all_weight_params, statistics, dataset)
10211061

10221062
return transformed_model

0 commit comments

Comments
 (0)