Skip to content

Commit 8a30046

Browse files
Add round to nearest logic for numpy case
1 parent 17a3aec commit 8a30046

File tree

8 files changed

+54
-10
lines changed

8 files changed

+54
-10
lines changed

nncf/quantization/algorithms/weight_compression/weight_lowering.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,19 @@ def _calculate_float_quantized_weight(norm_weight: Tensor, mode: CompressWeights
554554
quantile_centers = fns.from_numpy(quantile_centers_np, backend=norm_weight.backend)
555555
indexes = fns.searchsorted(quantile_centers, norm_weight)
556556
quantiles = fns.from_numpy(quantiles_np, backend=indexes.backend)
557+
558+
if mode == CompressWeightsMode.E2M1:
559+
# Round to the nearest even quantile
560+
shifted_indexes = fns.clip(indexes + 1, 0, quantiles.size - 1)
561+
left = quantiles[indexes]
562+
right = quantiles[shifted_indexes]
563+
dist_left = fns.abs(norm_weight - left)
564+
dist_right = fns.abs(norm_weight - right)
565+
choose_right = fns.logical_or(
566+
dist_right < dist_left, fns.logical_and(dist_left == dist_right, (shifted_indexes + 1) % 2 == 0)
567+
)
568+
indexes = fns.where(choose_right, shifted_indexes, indexes)
569+
557570
quantized_weight = quantiles[indexes]
558571
return quantized_weight
559572

nncf/tensor/functions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from nncf.tensor.functions.numeric import isempty as isempty
3737
from nncf.tensor.functions.numeric import item as item
3838
from nncf.tensor.functions.numeric import log2 as log2
39+
from nncf.tensor.functions.numeric import logical_and as logical_and
3940
from nncf.tensor.functions.numeric import logical_or as logical_or
4041
from nncf.tensor.functions.numeric import masked_mean as masked_mean
4142
from nncf.tensor.functions.numeric import masked_median as masked_median

nncf/tensor/functions/numeric.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,18 @@ def logical_or(x1: Tensor, x2: Tensor) -> Tensor:
612612
"""
613613

614614

615+
@tensor_dispatcher
616+
def logical_and(x1: Tensor, x2: Tensor) -> Tensor:
617+
"""
618+
Computes the element-wise logical AND of the given input tensors.
619+
Zeros are treated as False and nonzeros are treated as True.
620+
621+
:param x1: The input tensor.
622+
:param x2: The tensor to compute and with.
623+
:return: Result of elementwise and operation between input_ and other tensor.
624+
"""
625+
626+
615627
@tensor_dispatcher
616628
def masked_mean(x: Tensor, mask: Tensor, axis: T_AXIS, keepdims: bool = False) -> Tensor:
617629
"""

nncf/tensor/functions/numpy_numeric.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ def _(x1: T_NUMPY_ARRAY, x2: T_NUMPY_ARRAY) -> T_NUMPY_ARRAY:
348348
return np.logical_or(x1, x2)
349349

350350

351+
@numeric.logical_and.register
352+
def _(x1: T_NUMPY_ARRAY, x2: T_NUMPY_ARRAY) -> T_NUMPY_ARRAY:
353+
return np.logical_and(x1, x2)
354+
355+
351356
@numeric.masked_mean.register
352357
def _(
353358
x: T_NUMPY_ARRAY,

nncf/tensor/functions/tf_numeric.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,12 @@ def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor:
417417
return tf.logical_or(x1, x2)
418418

419419

420+
@numeric.logical_and.register
421+
def _(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor:
422+
with tf.device(x1.device):
423+
return tf.logical_and(x1, x2)
424+
425+
420426
@numeric.masked_mean.register
421427
def _(
422428
x: tf.Tensor, mask: Optional[tf.Tensor], axis: Optional[Union[int, tuple[int, ...]]], keepdims: bool = False

nncf/tensor/functions/torch_numeric.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,11 @@ def _(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
366366
return torch.logical_or(x1, x2)
367367

368368

369+
@numeric.logical_and.register
370+
def _(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
371+
return torch.logical_and(x1, x2)
372+
373+
369374
@numeric.masked_mean.register
370375
def _(x: torch.Tensor, mask: Optional[torch.Tensor], axis: T_AXIS, keepdims: bool = False) -> torch.Tensor:
371376
if mask is None:

nncf/tensor/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def __ifloordiv__(self, other: Union[Tensor, T_NUMBER]) -> Tensor:
144144
self._data //= unwrap_tensor_data(other)
145145
return self
146146

147+
def __mod__(self, other: Union[Tensor, T_NUMBER]) -> Tensor:
148+
return cast(Tensor, _call_function("_binary_op_nowarn", self, other, operator.mod))
149+
147150
def __matmul__(self, other: Union[Tensor, T_NUMBER]) -> Tensor:
148151
return Tensor(self.data @ unwrap_tensor_data(other))
149152

tests/openvino/optimized_functions/test_compression_functions.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,18 @@ def test_optimized_compression_is_disabled(weight_shape, is_disabled, quantizati
151151
reason="Due to a bug in CPU plugin compression models can fail at compilation on ARM CPUs. Ticket: 164135.",
152152
)
153153
@pytest.mark.parametrize("weight_shape", [WEIGHT_SHAPE], ids=[""])
154-
# @pytest.mark.parametrize("config", COMPRESSION_CONFIGS, ids=[str(c) for c in COMPRESSION_CONFIGS])
155-
@pytest.mark.parametrize("config", FP4_COMPRESSION_CONFIGS[-2:])
154+
@pytest.mark.parametrize("config", COMPRESSION_CONFIGS, ids=[str(c) for c in COMPRESSION_CONFIGS])
156155
@pytest.mark.parametrize(
157156
("quantization_task", "tensor_backend"),
158157
[
159158
(QuantizationTask.Q, TensorBackend.numpy),
160-
# (QuantizationTask.Q, "auto"),
159+
(QuantizationTask.Q, "auto"),
161160
# NumPy backend should support OV tensors as inputs only for quantization task
162-
# (QuantizationTask.Q, TensorBackend.ov),
163-
# (QuantizationTask.Q_DQ, TensorBackend.numpy),
164-
# (QuantizationTask.Q_DQ, "auto"),
165-
# (QuantizationTask.Q_DQ_RQ, TensorBackend.numpy),
166-
# (QuantizationTask.Q_DQ_RQ, "auto"),
161+
(QuantizationTask.Q, TensorBackend.ov),
162+
(QuantizationTask.Q_DQ, TensorBackend.numpy),
163+
(QuantizationTask.Q_DQ, "auto"),
164+
(QuantizationTask.Q_DQ_RQ, TensorBackend.numpy),
165+
(QuantizationTask.Q_DQ_RQ, "auto"),
167166
],
168167
)
169168
@pytest.mark.parametrize("dtype", [TensorDataType.float32, TensorDataType.float16, TensorDataType.bfloat16])
@@ -525,8 +524,8 @@ def format_list_of_floats(lst):
525524
f"NumPy result: {format_list_of_floats(numpy_result.data[not_equal_mask])}\n"
526525
)
527526
if "input" in results[ComputationBackend.OV] and "input" in results[ComputationBackend.NumPy]:
528-
numpy_input = results[ComputationBackend.NumPy]['input'].data
529-
ov_input = results[ComputationBackend.OV]['input'].data
527+
numpy_input = results[ComputationBackend.NumPy]["input"].data
528+
ov_input = results[ComputationBackend.OV]["input"].data
530529
np.testing.assert_allclose(numpy_input, ov_input, atol=0, rtol=0)
531530
msg += f"Input values : {format_list_of_floats(numpy_input[not_equal_mask])}\n"
532531
misaligned_groups_mask = np.any(not_equal_mask, axis=-1)

0 commit comments

Comments
 (0)