Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to set early stopping not enabled during the first n epochs #3317

Closed
5o1 opened this issue Dec 24, 2024 · 2 comments
Closed

How to set early stopping not enabled during the first n epochs #3317

5o1 opened this issue Dec 24, 2024 · 2 comments
Labels

Comments

@5o1
Copy link

5o1 commented Dec 24, 2024

Because metrics instability greatly during the first 50 epochs, I want to disable early stopping during this time.

I checked the class of early stopping parameter and it doesn't seem to provide this function.

How do I achieve my needs?

@5o1 5o1 added the question label Dec 24, 2024
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Dec 24, 2024

@5o1 you can filter the events on which early stopping handler is applied. For example if you add the early stopping handler in the following:

es_handler = EarlyStopping(..., trainer=trainer)

# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
def training_epoch_filtering(*args):
    if trainer.state.epoch > 50:
        return True
    return False

evaluator.add_event_handler(Events.COMPLETED(event_filter=training_epoch_filtering), es_handler)

Here is a working example:

from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging
from ignite.handlers import EarlyStopping


train_data = range(5)
eval_data = range(4)
max_epochs = 20


def train_step(engine, batch):
    pass

trainer = Engine(train_step)


evaluator = Engine(lambda e, b: None)


mock_accuracy = [
    0.3,
    0.26,
    0.22,
    0.25,
    0.5,
    0.6,
    0.7,
]



@trainer.on(Events.EPOCH_COMPLETED)
def run_validation():
    evaluator.run(eval_data)
    # mock validation accuracy score
    evaluator.state.metrics["accuracy"] = mock_accuracy[min(trainer.state.epoch, len(mock_accuracy) - 1)]
    print(
        f"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - run validation: "
        f"accruacy={evaluator.state.metrics['accuracy']}",
        flush=True
      )

# Optionally, we define the score function to print when EarlyStopping is called
def score_function(engine):
    print("Called EarlyStopping score function")
    return engine.state.metrics["accuracy"]


es_handler = EarlyStopping(
    patience=3, 
    score_function=score_function, 
    trainer=trainer
)
# This will report early stopping patience counter incrementation:
# For example, ES DEBUG: EarlyStopping: 1 / 3
es_handler.logger = setup_logger("ES", level=logging.DEBUG)


# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
def training_epoch_filtering(*args):
    if trainer.state.epoch > 5:
        return True
    return False

evaluator.add_event_handler(Events.COMPLETED(event_filter=training_epoch_filtering), es_handler)
    
trainer.run(train_data, max_epochs=max_epochs)

Output:

1 / 20 | 5 - run validation: accruacy=0.26
2 / 20 | 10 - run validation: accruacy=0.22
3 / 20 | 15 - run validation: accruacy=0.25
4 / 20 | 20 - run validation: accruacy=0.5
5 / 20 | 25 - run validation: accruacy=0.6
Called EarlyStopping score function
6 / 20 | 30 - run validation: accruacy=0.7
Called EarlyStopping score function
7 / 20 | 35 - run validation: accruacy=0.7

2024-12-24 12:51:02,524 ES DEBUG: EarlyStopping: 1 / 3

Called EarlyStopping score function
8 / 20 | 40 - run validation: accruacy=0.7

2024-12-24 12:51:02,527 ES DEBUG: EarlyStopping: 2 / 3

Called EarlyStopping score function
9 / 20 | 45 - run validation: accruacy=0.7

2024-12-24 12:51:02,533 ES DEBUG: EarlyStopping: 3 / 3
2024-12-24 12:51:02,534 ES INFO: EarlyStopping: Stop training

Called EarlyStopping score function
10 / 20 | 50 - run validation: accruacy=0.7

State:
	iteration: 50
	epoch: 10
	epoch_length: 5
	max_epochs: 20
	output: <class 'NoneType'>
	batch: 4
	metrics: <class 'dict'>
	dataloader: <class 'range'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

We can see that early stopping was not called before trainer epoch less than 5 and then the training was stopped before the max epoch as validation accuracy was mocked such that it stalls at 0.7 since the epoch 7

