Skip to content

Commit a390263

Browse files
committed
Move is_weight_compression_needed function to common
1 parent 6f37e09 commit a390263

File tree

7 files changed

+23
-60
lines changed

7 files changed

+23
-60
lines changed

src/nncf/experimental/torch/fx/quantization/backend_parameters.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/nncf/experimental/torch/fx/quantization/quantize_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from nncf.common.logging import nncf_logger
2626
from nncf.common.quantization.structs import QuantizationPreset
2727
from nncf.data import Dataset
28-
from nncf.experimental.torch.fx.quantization.backend_parameters import is_weight_compression_needed
2928
from nncf.experimental.torch.fx.transformations import DuplicateDQPassNoAnnotations
3029
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
3130
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
@@ -93,7 +92,7 @@ def quantize_impl(
9392
nncf_graph = NNCFGraphFactory.create(copied_model)
9493
quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)
9594

96-
if is_weight_compression_needed(advanced_parameters):
95+
if advanced_parameters.is_weight_compression_needed():
9796
compress_post_quantize_transformation(quantized_model)
9897
else:
9998
fq_weights_transformation(quantized_model)

src/nncf/openvino/quantization/backend_parameters.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,11 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Optional
13-
14-
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
15-
1612

1713
class BackendParameters:
18-
COMPRESS_WEIGHTS = "compress_weights"
1914
STAT_REQUESTS_NUMBER = "stat_requests_number"
2015
EVAL_REQUESTS_NUMBER = "eval_requests_number"
2116
ACTIVATIONS = "activations"
2217
WEIGHTS = "weights"
2318
LEVEL_LOW = "level_low"
2419
LEVEL_HIGH = "level_high"
25-
26-
27-
def is_weight_compression_needed(advanced_parameters: Optional[AdvancedQuantizationParameters]) -> bool:
28-
"""
29-
Determines whether weight compression is needed based on the provided
30-
advanced quantization parameters.
31-
32-
:param advanced_parameters: Advanced quantization parameters.
33-
:return: True if weight compression is needed, False otherwise.
34-
"""
35-
if advanced_parameters is not None and advanced_parameters.backend_params is not None:
36-
return advanced_parameters.backend_params.get(BackendParameters.COMPRESS_WEIGHTS, True)
37-
return True

src/nncf/openvino/quantization/quantize_model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates
2828
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
2929
from nncf.openvino.graph.node_utils import get_number_if_op
30-
from nncf.openvino.quantization.backend_parameters import BackendParameters
31-
from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed
3230
from nncf.openvino.quantization.quantize_ifmodel import apply_algorithm_if_bodies
3331
from nncf.openvino.rt_info import dump_parameters
3432
from nncf.parameters import BackupMode
@@ -123,7 +121,7 @@ def _extract_all_subgraphs(model: ov.Model, current_id: str) -> None:
123121
quantization_algorithm, model, graphs, main_model_graph_id, calibration_dataset, subset_size, 1
124122
)
125123

126-
if is_weight_compression_needed(advanced_parameters):
124+
if advanced_parameters.is_weight_compression_needed():
127125
compress_quantize_weights_transformation(quantized_model)
128126

129127
dump_parameters(
@@ -170,7 +168,7 @@ def native_quantize_impl(
170168
warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS)
171169
quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset)
172170

173-
if is_weight_compression_needed(advanced_parameters):
171+
if advanced_parameters.is_weight_compression_needed():
174172
compress_quantize_weights_transformation(quantized_model)
175173

176174
dump_parameters(
@@ -211,13 +209,13 @@ def quantize_with_accuracy_control_impl(
211209
if advanced_accuracy_restorer_parameters is None:
212210
advanced_accuracy_restorer_parameters = AdvancedAccuracyRestorerParameters()
213211

214-
compress_weights = is_weight_compression_needed(advanced_quantization_parameters)
212+
compress_weights = advanced_quantization_parameters.is_weight_compression_needed()
215213

216214
if advanced_quantization_parameters is None:
217215
copied_parameters = AdvancedQuantizationParameters()
218216
else:
219217
copied_parameters = deepcopy(advanced_quantization_parameters)
220-
copied_parameters.backend_params[BackendParameters.COMPRESS_WEIGHTS] = False
218+
copied_parameters.backend_params[AdvancedQuantizationParameters.COMPRESS_WEIGHTS] = False
221219

222220
quantized_model = quantize_impl(
223221
model=model,

src/nncf/quantization/advanced_parameters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,22 @@ class AdvancedQuantizationParameters:
286286
# Backend specific parameters
287287
backend_params: dict[str, Any] = field(default_factory=dict)
288288

289+
# Backend parameter names
290+
COMPRESS_WEIGHTS = "compress_weights"
291+
292+
def is_weight_compression_needed(self) -> bool:
293+
"""
294+
Determine whether weight compression is needed based on advanced quantization parameters.
295+
296+
If `advanced_parameters` or its `backend_params` are not provided, defaults to True.
297+
298+
:param advanced_parameters: Advanced quantization parameters.
299+
:return: True if weight compression is needed, False otherwise.
300+
"""
301+
if self.backend_params is not None:
302+
return bool(self.backend_params.get(AdvancedQuantizationParameters.COMPRESS_WEIGHTS, True))
303+
return True
304+
289305

290306
@api()
291307
@dataclass

src/nncf/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
__version__ = "2.19.0"
12+
__version__ = "2.19.0.dev0+6f37e09d9dirty"
1313

1414

1515
BKC_TORCH_SPEC = "==2.8.*"

tests/torch2/fx/test_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from nncf.common.utils.os import safe_open
3434
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
3535
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
36-
from nncf.experimental.torch.fx.quantization.backend_parameters import FXBackendParameters
3736
from nncf.experimental.torch.fx.transformations import DEQUANTIZE_NODE_TARGETS
3837
from nncf.experimental.torch.fx.transformations import _get_node_inputs
3938
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
@@ -219,7 +218,7 @@ def transform_fn(data_item):
219218
calibration_dataset = nncf.Dataset([example_input], transform_fn)
220219

221220
quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(
222-
disable_bias_correction=True, backend_params={FXBackendParameters.COMPRESS_WEIGHTS: compress_weights}
221+
disable_bias_correction=True, backend_params={AdvancedQuantizationParameters.COMPRESS_WEIGHTS: compress_weights}
223222
)
224223
quantization_parameters["subset_size"] = 1
225224

0 commit comments

Comments
 (0)