Open
Description
Describe the bug
Hi! I have some problem performing QAT on a gnn (built from TF-GNN) that uses custom layers. I skipped the custom layers and opted to only quantize the built-in Dense layers, but still saw some unexpected errors (please see below).
System information
TensorFlow version (installed from binary): 2.18.0
TensorFlow Model Optimization version (installed from binary): 0.8.0
Python version: 3.10.14
Code to reproduce the issue
model = ... # GNN model with custom layers
def apply_quantization_to_dense(layer):
if isinstance(layer, keras.layers.Dense) or isinstance(layer, tf.keras.layers.Dense):
return tfmot.quantization.keras.quantize_annotate_layer(layer)
return layer
# Use `keras.models.clone_model` to apply `apply_quantization_to_dense`
# to the layers of the model.
annotated_model = tf.keras.models.clone_model(
model,
clone_function=apply_quantization_to_dense,
)
# Create quantization-aware model from non-quantization-aware model
q_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
Describe the current behavior
Error:
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn/quant_finetune.py", line 235, in main
q_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/keras/metrics.py", line 74, in inner
raise error
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/keras/metrics.py", line 69, in inner
results = func(*args, **kwargs)
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/quantization/keras/quantize.py", line 490, in quantize_apply
transformed_model, layer_quantize_map = quantize_transform.apply(
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py", line 76, in apply
return model_transformer.ModelTransformer(model, transforms, set(layer_quantize_map.keys()), layer_quantize_map).transform()
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py", line 625, in transform
transformed_model = keras.Model.from_config(self._config, custom_objects)
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tf_keras/src/engine/training.py", line 3325, in from_config
inputs, outputs, layers = functional.reconstruct_from_config(
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tf_keras/src/engine/functional.py", line 1492, in reconstruct_from_config
process_layer(layer_data)
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tf_keras/src/engine/functional.py", line 1473, in process_layer
layer = deserialize_layer(layer_data, custom_objects=custom_objects)
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tf_keras/src/layers/serialization.py", line 276, in deserialize
return serialization_lib.deserialize_keras_object(
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tf_keras/src/saving/serialization_lib.py", line 727, in deserialize_keras_object
instance = cls.from_config(inner_config)
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_gnn/keras/layers/graph_update.py", line 227, in from_config
return cls(**config)
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_gnn/keras/layers/graph_update.py", line 184, in __init__
self._init_from_updates(edge_sets, node_sets, context)
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_gnn/keras/layers/graph_update.py", line 201, in _init_from_updates
self._node_set_updates = {
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_gnn/keras/layers/graph_update.py", line 202, in <dictcomp>
key: _check_is_layer(value, f"GraphUpdate(node_sets={{{key}: ...}}")
File "/home/sicli01/Projects/FluidML/gnn-physics/gnn_env_tf2_16_1/lib/python3.10/site-packages/tensorflow_gnn/keras/layers/graph_update.py", line 584, in _check_is_layer
raise ValueError(f"{description} must be a tf.keras.layer.Layer, "
ValueError: GraphUpdate(node_sets={particle: ...} must be a tf.keras.layer.Layer, got type: SharedObjectConfig,
It seems that this is due to some layers failing to be deserialized?
Thanks in advance!