Skip to content

Commit d1a1ca6

Browse files
committed
Move is_weight_compression_needed function to common
1 parent 6f37e09 commit d1a1ca6

File tree

6 files changed

+29
-73
lines changed

6 files changed

+29
-73
lines changed

src/nncf/experimental/torch/fx/quantization/backend_parameters.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/nncf/experimental/torch/fx/quantization/quantize_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from nncf.common.logging import nncf_logger
2626
from nncf.common.quantization.structs import QuantizationPreset
2727
from nncf.data import Dataset
28-
from nncf.experimental.torch.fx.quantization.backend_parameters import is_weight_compression_needed
2928
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
3029
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
3130
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
@@ -39,6 +38,7 @@
3938
from nncf.parameters import TargetDevice
4039
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
4140
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
41+
from nncf.quantization.advanced_parameters import is_weight_compression_needed
4242
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
4343
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
4444
from nncf.scopes import IgnoredScope

src/nncf/openvino/quantization/backend_parameters.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

src/nncf/openvino/quantization/quantize_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates
2828
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
2929
from nncf.openvino.graph.node_utils import get_number_if_op
30-
from nncf.openvino.quantization.backend_parameters import BackendParameters
31-
from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed
3230
from nncf.openvino.quantization.quantize_ifmodel import apply_algorithm_if_bodies
3331
from nncf.openvino.rt_info import dump_parameters
3432
from nncf.parameters import BackupMode
@@ -42,7 +40,9 @@
4240
from nncf.quantization.advanced_parameters import AdvancedAccuracyRestorerParameters
4341
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
4442
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
43+
from nncf.quantization.advanced_parameters import BackendParameters
4544
from nncf.quantization.advanced_parameters import convert_to_dict_recursively
45+
from nncf.quantization.advanced_parameters import is_weight_compression_needed
4646
from nncf.quantization.algorithms.accuracy_control.algorithm import QuantizationAccuracyRestorer
4747
from nncf.quantization.algorithms.accuracy_control.algorithm import calculate_accuracy_drop
4848
from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator

src/nncf/quantization/advanced_parameters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,30 @@ class AdvancedQuantizationParameters:
287287
backend_params: dict[str, Any] = field(default_factory=dict)
288288

289289

290+
class BackendParameters:
291+
COMPRESS_WEIGHTS = "compress_weights"
292+
STAT_REQUESTS_NUMBER = "stat_requests_number"
293+
EVAL_REQUESTS_NUMBER = "eval_requests_number"
294+
ACTIVATIONS = "activations"
295+
WEIGHTS = "weights"
296+
LEVEL_LOW = "level_low"
297+
LEVEL_HIGH = "level_high"
298+
299+
300+
def is_weight_compression_needed(advanced_parameters: Optional["AdvancedQuantizationParameters"]) -> bool:
301+
"""
302+
Determine whether weight compression is needed based on advanced quantization parameters.
303+
304+
If `advanced_parameters` or its `backend_params` are not provided, defaults to True.
305+
306+
:param advanced_parameters: Advanced quantization parameters.
307+
:return: True if weight compression is needed, False otherwise.
308+
"""
309+
if advanced_parameters is not None and advanced_parameters.backend_params is not None:
310+
return bool(advanced_parameters.backend_params.get(BackendParameters.COMPRESS_WEIGHTS, True))
311+
return True
312+
313+
290314
@api()
291315
@dataclass
292316
class AdvancedAWQParameters:

tests/torch2/fx/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333
from nncf.common.utils.os import safe_open
3434
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
3535
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
36-
from nncf.experimental.torch.fx.quantization.backend_parameters import FXBackendParameters
3736
from nncf.experimental.torch.fx.transformations import DEQUANTIZE_NODE_TARGETS
3837
from nncf.experimental.torch.fx.transformations import _get_node_inputs
3938
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
39+
from nncf.quantization.advanced_parameters import BackendParameters
4040
from tests.cross_fw.shared.nx_graph import compare_nx_graph_with_reference
4141
from tests.cross_fw.shared.paths import TEST_ROOT
4242
from tests.torch import test_models
@@ -219,7 +219,7 @@ def transform_fn(data_item):
219219
calibration_dataset = nncf.Dataset([example_input], transform_fn)
220220

221221
quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(
222-
disable_bias_correction=True, backend_params={FXBackendParameters.COMPRESS_WEIGHTS: compress_weights}
222+
disable_bias_correction=True, backend_params={BackendParameters.COMPRESS_WEIGHTS: compress_weights}
223223
)
224224
quantization_parameters["subset_size"] = 1
225225

0 commit comments

Comments
 (0)