Skip to content

Commit 4f4e5d0

Browse files
Xharktensorflower-gardener
authored andcommitted
Add Conv2DTranspose supports.
PiperOrigin-RevId: 353149071
1 parent 37006f0 commit 4f4e5d0

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ class Default8BitQuantizeRegistry(
9292
# layers.DepthwiseConv2D is supported and handled in code below.
9393

9494
# _QuantizeInfo(layers.Conv3D, ['kernel'], ['activation']),
95-
# _QuantizeInfo(layers.Conv2DTranspose, ['kernel'], ['activation']),
9695
# _QuantizeInfo(layers.Conv3DTranspose, ['kernel'], ['activation']),
9796
_no_quantize(layers.Cropping1D),
9897
_no_quantize(layers.Cropping2D),
@@ -198,6 +197,9 @@ def __init__(self):
198197
self._layer_quantize_map[
199198
layers.DepthwiseConv2D] = Default8BitConvQuantizeConfig(
200199
['depthwise_kernel'], ['activation'], False)
200+
self._layer_quantize_map[layers.Conv2DTranspose] = \
201+
Default8BitConvTransposeQuantizeConfig(
202+
['kernel'], ['activation'], False)
201203

202204
def _is_supported_layer(self, layer_class):
203205
return layer_class in self._layer_quantize_map
@@ -509,6 +511,17 @@ def __init__(self, weight_attrs, activation_attrs, quantize_output):
509511
)
510512

511513

514+
class Default8BitConvTransposeQuantizeConfig(Default8BitQuantizeConfig):
515+
"""QuantizeConfig for Conv2DTranspose layers."""
516+
517+
def __init__(self, weight_attrs, activation_attrs, quantize_output):
518+
super(Default8BitConvTransposeQuantizeConfig,
519+
self).__init__(weight_attrs, activation_attrs, quantize_output)
520+
521+
self.weight_quantizer = default_8bit_quantizers.Default8BitConvTransposeWeightsQuantizer(
522+
)
523+
524+
512525
def _types_dict():
513526
return {
514527
'Default8BitQuantizeConfig':

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,20 @@ def build(self, tensor_shape, name, layer):
4141
trainable=False)
4242

4343
return {'min_var': min_weight, 'max_var': max_weight}
44+
45+
46+
class Default8BitConvTransposeWeightsQuantizer(quantizers.LastValueQuantizer):
47+
"""Quantizer for handling weights in Conv2DTranspose layers."""
48+
49+
def __init__(self):
50+
"""Construct LastValueQuantizer with params specific for TFLite Conv2DTranpose."""
51+
52+
super(Default8BitConvTransposeWeightsQuantizer, self).__init__(
53+
num_bits=8, per_axis=False, symmetric=True, narrow_range=True)
54+
55+
def __call__(self, inputs, training, weights, **kwargs):
56+
outputs = tf.transpose(inputs, (0, 1, 3, 2))
57+
outputs = super(Default8BitConvTransposeWeightsQuantizer,
58+
self).__call__(outputs, training, weights, **kwargs)
59+
outputs = tf.transpose(outputs, (0, 1, 3, 2))
60+
return outputs

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ def _get_upsampling2d_bilinear_model(self):
153153
x = tf.keras.layers.UpSampling2D(size=(1, 5), interpolation='bilinear')(i)
154154
return tf.keras.Model(i, x)
155155

156+
def _get_conv2d_transpose_model(self):
157+
i = tf.keras.Input(shape=(32, 32, 3))
158+
x = tf.keras.layers.Conv2DTranspose(
159+
2, kernel_size=(3, 3), strides=(2, 2))(
160+
i)
161+
return tf.keras.Model(i, x)
162+
156163
@parameterized.parameters([
157164
_get_single_conv_model, _get_single_dense_model,
158165
_get_single_conv_relu_model, _get_stacked_convs_model,
@@ -165,6 +172,7 @@ def _get_upsampling2d_bilinear_model(self):
165172
# TODO(tfmot): There are gaps between ResizeBilinear with FakeQuant and
166173
# TFLite quantized ResizeBilinear op. It has a bit more quantization
167174
# error than other ops in this test now.
175+
_get_conv2d_transpose_model,
168176
])
169177
def testModelEndToEnd(self, model_fn):
170178
# 1. Check whether quantized model graph can be constructed.

tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ class QuantizeFullIntegerModelTest(tf.test.TestCase, parameterized.TestCase):
298298
layers.UpSampling3D,
299299
# Not done since not registered since not per-axis yet.
300300
layers.Conv1D,
301-
layers.Conv2DTranspose,
302301
]
303302
])
304303
def testQuantizeSingleLayer_ProducesFullIntegerModel_TF2(

0 commit comments

Comments
 (0)