Skip to content

Commit bef451e

Browse files
authored
[Data-Aware WC] Fix Statistics Processing for Data Aware Methods (#3752)
### Changes Fix logic to sample activation values from the collected statistics which resolves a bug. ### Reason for changes An error was introduced with the addition of SE support for 3D weights #3706. This was spotted by the failing test https://github.com/openvinotoolkit/nncf/actions/runs/19549771001 ### Related tickets <!--- Post the numerical ID of the ticket, if available --> ### Tests 1. Extended Scale estimation test `test_scale_estimation()` in `tests/cross_fw/test_templates/template_test_weights_compression.py` to include the case where calibration dataset size > scale estimation subset size 2. Added `test_process_stats` in weights compression template test for testing the process_stats function. Example Test: https://github.com/openvinotoolkit/nncf/actions/runs/19580896526 WC Conformance Test: https://github.com/openvinotoolkit/nncf/actions/runs/19577215441 Precision Type | Filter | Value | Stderr -- | -- | -- | -- INT4 SYM Per-Channel (with Scale estimation) After Fix | flexible-extract | 0.61 | 0.0490   | strict-match | 0.40 | 0.0492 INT4 SYM Per-Channel (with Scale estimation) Before fix | flexible-extract | 0.66 | 0.0476   | strict-match | 0.38 | 0.0488 INT4 SYM Per-Channel | flexible-extract | 0.77 | 0.0423   | strict-match | 0.28 | 0.0451 FP16 | flexible-extract | 0.91 | 0.0288   | strict-match | 0.86 | 0.0349
1 parent 31084d0 commit bef451e

File tree

6 files changed

+610
-266
lines changed

6 files changed

+610
-266
lines changed

src/nncf/quantization/algorithms/weight_compression/activation_stats.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,10 @@ def process_stats(stats: WCTensorStatistic, subset_size: int) -> tuple[Tensor, T
4141

4242
# Prevent high memory and time consumption by sampling
4343
if X_full.shape[sample_axis] > subset_size:
44-
# Activations were reduced across all but the last dimension
4544
lens = [reduce(mul, shape[:-1], 1) for shape in stats.shape_values]
4645
step = X_full.shape[sample_axis] // subset_size
47-
sorted_idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
48-
idxs = [idx for idx in sorted_idxs if idx < X_full.shape[sample_axis]][:subset_size]
49-
50-
# Create index slices for all dimensions except the last one
51-
# This works for both 2D and 3D (and theoretically any dimensionality)
52-
index_slices = [slice(None)] * (len(X_full.shape) - 1) + [idxs]
53-
X = X_full[tuple(index_slices)]
46+
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
47+
X = X_full[..., idxs]
5448
else:
5549
X = X_full
5650

tests/cross_fw/test_templates/template_test_weights_compression.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import math
1212
from abc import ABC
1313
from abc import abstractmethod
14+
from functools import reduce
15+
from operator import mul
1416
from typing import Any, Callable, Optional, TypeVar
1517
from unittest.mock import patch
1618

@@ -28,6 +30,8 @@
2830
from nncf.quantization import compress_weights
2931
from nncf.quantization.advanced_parameters import AdvancedAWQParameters as AWQParams
3032
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters as CompressionParams
33+
from nncf.quantization.algorithms.weight_compression.activation_stats import WCTensorStatistic
34+
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
3135
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
3236
from nncf.quantization.algorithms.weight_compression.awq import AWQ
3337
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
@@ -240,20 +244,25 @@ def get_moe_model_for_test_scale_estimation() -> TModel:
240244

241245
@staticmethod
242246
@abstractmethod
243-
def get_moe_scale_estimation_ref() -> TTensor:
247+
def get_moe_scale_estimation_ref(check_sampling_activation_stats_flow: bool) -> TTensor:
244248
"""
249+
:param check_sampling_activation_stats_flow: whether we are checking the flow with sampling when processing
250+
activation statistics
245251
Returns the reference output of calculate_quantization_params for MoE model.
246252
"""
247253

248254
@staticmethod
249255
@abstractmethod
250-
def get_scale_estimation_ref() -> TTensor:
256+
def get_scale_estimation_ref(check_sampling_activation_stats_flow: bool) -> TTensor:
251257
"""
258+
:param check_sampling_activation_stats_flow: whether we are checking the flow with sampling when processing
259+
activation statistics
252260
Returns the reference output of calculate_quantization_params of ScaleEstimation.
253261
"""
254262

255263
@pytest.mark.parametrize("is_moe", [False, True])
256-
def test_scale_estimation(self, mocker, is_moe):
264+
@pytest.mark.parametrize("check_sampling_activation_stats_flow", [False, True])
265+
def test_scale_estimation(self, mocker, is_moe, check_sampling_activation_stats_flow):
257266
"""Checks that scales match the reference."""
258267
calc_q_params_spy = mocker.spy(ScaleEstimation, "calculate_quantization_params")
259268

@@ -264,9 +273,15 @@ def test_scale_estimation(self, mocker, is_moe):
264273
model = self.get_model_for_test_scale_estimation()
265274
input = np.arange(0, 4 * 8, dtype=np.float32).reshape(1, 4, 8)
266275

267-
# prepare dataset with one input tensor
276+
# prepare dataset of size subset_size with input tensors
277+
subset_size = 2 if check_sampling_activation_stats_flow else 1
278+
# make sure that subset size for SE < subset size for statistics collection.
279+
# This is to test the Optimized statistics processing flow which samples only a few data
280+
# points in nncf/quantization/algorithms/weight_compression/activation_stats.py
281+
se_subset_size = subset_size // 2 if check_sampling_activation_stats_flow else subset_size
268282
input = self.to_tensor(input)
269-
dataset = Dataset([input], self.get_transform_func())
283+
284+
dataset = Dataset([input + i for i in range(subset_size)], self.get_transform_func())
270285

271286
with SpyWeightCompressionStatisticsContext(mocker):
272287
_ = compress_weights(
@@ -277,15 +292,18 @@ def test_scale_estimation(self, mocker, is_moe):
277292
scale_estimation=True,
278293
all_layers=True,
279294
dataset=dataset,
295+
subset_size=subset_size,
296+
advanced_parameters=nncf.AdvancedCompressionParameters(
297+
scale_estimation_params=nncf.AdvancedScaleEstimationParameters(subset_size=se_subset_size)
298+
),
280299
)
281300

282301
computed_scale = calc_q_params_spy.spy_return[0]
283302

284303
if is_moe:
285-
reference = self.get_moe_scale_estimation_ref()
304+
reference = self.get_moe_scale_estimation_ref(check_sampling_activation_stats_flow)
286305
else:
287-
reference = self.get_scale_estimation_ref()
288-
306+
reference = self.get_scale_estimation_ref(check_sampling_activation_stats_flow)
289307
assert fns.allclose(Tensor(reference), computed_scale)
290308

291309
@staticmethod
@@ -643,3 +661,46 @@ def get_transform_func() -> Optional[Callable[..., Any]]:
643661
@staticmethod
644662
def get_reduction_axes() -> int:
645663
return 1
664+
665+
@pytest.mark.parametrize(
666+
"mean_values_shape,num_samples,subset_size,expected_s_shape,expected_X_shape,expected_indices",
667+
[
668+
# 2D Activations
669+
((8,), 10, 5, (8,), (8, 5), [0, 2, 4, 6, 8]),
670+
((8,), 5, 10, (8,), (8, 5), [0, 1, 2, 3, 4]),
671+
((8,), 12, 5, (8,), (8, 6), [0, 2, 4, 6, 8, 10]),
672+
# 3D Activations
673+
((4, 8), 10, 5, (4, 8), (4, 8, 5), [0, 2, 4, 6, 8]),
674+
((4, 8), 5, 10, (4, 8), (4, 8, 5), [0, 1, 2, 3, 4]),
675+
((4, 8), 25, 8, (4, 8), (4, 8, 9), [0, 3, 6, 9, 12, 15, 18, 21, 24]),
676+
],
677+
)
678+
def test_process_stats(
679+
self, mean_values_shape, num_samples, subset_size, expected_s_shape, expected_X_shape, expected_indices
680+
):
681+
total_elements = reduce(mul, mean_values_shape, 1)
682+
mean_values = [
683+
Tensor(np.arange(i * total_elements, (i + 1) * total_elements, dtype=np.float32).reshape(mean_values_shape))
684+
for i in range(num_samples)
685+
]
686+
shape_values = [(1,) + mean_values_shape for _ in range(num_samples)]
687+
688+
stats = WCTensorStatistic(mean_values=mean_values, shape_values=shape_values)
689+
690+
s, X = process_stats(stats, subset_size)
691+
692+
assert s.shape == expected_s_shape, f"Expected s shape {expected_s_shape}, got {s.shape}"
693+
assert X.shape == expected_X_shape, f"Expected X shape {expected_X_shape}, got {X.shape}"
694+
695+
X_full_list = [mean_values[i] for i in range(num_samples)]
696+
X_full = fns.stack(X_full_list)
697+
axes = list(range(1, len(X_full.shape))) + [0]
698+
X_full_transposed = fns.transpose(X_full, axes=axes)
699+
700+
for idx, sample_idx in enumerate(expected_indices):
701+
expected_sample = X_full_transposed[..., sample_idx]
702+
actual_sample = X[..., idx]
703+
assert fns.all(actual_sample == expected_sample)
704+
705+
expected_s = fns.max(fns.abs(X_full_transposed), axis=-1)
706+
assert fns.all(s == expected_s)

tests/onnx/quantization/test_weights_compression.py

Lines changed: 134 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -484,78 +484,150 @@ def get_moe_model_for_test_scale_estimation() -> onnx.ModelProto:
484484
return mb.build(opset_version=21)
485485

486486
@staticmethod
487-
def get_scale_estimation_ref():
488-
return np.array(
489-
[
490-
[[0.473328]],
491-
[[0.929023]],
492-
[[1.446527]],
493-
[[1.920595]],
494-
[[2.517053]],
495-
[[3.030101]],
496-
[[3.584278]],
497-
[[4.04351]],
498-
[[4.620007]],
499-
[[5.165322]],
500-
[[5.710637]],
501-
[[6.122580]],
502-
[[6.655914]],
503-
[[7.237173]],
504-
[[7.722581]],
505-
[[8.255914]],
506-
]
507-
).T
487+
def get_scale_estimation_ref(check_sampling_activation_stats_flow):
488+
return (
489+
np.array(
490+
[
491+
[[0.473328]],
492+
[[0.929023]],
493+
[[1.446527]],
494+
[[1.920595]],
495+
[[2.517054]],
496+
[[3.030102]],
497+
[[3.584279]],
498+
[[4.043509]],
499+
[[4.620008]],
500+
[[5.165322]],
501+
[[5.710637]],
502+
[[6.122581]],
503+
[[6.655914]],
504+
[[7.237174]],
505+
[[7.722580]],
506+
[[8.255914]],
507+
]
508+
).T,
509+
np.array(
510+
[
511+
[[0.47344488]],
512+
[[0.9287766]],
513+
[[1.4463282]],
514+
[[1.920052]],
515+
[[2.5167778]],
516+
[[3.02987]],
517+
[[3.5842714]],
518+
[[4.0429296]],
519+
[[4.619769]],
520+
[[5.165224]],
521+
[[5.7106786]],
522+
[[6.121212]],
523+
[[6.654546]],
524+
[[7.2366524]],
525+
[[7.7212124]],
526+
[[8.254545]],
527+
]
528+
).T,
529+
)[check_sampling_activation_stats_flow]
508530

509531
@staticmethod
510-
def get_moe_scale_estimation_ref():
511-
return np.array(
512-
[
532+
def get_moe_scale_estimation_ref(check_sampling_activation_stats_flow):
533+
return (
534+
np.array(
513535
[
514536
[
515537
[
516-
7.573249,
517-
7.4666667,
518-
7.4666667,
519-
7.4666667,
520-
7.4666667,
521-
7.260152,
522-
7.4666667,
523-
7.4666667,
524-
7.4666667,
525-
7.4666667,
526-
7.3082952,
527-
7.846745,
528-
7.223278,
529-
7.271495,
530-
7.420518,
531-
7.4666667,
538+
[
539+
7.5732,
540+
7.4667,
541+
7.4667,
542+
7.4667,
543+
7.4667,
544+
7.2602,
545+
7.4667,
546+
7.4667,
547+
7.4667,
548+
7.4667,
549+
7.3083,
550+
7.8467,
551+
7.2233,
552+
7.2715,
553+
7.4205,
554+
7.4667,
555+
]
556+
]
557+
],
558+
[
559+
[
560+
[
561+
14.8205,
562+
14.9032,
563+
14.9858,
564+
15.0685,
565+
15.1512,
566+
14.3400,
567+
14.4173,
568+
14.4945,
569+
14.5718,
570+
14.6491,
571+
14.7264,
572+
14.8037,
573+
14.8810,
574+
14.9583,
575+
15.0355,
576+
15.1128,
577+
]
532578
]
533-
]
534-
],
579+
],
580+
]
581+
),
582+
np.array(
535583
[
536584
[
537585
[
538-
14.820505,
539-
14.903171,
540-
14.985837,
541-
15.068501,
542-
15.151169,
543-
14.339979,
544-
14.417264,
545-
14.494548,
546-
14.571833,
547-
14.649117,
548-
14.726402,
549-
14.803687,
550-
14.880971,
551-
14.958257,
552-
15.035541,
553-
15.112826,
586+
[
587+
7.575118,
588+
7.4666667,
589+
7.4666667,
590+
7.4666667,
591+
7.4666667,
592+
7.254837,
593+
7.4666667,
594+
7.4666667,
595+
7.4666667,
596+
7.4666667,
597+
7.495066,
598+
7.850108,
599+
7.219489,
600+
7.2685375,
601+
7.418597,
602+
7.4666667,
603+
]
554604
]
555-
]
556-
],
557-
]
558-
)
605+
],
606+
[
607+
[
608+
[
609+
14.820066,
610+
14.902746,
611+
14.985427,
612+
15.068108,
613+
15.150787,
614+
14.3391285,
615+
14.416424,
616+
14.493721,
617+
14.571016,
618+
14.648311,
619+
14.725608,
620+
14.802904,
621+
14.8801985,
622+
14.957496,
623+
15.034791,
624+
15.112087,
625+
]
626+
]
627+
],
628+
]
629+
),
630+
)[check_sampling_activation_stats_flow]
559631

560632
@staticmethod
561633
def get_orig_weight(model: onnx.ModelProto) -> Tensor:

0 commit comments

Comments
 (0)