@5o1
Copy link
Author

5o1 commented Dec 24, 2024

@5o1 you can filter the events on which early stopping handler is applied. For example if you add the early stopping handler in the following:@5o1您可以过滤应用了提前停止处理程序的事件。例如,如果您在下面添加提前停止处理程序:

es_handler = EarlyStopping(..., trainer=trainer)

# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
def training_epoch_filtering(*args):
    if trainer.state.epoch > 50:
        return True
    return False

evaluator.add_event_handler(Events.COMPLETED(event_filter=training_epoch_filtering), es_handler)

Here is a working example:下面是一个工作示例:

from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging
from ignite.handlers import EarlyStopping


train_data = range(5)
eval_data = range(4)
max_epochs = 20


def train_step(engine, batch):
    pass

trainer = Engine(train_step)


evaluator = Engine(lambda e, b: None)


mock_accuracy = [
    0.3,
    0.26,
    0.22,
    0.25,
    0.5,
    0.6,
    0.7,
]



@trainer.on(Events.EPOCH_COMPLETED)
def run_validation():
    evaluator.run(eval_data)
    # mock validation accuracy score
    evaluator.state.metrics["accuracy"] = mock_accuracy[min(trainer.state.epoch, len(mock_accuracy) - 1)]
    print(
        f"{trainer.state.epoch} / {trainer.state.max_epochs} | {trainer.state.iteration} - run validation: "
        f"accruacy={evaluator.state.metrics['accuracy']}",
        flush=True
      )

# Optionally, we define the score function to print when EarlyStopping is called
def score_function(engine):
    print("Called EarlyStopping score function")
    return engine.state.metrics["accuracy"]


es_handler = EarlyStopping(
    patience=3, 
    score_function=score_function, 
    trainer=trainer
)
# This will report early stopping patience counter incrementation:
# For example, ES DEBUG: EarlyStopping: 1 / 3
es_handler.logger = setup_logger("ES", level=logging.DEBUG)


# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
def training_epoch_filtering(*args):
    if trainer.state.epoch > 5:
        return True
    return False

evaluator.add_event_handler(Events.COMPLETED(event_filter=training_epoch_filtering), es_handler)
    
trainer.run(train_data, max_epochs=max_epochs)

Output:输出量:

1 / 20 | 5 - run validation: accruacy=0.26
2 / 20 | 10 - run validation: accruacy=0.22
3 / 20 | 15 - run validation: accruacy=0.25
4 / 20 | 20 - run validation: accruacy=0.5
5 / 20 | 25 - run validation: accruacy=0.6
Called EarlyStopping score function
6 / 20 | 30 - run validation: accruacy=0.7
Called EarlyStopping score function
7 / 20 | 35 - run validation: accruacy=0.7

2024-12-24 12:51:02,524 ES DEBUG: EarlyStopping: 1 / 3

Called EarlyStopping score function
8 / 20 | 40 - run validation: accruacy=0.7

2024-12-24 12:51:02,527 ES DEBUG: EarlyStopping: 2 / 3

Called EarlyStopping score function
9 / 20 | 45 - run validation: accruacy=0.7

2024-12-24 12:51:02,533 ES DEBUG: EarlyStopping: 3 / 3
2024-12-24 12:51:02,534 ES INFO: EarlyStopping: Stop training

Called EarlyStopping score function
10 / 20 | 50 - run validation: accruacy=0.7

State:
	iteration: 50
	epoch: 10
	epoch_length: 5
	max_epochs: 20
	output: <class 'NoneType'>
	batch: 4
	metrics: <class 'dict'>
	dataloader: <class 'range'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

We can see that early stopping was not called before trainer epoch less than 5 and then the training was stopped before the max epoch as validation accuracy was mocked such that it stalls at 0.7 since the epoch 7我们可以看到,在训练器epoch小于5之前没有调用早期停止,然后在最大epoch之前停止训练,因为验证准确性被嘲笑,使得它从epoch 7开始停滞在0.7

It works. Thank you.

@5o1 5o1 closed this as completed Dec 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants