View source
on GitHub
Trains using the tf.keras.Model.fit
training loop.
Inherits From: Trainer
runner.KerasTrainer(
strategy: tf.distribute.Strategy,
*,
model_dir: str,
checkpoint_options: Optional[runner.KerasTrainerCheckpointOptions
] = None,
backup_dir: Optional[str] = None,
steps_per_epoch: Optional[int] = None,
verbose: Union[int, str] = 'auto',
validation_steps: Optional[int] = None,
validation_per_epoch: Optional[int] = None,
validation_freq: Optional[int] = None,
summarize_every_n_steps: Union[int, str] = 500,
checkpoint_every_n_steps: Union[int, str] = 'epoch',
backup_and_restore: bool = True,
callbacks: Optional[Sequence[tf.keras.callbacks.Callback]] = None,
restore_best_weights: Optional[bool] = None,
options: Optional[runner.KerasTrainerOptions
] = None
)
|
strategy
|
A tf.distribute.Strategy.
|
model_dir
|
A model directory for summaries.
|
checkpoint_options
|
An optional configuration for checkpointing related
configs. If checkpoint_options.checkpoint_dir is unset;
os.path.join(model_dir, "ckpnt") is used.
|
backup_dir
|
An optional directory for backup, if unset;
(os.path.join(model_dir, "backup"),) is used.
|
steps_per_epoch
|
The number of training steps per epoch. Optional,
if unspecified: epochs are at tf.data.Dataset end.
|
verbose
|
Forwarded to tf.keras.Model.fit() . Possible values are
0 (silent), 1 (print progress bar), 2 (one line per epoch), and
"auto" (default) defers to keras to select verbosity.
|
validation_steps
|
The number of steps used during validation. Optional,
if unspecified: the entire validation tf.data.Dataset is evaluated.
|
validation_per_epoch
|
The number of validations done per training epoch.
Optional, if unspecified: Perform one validation per training epoch.
Only one of validation_per_epoch and validation_freq can be
specified.
|
validation_freq
|
Specifies how many training epochs to run before a new
validation run is performed. Optional, if unspecified: Performs
validation after every training epoch. Only one of
validation_per_epoch and validation_freq can be specified.
|
summarize_every_n_steps
|
The frequency for writing TensorBoard summaries,
as an integer number of steps, or "epoch" for once per epoch, or
"never".
|
checkpoint_every_n_steps
|
The frequency for writing latest models, as an
integer number of steps, or "epoch" for once per epoch, or "never".
The best model will always be saved after each validation epoch except
when this parameter is set to "never", because the validation metric is
available only after validation epoch.
|
backup_and_restore
|
Whether to backup and restore (According to
tf.keras.callbacks.BackupAndRestore ). The backup
directory is determined by backup_dir .
|
callbacks
|
Optional additional tf.keras.callbacks.Callback for
tf.keras.Model.fit.
|
restore_best_weights
|
Requires a checkpoint_every_n_steps other than
"never." Whether to restore the best model weights as determined by
tf.keras.callbacks.ModelCheckpoint after training. If unspecified,
its value is determined at train(...) invocation: True if
valid_ds_provider is not None else False .
|
options
|
A KerasTrainerOptions.
|
View
source
train(
model_fn: Callable[[], tf.keras.Model],
train_ds_provider: runner.DatasetProvider
,
*,
epochs: int = 1,
valid_ds_provider: Optional[runner.DatasetProvider
] = None
) -> tf.keras.Model
Runs tf.keras.Model.fit
with thetf.distribute.Strategy
provided.
Args |
model_fn
|
A ModelFn , to be invoked in the tf.distribute.Strategty
scope.
|
train_ds_provider
|
A function that returns a tf.data.Dataset for
training.The items of the tf.data.Dataset are pairs
(graph_tensor, label) that represent one batch of per-replica training
inputs after GraphTensor.merge_batch_to_components() has been applied.
|
epochs
|
The epochs to train: adjusted for validation_per_epoch.
|
valid_ds_provider
|
An optional function that returns a tf.data.Dataset
for validation. The items of the tf.data.Dataset are pairs
(graph_tensor, label) that represent one batch of per-replica training
inputs after GraphTensor.merge_batch_to_components() has been applied.
|
Returns |
A trained tf.keras.Model.
|