Skip to content

Per-tensor QAT model Conv2d+BN+relu folding issue #1131

Open
@sheh

Description

@sheh

Describe the bug
I need to train QAT (per-tensor) model and then convert it tflite. But I get "folding issue" described here.

System information

TensorFlow version (installed from source or binary): 2.15.0

TensorFlow Model Optimization version (installed from source or binary): 0.8.0

Python version: 3.10.12

Describe the expected behavior

A 1-layer CNN (conv2d+bn+relu) is folded and converted to tflite after QAT in per-tensor mode without splitting computation graph on multiply "Quantize-Dequatize" parts.

Describe the current behavior

After folding a 1-layer CNN (conv2d+bn+relu) the folded layer is unquantized.

Code to reproduce the issue

import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import \
    default_8bit_quantize_scheme
import tensorflow_model_optimization as tfmot

quantize_apply = tfmot.quantization.keras.quantize_apply
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model


def train_qat_convert_tflite(per_tensor):
    model = keras.Sequential([
        keras.layers.InputLayer(input_shape=(128, 128, 3)),

        keras.layers.Conv2D(3, 3, padding='same', use_bias=False),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('relu'),

        keras.layers.Softmax(),
    ])

    annotated_model = quantize_annotate_model(model)
    q_aware_model = quantize_apply(annotated_model,
                                   scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme(disable_per_axis=per_tensor))

    q_aware_model.compile(
        # optimizer=Adam(learning_rate=learning_rate, epsilon=1e-8, weight_decay=1e-4),
        optimizer='Adam',
        loss=keras.losses.MeanAbsoluteError(),
        metrics=['accuracy'],
    )

    q_aware_model.fit(
        x=tf.random.normal((128, 128, 128, 3)),
        y=tf.random.normal((128, 128, 128, 3)),
        batch_size=16,
        epochs=1,
    )
    q_aware_model.save(f'{per_tensor=}.h5')

    converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

    quantized_tflite_model = converter.convert()
    open(f'{per_tensor=}.tflite', "wb").write(quantized_tflite_model)


train_qat_convert_tflite(per_tensor=True)
train_qat_convert_tflite(per_tensor=False)

Screenshots

keras-h5

2024-05-15_13-35

Additional context
I tested #552 but in case of a simple 1-layer CNN (see code) there are no custom layers so if statement in _replace function is False and I get the next line.
I see that in keras h5 model BN layer is quantized as per-channel because quantization parameters in both cases are tensors not scalar as it is expected for per-tensor mode.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions