Skip to content

Latest commit

 

History

History
263 lines (239 loc) · 7.46 KB

KerasTrainer.md

File metadata and controls

263 lines (239 loc) · 7.46 KB

runner.KerasTrainer

View source on GitHub

Trains using the tf.keras.Model.fit training loop.

Inherits From: Trainer

runner.KerasTrainer(
    strategy: tf.distribute.Strategy,
    *,
    model_dir: str,
    checkpoint_options: Optional[runner.KerasTrainerCheckpointOptions] = None,
    backup_dir: Optional[str] = None,
    steps_per_epoch: Optional[int] = None,
    verbose: Union[int, str] = 'auto',
    validation_steps: Optional[int] = None,
    validation_per_epoch: Optional[int] = None,
    validation_freq: Optional[int] = None,
    summarize_every_n_steps: Union[int, str] = 500,
    checkpoint_every_n_steps: Union[int, str] = 'epoch',
    backup_and_restore: bool = True,
    callbacks: Optional[Sequence[tf.keras.callbacks.Callback]] = None,
    restore_best_weights: Optional[bool] = None,
    options: Optional[runner.KerasTrainerOptions] = None
)

Args

strategy A tf.distribute.Strategy.
model_dir A model directory for summaries.
checkpoint_options An optional configuration for checkpointing related configs. If checkpoint_options.checkpoint_dir is unset; os.path.join(model_dir, "ckpnt") is used.
backup_dir An optional directory for backup, if unset; (os.path.join(model_dir, "backup"),) is used.
steps_per_epoch The number of training steps per epoch. Optional, if unspecified: epochs are at tf.data.Dataset end.
verbose Forwarded to tf.keras.Model.fit(). Possible values are 0 (silent), 1 (print progress bar), 2 (one line per epoch), and "auto" (default) defers to keras to select verbosity.
validation_steps The number of steps used during validation. Optional, if unspecified: the entire validation tf.data.Dataset is evaluated.
validation_per_epoch The number of validations done per training epoch. Optional, if unspecified: Perform one validation per training epoch. Only one of validation_per_epoch and validation_freq can be specified.
validation_freq Specifies how many training epochs to run before a new validation run is performed. Optional, if unspecified: Performs validation after every training epoch. Only one of validation_per_epoch and validation_freq can be specified.
summarize_every_n_steps The frequency for writing TensorBoard summaries, as an integer number of steps, or "epoch" for once per epoch, or "never".
checkpoint_every_n_steps The frequency for writing latest models, as an integer number of steps, or "epoch" for once per epoch, or "never". The best model will always be saved after each validation epoch except when this parameter is set to "never", because the validation metric is available only after validation epoch.
backup_and_restore Whether to backup and restore (According to tf.keras.callbacks.BackupAndRestore). The backup directory is determined by backup_dir.
callbacks Optional additional tf.keras.callbacks.Callback for tf.keras.Model.fit.
restore_best_weights Requires a checkpoint_every_n_steps other than "never." Whether to restore the best model weights as determined by tf.keras.callbacks.ModelCheckpoint after training. If unspecified, its value is determined at train(...) invocation: True if valid_ds_provider is not None else False.
options A KerasTrainerOptions.

Attributes

model_dir
strategy

Methods

train

View source

train(
    model_fn: Callable[[], tf.keras.Model],
    train_ds_provider: runner.DatasetProvider,
    *,
    epochs: int = 1,
    valid_ds_provider: Optional[runner.DatasetProvider] = None
) -> tf.keras.Model

Runs tf.keras.Model.fit with thetf.distribute.Strategy provided.

Args
model_fn A ModelFn, to be invoked in the tf.distribute.Strategty scope.
train_ds_provider A function that returns a tf.data.Dataset for training.The items of the tf.data.Dataset are pairs (graph_tensor, label) that represent one batch of per-replica training inputs after GraphTensor.merge_batch_to_components() has been applied.
epochs The epochs to train: adjusted for validation_per_epoch.
valid_ds_provider An optional function that returns a tf.data.Dataset for validation. The items of the tf.data.Dataset are pairs (graph_tensor, label) that represent one batch of per-replica training inputs after GraphTensor.merge_batch_to_components() has been applied.
Returns
A trained tf.keras.Model.