-
-
Notifications
You must be signed in to change notification settings - Fork 634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to set early stopping not enabled during the first n epochs #3317
Comments
@5o1 you can filter the events on which early stopping handler is applied. For example if you add the early stopping handler in the following: es_handler = EarlyStopping(..., trainer=trainer)
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
def training_epoch_filtering(*args):
if trainer.state.epoch > 50:
return True
return False
evaluator.add_event_handler(Events.COMPLETED(event_filter=training_epoch_filtering), es_handler) Here is a working example: from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging
from ignite.handlers import EarlyStopping
train_data = range(5)
eval_data = range(4)
max_epochs = 20
def train_step(engine, batch):
pass
trainer = Engine(train_step)
evaluator = Engine(lambda e, b: None)
mock_accuracy = [
0.3,
0.26,
0.22,
0.25,
0.5,
0.6,
0.7,
]
@trainer.on(Events.EPOCH_COMPLETED)
def run_validation():
evaluator.run(eval_data)
# mock validation accuracy score
evaluator.state.metrics["accuracy"] = mock_accuracy[min(trainer.state.epoch, len(mock_accuracy) - 1)]
print(
f"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - run validation: "
f"accruacy={evaluator.state.metrics['accuracy']}",
flush=True
)
# Optionally, we define the score function to print when EarlyStopping is called
def score_function(engine):
print("Called EarlyStopping score function")
return engine.state.metrics["accuracy"]
es_handler = EarlyStopping(
patience=3,
score_function=score_function,
trainer=trainer
)
# This will report early stopping patience counter incrementation:
# For example, ES DEBUG: EarlyStopping: 1 / 3
es_handler.logger = setup_logger("ES", level=logging.DEBUG)
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
def training_epoch_filtering(*args):
if trainer.state.epoch > 5:
return True
return False
evaluator.add_event_handler(Events.COMPLETED(event_filter=training_epoch_filtering), es_handler)
trainer.run(train_data, max_epochs=max_epochs) Output:
We can see that early stopping was not called before trainer epoch less than 5 and then the training was stopped before the max epoch as validation accuracy was mocked such that it stalls at 0.7 since the epoch 7 |
It works. Thank you. |
Because metrics instability greatly during the first 50 epochs, I want to disable early stopping during this time.
I checked the class of early stopping parameter and it doesn't seem to provide this function.
How do I achieve my needs?
The text was updated successfully, but these errors were encountered: