Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/nncf/openvino/optimized_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,15 @@ def do_float_quantization(
) -> tuple[Tensor, Tensor, Tensor]:
"""
Computes quantization scale if not provided, and performs corresponding float weight quantization.
NF4 format uses 16 levels in [-1, 1] range, while MXFP4 uses 16 levels in [-6, 6].
NF4 format uses 16 levels in [-1, 1] range, while FP4/MXFP4 uses 16 levels in [-6, 6].

:param weight: Weight array to compress.
:param config: Weight compression configuration.
:param reduction_axes: Axes, along which to reduce (collect) different statistics.
:param precomputed_scale: Optional precomputed scale.
:return: Returns quantized (for MXFP8_E4M3, FP4 and FP8_E4M3 normalized)
weight tensor and corresponding scale tensor.
:return: Returns quantized weight tensor and corresponding scale tensor.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXFP8_E4M3 and FP8_E4M3 are actually not supported by optimized compression, no need to mention this.

"""
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]

weight_shape = weight.shape
scale_shape = None if precomputed_scale is None else precomputed_scale.shape
Expand All @@ -129,7 +128,7 @@ def do_float_quantization(
if weight.backend == TensorBackend.ov:
# Return ov tensors in target precision to seamlessly insert them into openvino model later
ov_model_params.return_ov_tensors = True
weight_dtype = TensorDataType.f4e2m1 if config.mode == CompressWeightsMode.MXFP4 else TensorDataType.nf4
weight_dtype = TensorDataType.nf4 if config.mode == CompressWeightsMode.NF4 else TensorDataType.f4e2m1
ov_model_params.output_dtypes.update({"compressed_weight": weight_dtype})

model = get_float_quantization_model(
Expand Down Expand Up @@ -235,7 +234,7 @@ def float_quantize_dequantize_weight(
:param return_compressed_weight: If True, besides decompressed weight will also return compressed weight and scale.
:return: Dequantized weight tensor or a tuple containing the decompressed weight, compressed weight and scale.
"""
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]

# When reduction axes are not provided, assuming that the weights are already reshaped
if config.group_size != -1 and reduction_axes is not None:
Expand Down
16 changes: 11 additions & 5 deletions src/nncf/openvino/optimized_functions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,7 @@ def get_float_quantization_model(
reduction_axes: Optional[ReductionAxes] = None,
) -> Union[ModelCallable, ModelAsNodes]:
"""
Get a model that compresses weights to float (currently nf4 or mxfp4) destination type using the given
configuration.
Get a model that compresses weights to float destination type using the given configuration.

:param ov_model_params: OV model parameters.
:param config: Compression configuration.
Expand Down Expand Up @@ -319,7 +318,7 @@ def get_float_quantize_dequantize_weight_model(
return_compressed_weight: Optional[bool] = False,
) -> ModelCallable:
"""
Get a model that performs float (currently only nf4) compression and decompression of the given weight.
Get a model that performs float compression and decompression of the given weight.

:param ov_model_params: OV model parameters.
:param config: Compression configuration.
Expand Down Expand Up @@ -572,7 +571,7 @@ def _build_float_quantization_model(
reduction_axes: Optional[ReductionAxes] = None,
return_nodes: bool = False,
) -> Union[ModelCallable, ModelAsNodes]:
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]

default_input_dtypes = {"scale": TensorDataType.float32}
default_output_dtypes = {"compressed_weight": TensorDataType.float32, "scale": TensorDataType.float32}
Expand Down Expand Up @@ -626,8 +625,15 @@ def _build_float_quantization_model(
eps = np.finfo(np.float32).eps
scale = opset.select(opset.less(opset.abs(scale), eps), eps, scale)

# Equals 1.0 for NF4
FP_MAX_VALS = {
CompressWeightsMode.MXFP4: 6.0,
CompressWeightsMode.FP4: 6.0,
}
if config.mode in FP_MAX_VALS:
scale = divide_op(scale, opset.constant(FP_MAX_VALS[config.mode], ov.Type.f32))

if config.mode == CompressWeightsMode.MXFP4:
scale = scale / opset.constant(6.0, ov.Type.f32)
scale = opset.log(scale) / opset.log(opset.constant(2.0, ov.Type.f32))
scale = opset.ceil(scale)
scale = opset.clamp(scale, -127.0, 127.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
dtype=np.float32,
)

MXFP4_QUANTILES = np.array(
F4E2M1_QUANTILES = np.array(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mxfp4 is a compression format with f4e2m1 weight, f8e8m0 scale and group size 32. While this grid is defined according to f4e2m1 data type, irrespective of other parameters. So renamed.

[
-6.0,
-4.0,
Expand Down Expand Up @@ -100,4 +100,4 @@
)


CENTER_OF_MXFP4_QUANTILES = (MXFP4_QUANTILES[1:] + MXFP4_QUANTILES[:-1]) / 2
CENTER_OF_F4E2M1_QUANTILES = (F4E2M1_QUANTILES[1:] + F4E2M1_QUANTILES[:-1]) / 2
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from nncf.errors import UnsupportedModelError
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.algorithms.weight_compression.constants import CENTER_OF_MXFP4_QUANTILES
from nncf.quantization.algorithms.weight_compression.constants import CENTER_OF_F4E2M1_QUANTILES
from nncf.quantization.algorithms.weight_compression.constants import CENTER_OF_NF4_QUANTILES
from nncf.quantization.algorithms.weight_compression.constants import MXFP4_QUANTILES
from nncf.quantization.algorithms.weight_compression.constants import F4E2M1_QUANTILES
from nncf.quantization.algorithms.weight_compression.constants import NF4_QUANTILES
from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight
from nncf.quantization.fake_quantize import calculate_scale_zero_point
Expand All @@ -32,6 +32,16 @@

ReductionAxes = Union[int, tuple[int, ...]]


OPTIMIZED_COMPRESSION_COMPATIBLE_MODES = (
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT4_ASYM,
CompressWeightsMode.INT4_SYM,
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.FP4,
)
MIN_INPUT_SIZE_FOR_OPTIMIZED_COMPRESSION = 10000


Expand Down Expand Up @@ -168,7 +178,7 @@ def do_float_quantization(
weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size)

# Optimized implementation
if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4] and _can_run_optimized(weight):
if _can_run_optimized(weight, config.mode):
from nncf.openvino.optimized_functions import do_float_quantization as do_float_quantization_ov

return do_float_quantization_ov(weight, config, reduction_axes, precomputed_scale)
Expand All @@ -183,7 +193,7 @@ def do_float_quantization(
if scale is None:
scale = calculate_float_quantization_params(weight, reduction_axes, config)
norm_weight = _calculate_normalized_weight(weight, scale)
if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]:
if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]:
if original_weight_backend == TensorBackend.ov:
# Can convert through OpenVINO and return OpenVINO-native nf4/f4e2m1 tensor
target_dtype = TensorDataType.nf4 if config.mode == CompressWeightsMode.NF4 else TensorDataType.f4e2m1
Expand All @@ -209,7 +219,7 @@ def float_quantize_dequantize_weight(
) -> Union[Tensor, tuple[Tensor, Tensor, Tensor]]:
"""
First quantizes the given weight tensor to float dtype and then dequantizes it back to obtain float32 values.
MXFP8_E4M3, FP8_E4M3 and FP4 modes currently are not supported.
MXFP8_E4M3 and FP8_E4M3 modes currently are not supported.

:param weight: The weight tensor to quantize-dequantize.
:param config: Compression configuration.
Expand All @@ -221,12 +231,13 @@ def float_quantize_dequantize_weight(
assert config.mode in [
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.FP4,
CompressWeightsMode.CODEBOOK,
CompressWeightsMode.CB4_F8E4M3,
]

# Optimized implementation
if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4] and _can_run_optimized(weight):
if _can_run_optimized(weight, config.mode):
from nncf.openvino.optimized_functions import (
float_quantize_dequantize_weight as float_quantize_dequantize_weight_ov,
)
Expand Down Expand Up @@ -302,7 +313,7 @@ def get_integer_quantization_error(
:return: The quantity characterizing the error of integer quantization.
"""
# Optimized implementation
if _can_run_optimized(weight):
if _can_run_optimized(weight, config.mode):
from nncf.openvino.optimized_functions import (
get_integer_quantization_error as get_integer_quantization_error_ov,
)
Expand Down Expand Up @@ -439,7 +450,7 @@ def do_integer_quantization(
weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size)

# Optimized implementation
if _can_run_optimized(weight):
if _can_run_optimized(weight, config.mode):
from nncf.openvino.optimized_functions import do_integer_quantization as do_integer_quantization_ov

return do_integer_quantization_ov(weight, config, reduction_axes, precomputed_scale, precomputed_zero_point)
Expand Down Expand Up @@ -488,7 +499,7 @@ def integer_quantize_dequantize_weight(
(and zero point).
"""
# Optimized implementation
if _can_run_optimized(weight):
if _can_run_optimized(weight, config.mode):
from nncf.openvino.optimized_functions import (
integer_quantize_dequantize_weight as integer_quantize_dequantize_weight_ov,
)
Expand Down Expand Up @@ -520,14 +531,14 @@ def _calculate_float_quantized_weight(norm_weight: Tensor, mode: CompressWeights
:param norm_weight: Normalized weight tensor to quantize.
:return: Tensor with floating-point values, where each of them corresponds to 1 out of 16 quants.
"""
assert mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]
quantiles_np = NF4_QUANTILES if mode == CompressWeightsMode.NF4 else MXFP4_QUANTILES
quantile_centers_np = CENTER_OF_NF4_QUANTILES if mode == CompressWeightsMode.NF4 else CENTER_OF_MXFP4_QUANTILES
assert mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]
quantiles_np = NF4_QUANTILES if mode == CompressWeightsMode.NF4 else F4E2M1_QUANTILES
quantile_centers_np = CENTER_OF_NF4_QUANTILES if mode == CompressWeightsMode.NF4 else CENTER_OF_F4E2M1_QUANTILES
quantile_centers = fns.from_numpy(quantile_centers_np, backend=norm_weight.backend)
indexes = fns.searchsorted(quantile_centers, norm_weight)
quantiles = fns.from_numpy(quantiles_np, backend=indexes.backend)

if mode == CompressWeightsMode.MXFP4:
if mode in [CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]:
# If in-between two quantiles, round to the nearest even quantile.
shifted_indexes = fns.clip(indexes + 1, 0, quantiles.size - 1)
dist_left = fns.abs(norm_weight - quantiles[indexes])
Expand Down Expand Up @@ -639,11 +650,12 @@ def _calculate_integer_quantized_weight(
return compressed_weights


def _can_run_optimized(inp: Tensor) -> bool:
def _can_run_optimized(inp: Tensor, mode: CompressWeightsMode) -> bool:
if (
inp.backend in [TensorBackend.ov, TensorBackend.numpy]
and inp.size >= MIN_INPUT_SIZE_FOR_OPTIMIZED_COMPRESSION
and os.environ.get("NNCF_DISABLE_OPTIMIZED_COMPRESSION") is None
and mode in OPTIMIZED_COMPRESSION_COMPATIBLE_MODES
):
if is_openvino_available():
from nncf.openvino.cpu_info import is_arm_cpu
Expand Down
34 changes: 33 additions & 1 deletion tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,39 @@ def test_int_compressed_weighs_range(mode, data):
"neg": [-8.0, -8.0, -6.0, -4.0, -4.0, -3.0, -2.0, -1.0, -0.0],
"pos": [-0.0, 1.0, 2.0, 3.0, 4.0, 4.0, 6.0, 8.0, 8.0],
"neg-pos": [-8.0, -8.0, -6.0, -4.0, -4.0, -3.0, -2.0, -1.0, -0.0, 1.0, 2.0, 3.0, 4.0, 4.0, 6.0, 8.0],
}
},
CompressWeightsMode.FP4: {
"neg": [
-8.0,
-8.0,
-5.333333492279053,
-5.333333492279053,
-4.0,
-2.6666667461395264,
-2.0,
-1.3333333730697632,
-0.0,
],
"pos": [-0.0, 1.3333333730697632, 2.0, 2.6666667461395264, 4.0, 5.333333492279053, 5.333333492279053, 8.0, 8.0],
"neg-pos": [
-8.0,
-8.0,
-5.333333492279053,
-5.333333492279053,
-4.0,
-2.6666667461395264,
-2.0,
-1.3333333730697632,
-0.0,
1.3333333730697632,
2.0,
2.6666667461395264,
4.0,
5.333333492279053,
5.333333492279053,
8.0,
],
},
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ class QuantizationTask(Enum):

FP4_COMPRESSION_CONFIGS = [
WeightCompressionConfig(CompressWeightsMode.NF4),
WeightCompressionConfig(CompressWeightsMode.FP4),
WeightCompressionConfig(CompressWeightsMode.NF4, group_size=2),
WeightCompressionConfig(CompressWeightsMode.FP4, group_size=2),
WeightCompressionConfig(CompressWeightsMode.MXFP4, group_size=32),
]

Expand Down Expand Up @@ -377,14 +379,16 @@ def get_input_node_data(node: ov.Node, input_id: int) -> Tensor:
or compression_kwargs.get("lora_correction")
)

if config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM, CompressWeightsMode.MXFP4]:
if config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM] and weight_dtype in [
TensorDataType.f8e4m3,
TensorDataType.f8e5m2,
]:
if is_data_aware and config.mode in [
CompressWeightsMode.INT8_ASYM,
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.MXFP4,
CompressWeightsMode.FP4,
]:
pytest.skip("Data-aware compression is not supported for INT8, MXFP4, FP4 modes.")
if config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]:
if weight_dtype in [TensorDataType.f8e4m3, TensorDataType.f8e5m2]:
pytest.skip("INT8 compression is not supported for f8 dtypes.")
if is_data_aware:
pytest.skip("Data-aware compression is not supported for INT8 or MXFP4 modes.")
else:
compression_kwargs["all_layers"] = True

Expand Down