-
Notifications
You must be signed in to change notification settings - Fork 334
Open
Labels
feature requestfeature requestfeature requestpriority:lowLow priority when applied. Intentionally open with no assignee or contributors welcome label.Low priority when applied. Intentionally open with no assignee or contributors welcome label.technique:pruningRegarding tfmot.sparsity.keras APIs and docsRegarding tfmot.sparsity.keras APIs and docs
Description
The recommended path for pruning with a custom training loop is not as simple as it could be.
pruned_model = setup_pruned_model()
loss = tf.keras.losses.categorical_crossentropy
optimizer = keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
# This is all not boilerplate.
pruned_model.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(pruned_model)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # optional Tensorboard logging.
log_callback.set_model(pruned_model)
step_callback.on_train_begin()
for _ in range(3):
# only one batch given batch_size = 20 and input shape.
step_callback.on_train_batch_begin(batch=unused_arg)
inp = np.reshape(x_train,
[self._BATCH_SIZE, 10]) # original shape: from [10].
with tf.GradientTape() as tape:
logits = pruned_model(inp, training=True)
loss_value = loss(y_train, logits)
grads = tape.gradient(loss_value, pruned_model.trainable_variables)
optimizer.apply_gradients(zip(grads, pruned_model.trainable_variables))
step_callback.on_epoch_end(batch=unused_arg)
log_callback.on_epoch_end(batch=unused_arg)
...
The set_model and pruned_model.optimizer setting is unusual and could be missed.
BogdanDidenko
Metadata
Metadata
Assignees
Labels
feature requestfeature requestfeature requestpriority:lowLow priority when applied. Intentionally open with no assignee or contributors welcome label.Low priority when applied. Intentionally open with no assignee or contributors welcome label.technique:pruningRegarding tfmot.sparsity.keras APIs and docsRegarding tfmot.sparsity.keras APIs and docs