Skip to content

Pruning fails for siamese networks #270

@digitalheir

Description

@digitalheir

Describe the bug
Follow along with MNIST siamese, where one set of weights is used twice in the same network. Try to make one layer prune, get error: tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [Prune() wrapper requires the UpdatePruningStep callback to be provided during training. Please add it as a callback to your model.fit call.] [Condition x >= y did not hold element-wise:x (assert_greater_equal/ReadVariableOp:0) = ]. Even though you supply the pruning callback to fit():

System information

TensorFlow installed from (source or binary):
binary

TensorFlow version:
2.1.0

TensorFlow Model Optimization version:
0.2.1

Python version:
3.6.10

Describe the expected behavior
Script doesn't crash

Describe the current behavior
Script crashes

Code to reproduce the issue

import numpy as np
import tensorflow as tf

from tensorflow.keras.datasets import mnist
from tensorflow import Variable, float32
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, Lambda, Conv2D, MaxPooling2D
from tensorflow.keras.models import Model
from tensorflow_model_optimization.sparsity import keras as sparsity

def euclidean_distance(vects):
    x, y = vects
    sum_square = keras_backend.sum(keras_backend.square(x - y), axis=1, keepdims=True)
    return keras_backend.sqrt(keras_backend.maximum(sum_square, keras_backend.epsilon()))


def eucl_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return shape1[0], 1

def random_other_digit(digit, max_num):
    inc = random.randrange(1, max_num)
    other_digit = (digit + inc) % max_num
    return other_digit

def create_pairs(x, digit_indices_in_dataset, num_classes=10): 
    """Massage input from MNIST
    """
    pairs = []
    labels = []
    minimal_digit_set_size = min([len(digit_indices_in_dataset[d]) for d in range(num_classes)]) - 1
    if minimal_digit_set_size <= 0:
        raise ValueError("Impossible ", minimal_digit_set_size)
    for digit in range(num_classes):
        for i in range(minimal_digit_set_size):
            indices_for_digit = digit_indices_in_dataset[digit]

            index_for_digit = indices_for_digit[i]
            digit_image_1 = x[index_for_digit]

            digit_index_same = indices_for_digit[i + 1]
            digit_image_same = x[digit_index_same]
            pair_same_digits = [digit_image_1, digit_image_same]
            pairs.append(pair_same_digits)

            other_digit = random_other_digit(digit, num_classes)
            digit_image_other = x[digit_indices_in_dataset[other_digit][i]]
            pair_different_digits = [digit_image_1, digit_image_other]
            pairs.append(pair_different_digits)

            # [Same, Different]
            labels += [1, 0]
    return np.array(pairs), np.array(labels)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train[0:5000]
y_train = y_train[0:5000]

x_test = x_test[0:5000]
y_test = y_test[0:5000]

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]

tr_pairs, tr_y = create_pairs(x_train, digit_indices, num_classes)
print("Pairs created")

print("Creating testing pairs...")
digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices, num_classes)
def create_base_network_pruned(input_shape, begin_step, end_step):
    """Base network to be shared (eq. to feature extraction).
    """
    pruning_params = {
        'pruning_schedule': sparsity.PolynomialDecay(initial_sparsity=0.50,
                                                     final_sparsity=0.90,
                                                     begin_step=begin_step,
                                                     end_step=end_step,
                                                     frequency=100)
    }
    input_base = Input(shape=input_shape)
    x = Flatten()(input_base)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = sparsity.prune_low_magnitude(Dense(128, activation='relu'),
                                     **pruning_params)(x)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x)
    return Model(input_base, x)

def siamese_dense_pruned(input_shape, begin_step, end_step):
    base_network = create_base_network_pruned(input_shape, begin_step, end_step)
    input_a = Input(shape=input_shape)
    input_b = Input(shape=input_shape)
    # because we re-use the same instance `base_network`,
    # the weights of the network
    # will be shared across the two branches
    processed_a = base_network(input_a)
    processed_b = base_network(input_b)
    distance = Lambda(euclidean_distance,
                      output_shape=eucl_dist_output_shape)([processed_a, processed_b])
    return Model([input_a, input_b], distance)

epochs = 20
batch_size = 128
num_train_samples = x_train.shape[0]
end_step = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * epochs
begin_step = end_step / 5
model = siamese_dense_pruned(input_shape, begin_step, end_step)
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
callbacks = [
       sparsity.UpdatePruningStep(),
       sparsity.PruningSummaries(log_dir=logdir, profile_batch=0)
]  
model.fit(training_pairs, tr_y,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1
              validation_data=(testing_pairs, te_y),
              callbacks=callbacks)

Additional context
Might have something to do with referencing the same layers twice in a model.

Metadata

Metadata

Assignees

Labels

feature requestfeature requestpriority:lowLow priority when applied. Intentionally open with no assignee or contributors welcome label.technique:pruningRegarding tfmot.sparsity.keras APIs and docs

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions