Skip to content

QAT support for LayerNormalization #942

Open
@tom-arm

Description

@tom-arm

System information

  • TensorFlow version (you are using): 2.8
  • Are you willing to contribute it (Yes/No): Yes

Motivation
This would be beneficial for models that use this layer - this is for example used in Transformer models.

Describe the feature
Be able to run QAT on a model with the LayerNormalization layer.

Describe how existing APIs don't satisfy your use case (optional if obvious)
As an example, the following code snippet will fail:

import tensorflow as tf
import tensorflow_model_optimization as tfmot

from tensorflow import keras

model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
    keras.layers.LayerNormalization(axis=3),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

quant_model = tfmot.quantization.keras.quantize_model(model)

Metadata

Metadata

Assignees

Labels

feature requestfeature requesttechnique:qatRegarding tfmot.quantization.keras (for quantization-aware training) APIs and docs

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions