Open
Description
I am trying to prune a CNN model given below in the python script.
Describe the bug
I am trying to execute code, which gives me the error every time. Is it a version mismatch or something else?
System information
TensorFlow version (installed from source or binary): 2.17.1
TensorFlow Model Optimization version (installed from source or binary): 0.8.0
Python version: 3.10.12 [GCC 11.4.0]
Code to reproduce the issue
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras import layers, models
import numpy as np
# Define the pruning parameters
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0,
final_sparsity=0.5,
begin_step=2000,
end_step=4000
)
# Create the model using keras.models.Sequential
base_model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10)
])
# Build the model
base_model.build(input_shape=(None, 32, 32, 3))
# Apply pruning wrapper
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(
base_model,
pruning_config={
'pruning_schedule': pruning_schedule
}
)
# Callbacks
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
"pruned_model_checkpoint.h5",
save_best_only=True,
monitor="val_accuracy"
),
tf.keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=3
),
tfmot.sparsity.keras.UpdatePruningStep()
]
# Compile the model
model_for_pruning.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# Train the model with pruning
history = model_for_pruning.fit(
train_images, train_labels,
batch_size=32,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=callbacks
)
# Evaluate the pruned model
test_loss, test_acc = model_for_pruning.evaluate(test_images, test_labels, verbose=2)
print(f"Pruned Test Accuracy: {(test_acc * 100):.2f}%")
# Strip pruning wrappers for final model
final_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
# Save the pruned and stripped model
final_model.save("cifar10_pruned_cnn_model.h5")
print("Pruned model saved as 'cifar10_pruned_cnn_model.h5'")
Screenshots
N/A
Additional context
Please help me to solve the issue.