Description
When attempting to sparsely prune or cluster DepthwiseConv2D Layers it appears that no clustering or sparse pruning actually occurs.
System information
TensorFlow version (installed from source or binary): 2.16.1/2.15.1
TensorFlow Model Optimization version (installed from source or binary): 0.8.0/0.7.5
Python version: 3.9.18
Describe the expected behavior
I should see that the kernel for the Depthwise layer should have 3 unique weights, not 9.
Describe the current behavior
The Depthwise layer has 9 unique weights
The same issue occurs for sparse m by n pruning, the weights are not pruned correctly in an m by n manner
Code to reproduce the issue
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras
import numpy as np
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.DepthwiseConv2D(kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
validation_split=0.1,
epochs=10
)
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 3,
'cluster_centroids_init': CentroidInitialization.LINEAR
}
clustered_model = cluster_weights(model, **clustering_params)
opt = keras.optimizers.Adam(learning_rate=1e-5)
clustered_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])
clustered_model.fit(
train_images,
train_labels,
batch_size=500,
epochs=1,
validation_split=0.1)
for layer in model.layers:
for weight in layer.weights:
print(weight.name)
print(len(np.unique(weight)))