Skip to content

Pruning: improve custom training loop API #271

@alanchiao

Description

@alanchiao

The recommended path for pruning with a custom training loop is not as simple as it could be.

pruned_model = setup_pruned_model()

loss = tf.keras.losses.categorical_crossentropy
optimizer = keras.optimizers.Adam()

log_dir = tempfile.mkdtemp()

# This is all not boilerplate.
pruned_model.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(pruned_model)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # optional Tensorboard logging.
log_callback.set_model(pruned_model)

step_callback.on_train_begin()
for _ in range(3):
    # only one batch given batch_size = 20 and input shape.
    step_callback.on_train_batch_begin(batch=unused_arg)
    inp = np.reshape(x_train,
                     [self._BATCH_SIZE, 10])  # original shape: from [10].
    with tf.GradientTape() as tape:
      logits = pruned_model(inp, training=True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, pruned_model.trainable_variables)
      optimizer.apply_gradients(zip(grads, pruned_model.trainable_variables))

    step_callback.on_epoch_end(batch=unused_arg)
    log_callback.on_epoch_end(batch=unused_arg)
...

The set_model and pruned_model.optimizer setting is unusual and could be missed.

Metadata

Metadata

Assignees

Labels

feature requestfeature requestpriority:lowLow priority when applied. Intentionally open with no assignee or contributors welcome label.technique:pruningRegarding tfmot.sparsity.keras APIs and docs

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions