Open
Description
System information
- TensorFlow version (you are using): 2.8
- Are you willing to contribute it (Yes/No): Yes
Motivation
This would be beneficial for models that use this layer - this is for example used in Transformer models.
Describe the feature
Be able to run QAT on a model with the LayerNormalization layer.
Describe how existing APIs don't satisfy your use case (optional if obvious)
As an example, the following code snippet will fail:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
keras.layers.LayerNormalization(axis=3),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
quant_model = tfmot.quantization.keras.quantize_model(model)