Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309))


- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).



---
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
"""Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.
Learning rate scheduler will still be stepped at the end of epoch.

Args:
batch: The batched data as it is returned by the training DataLoader.
Expand Down
41 changes: 24 additions & 17 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,30 +325,33 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
trainer._logger_connector.on_batch_start(batch)

batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
should_skip_rest_of_epoch = False

if batch is None and not using_dataloader_iter:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
else:
# hook
call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx)
call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration

self.batch_progress.increment_started()

kwargs = (
self._build_kwargs(OrderedDict(), batch, batch_idx)
if not using_dataloader_iter
else OrderedDict(any=dataloader_iter)
)
with trainer.profiler.profile("run_training_batch"):
if trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
else:
batch_output = self.manual_optimization.run(kwargs)
should_skip_rest_of_epoch = response == -1
# Signal this is the last batch for the current epoch
if should_skip_rest_of_epoch:
self.batch_progress.increment_by(0, is_last_batch=True)
else:
self.batch_progress.increment_started()

kwargs = (
self._build_kwargs(OrderedDict(), batch, batch_idx)
if not using_dataloader_iter
else OrderedDict(any=dataloader_iter)
)
with trainer.profiler.profile("run_training_batch"):
if trainer.lightning_module.automatic_optimization:
# in automatic optimization, there can only be one optimizer
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
else:
batch_output = self.manual_optimization.run(kwargs)

self.batch_progress.increment_processed()

Expand All @@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
if self._num_ready_batches_reached():
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

if should_skip_rest_of_epoch:
# Only raise StopIteration now so that the training epoch loop can finish
raise StopIteration

if using_dataloader_iter:
# update the hook kwargs now that the step method might have consumed the iterator
batch = data_fetcher._batch
Expand Down
25 changes: 25 additions & 0 deletions tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def on_train_batch_start(self, batch, batch_idx):
assert trainer.fit_loop.batch_idx == batch_idx_
assert trainer.global_step == batch_idx_ * max_epochs

assert trainer.is_last_batch


def test_should_stop_mid_epoch(tmp_path):
"""Test that training correctly stops mid epoch and that validation is still called at the right time."""
Expand Down Expand Up @@ -305,3 +307,26 @@ def test_eval_mode_warning(tmp_path, warn):
w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message)
]
assert len(eval_warnings) == 0, "Expected no eval mode warnings"


@pytest.mark.parametrize(("max_epochs", "batch_idx_"), [(2, 5), (3, 8)])
def test_lr_updated_on_train_batch_start_returns_minus_one(tmp_path, max_epochs, batch_idx_):
"""Test that when the rest of the epoch is skipped, due to on_train_batch_start returning -1, the learning rate is
still updated when it should, at the end of the epoch."""

class TestModel(BoringModel):
def on_train_batch_start(self, batch, batch_idx):
if batch_idx == batch_idx_:
return -1
return super().on_train_batch_start(batch, batch_idx)

model = TestModel()
init_lr = 0.1
trainer = Trainer(default_root_dir=tmp_path, limit_train_batches=10, max_epochs=max_epochs)
trainer.fit(model)

adjusted_lr = [pg["lr"] for pg in trainer.optimizers[0].param_groups]

assert len(trainer.lr_scheduler_configs) == 1
assert all(a == adjusted_lr[0] for a in adjusted_lr)
assert init_lr * 0.1**max_epochs == adjusted_lr[0]
Loading