Skip to content

Commit 02e0cd5

Browse files
daverimtensorflower-gardener
authored andcommitted
Support quantization of gelu activation
PiperOrigin-RevId: 424128991
1 parent 28df91b commit 02e0cd5

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def get_output_quantizers(self, layer):
491491
'Default8BitActivationQuantizeConfig.'.format(
492492
layer.activation))
493493

494-
if layer.activation.__name__ in ['relu', 'swish']:
494+
if layer.activation.__name__ in ['relu', 'swish', 'gelu']:
495495
# 'relu' should generally get fused into the previous layer.
496496
return [quantizers.MovingAverageQuantizer(
497497
num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,13 @@ class QuantizeAwareActivation(object):
7575

7676
# TODO(pulkitb): Other activations such as elu, tanh etc., should just work
7777
# on inclusion. Verify in TFLite before enabling.
78+
# gelu requires both because it is not folded by tflite.
7879

7980
# These activations should be quantized prior to the activation being applied.
80-
_PRE_QUANT_ACTIVATIONS = frozenset({'softmax', 'sigmoid', 'tanh'})
81+
_PRE_QUANT_ACTIVATIONS = frozenset({'softmax', 'sigmoid', 'tanh', 'gelu'})
8182

8283
# These activations should be quantized after the activation has been applied.
83-
_POST_QUANT_ACTIVATIONS = frozenset({'linear', 'relu', 'swish'})
84+
_POST_QUANT_ACTIVATIONS = frozenset({'linear', 'relu', 'swish', 'gelu'})
8485

8586
# Don't take any quantize operations for these activations.
8687
_NO_QUANTIZE_ACTIVATIONS = frozenset({'NoOpActivation'})

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ def call(self, inputs, training=None):
5959
def compute_output_shape(self, input_shape):
6060
return input_shape
6161

62+
def testSupportedPreAndPostActivation(self):
63+
layer = self.TestLayer()
64+
layer.activation = QuantizeAwareActivation(
65+
activations.get('gelu'), self.quantizer, 0, layer)
66+
model = keras.Sequential([layer])
67+
names = ', '.join([weight.name for weight in model.layers[-1].weights])
68+
self.assertIn('pre_activation', names)
69+
self.assertIn('post_activation', names)
70+
6271
def testConstruction_SupportedAndUnsupportedActivations(self):
6372
layer = self.TestLayer()
6473

0 commit comments

Comments
 (0)