Skip to content

ValueError: prune_low_magnitude can only prune an object of the following types: keras.models.Sequential, keras functional model, keras.layers.Layer, list of keras.layers.Layer. You passed an object of type: Conv2D. #1172

Open
@ShardulNalegave

Description

@ShardulNalegave

Describe the bug
I am trying to prune MobileNetV2 model using prune_low_magnitude but am running into the following error.

2025-01-19 17:41:03.439078: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Loading MobileNetV2 model...
Pruning the model...
Traceback (most recent call last):
  File "/Users/esp/Projects/gen_models/mbv2_prune.py", line 89, in <module>
    full_workflow()
  File "/Users/esp/Projects/gen_models/mbv2_prune.py", line 81, in full_workflow
    pruned_model = prune_model(model)
  File "/Users/esp/Projects/gen_models/mbv2_prune.py", line 27, in prune_model
    pruned_layer = prune_low_magnitude(tf.keras.layers.Conv2D(
  File "/Users/esp/Projects/gen_models/venv/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/keras/metrics.py", line 74, in inner
    raise error
  File "/Users/esp/Projects/gen_models/venv/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/keras/metrics.py", line 69, in inner
    results = func(*args, **kwargs)
  File "/Users/esp/Projects/gen_models/venv/lib/python3.10/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/prune.py", line 216, in prune_low_magnitude
    raise ValueError(
ValueError: `prune_low_magnitude` can only prune an object of the following types: keras.models.Sequential, keras functional model, keras.layers.Layer, list of keras.layers.Layer. You passed an object of type: Conv2D.

System information

TensorFlow version (installed from source or binary): 2.16.2

TensorFlow Model Optimization version (installed from source or binary): 0.8.0

Python version: 3.10.16

Describe the expected behavior
The layers should be pruned as Conv2D inherits Layer which as the error says is supported.

Describe the current behavior
Even though Conv2D inherits Layer which as the error says is supported, the code fails.

Code to reproduce the issue

def load_model_mobilenetv2():
    model = MobileNetV2(weights="imagenet", include_top=True, alpha=0.35, input_shape=(224, 224, 3))
    return model

def prune_model(base_model):
    pruning_schedule = PolynomialDecay(
        initial_sparsity=0.0,
        final_sparsity=0.5,
        begin_step=0,
        end_step=1000
    )
    
    pruned_model = Sequential()
    for layer in base_model.layers:
        if isinstance(layer, tf.keras.layers.Dense):
            pruned_layer = prune_low_magnitude(layer, pruning_schedule)
            pruned_model.add(pruned_layer)
        elif isinstance(layer, tf.keras.layers.Conv2D):
            pruned_layer = prune_low_magnitude(tf.keras.layers.Conv2D(
                layer.filters,
                layer.kernel_size,
                layer.strides,
                layer.padding,
                layer.data_format,
                layer.dilation_rate,
                layer.groups,
                layer.activation,
                layer.use_bias,
                layer.kernel_initializer,
                layer.bias_initializer,
                layer.kernel_regularizer,
                layer.bias_regularizer,
                layer.activity_regularizer,
                layer.kernel_constraint,
                layer.bias_constraint,
            ), pruning_schedule)
            pruned_model.add(pruned_layer)
        else:
            pruned_model.add(layer)
    
    # This doesn't work too
    # for layer in base_model.layers:
    #     if isinstance(layer, tf.keras.layers.Dense) or isinstance(layer, tf.keras.layers.Conv2D):
    #         pruned_layer = prune_low_magnitude(layer, pruning_schedule)
    #         pruned_model.add(pruned_layer)
    #     else:
    #         pruned_model.add(layer)

    pruned_model.build(input_shape=(None, 224, 224, 3))
    pruned_model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )
    return pruned_model

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions