Skip to content

Commit

Permalink
Give the option to terminate the engine without firing Events.COMPLET… (
Browse files Browse the repository at this point in the history
#3309)

* Give the option to terminate the engine without firing Events.COMPLETED. The default behaviour is not changed.

Note that even though Events.COMPLETED is not fired, its timer is updated.

* Update ignite/engine/engine.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/engine/engine.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/engine/engine.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/engine/engine.py

Co-authored-by: vfdev <[email protected]>

* Update ignite/engine/events.py

Co-authored-by: vfdev <[email protected]>

* Argument `skip_event_completed` renamed to `skip_completed`

* - Fixed docs broken links.
- Do not update self.state.times[Events.COMPLETED.name]  if terminated
- Fixed unit test

* Update ignite/engine/engine.py

Co-authored-by: vfdev <[email protected]>

* Refactoring and patching.

- Engine time logging moved out of the if clause. In the log message "completed" has been replaced with "finished" to avoid confusion.
- Same changes applied to the method `_internal_run_legacy()`

* Restored .gitignore

Sorry for accidentally including it into the previous commit!

* Update ignite/engine/events.py

* Fixed typo in test_engine.py

* Parametrized test for engine.terminate(skip_completed)

* Update event table

* Fixed documentation

---------

Co-authored-by: vfdev <[email protected]>
  • Loading branch information
bonassifabio and vfdev-5 authored Dec 3, 2024
1 parent 4f46210 commit 6f8ad2a
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 36 deletions.
41 changes: 28 additions & 13 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self._process_function = process_function
self.last_event_name: Optional[Events] = None
self.should_terminate = False
self.skip_completed_after_termination = False
self.should_terminate_single_epoch = False
self.should_interrupt = False
self.state = State()
Expand Down Expand Up @@ -538,7 +539,7 @@ def call_interrupt():
self.logger.info("interrupt signaled. Engine will interrupt the run after current iteration is finished.")
self.should_interrupt = True

def terminate(self) -> None:
def terminate(self, skip_completed: bool = False) -> None:
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
terminated after the event on which ``terminate`` method was called. The following events are triggered:
Expand All @@ -547,6 +548,9 @@ def terminate(self) -> None:
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`
Args:
skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
:attr:`~ignite.engine.events.Events.TERMINATE`. Default is False.
Examples:
.. testcode::
Expand Down Expand Up @@ -617,9 +621,12 @@ def terminate():
.. versionchanged:: 0.4.10
Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669
.. versionchanged:: 0.5.2
Added `skip_completed` flag
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True
self.skip_completed_after_termination = skip_completed

def terminate_epoch(self) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
Expand Down Expand Up @@ -993,13 +1000,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")

except BaseException as e:
self._dataloader_iter = None
Expand Down Expand Up @@ -1174,13 +1185,17 @@ def _internal_run_legacy(self) -> State:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")

except BaseException as e:
self._dataloader_iter = None
Expand Down
25 changes: 18 additions & 7 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,36 +259,47 @@ class Events(EventEnum):
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
:meth:`~ignite.engine.engine.Engine.terminate()` call.
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- TERMINATE : triggered when the run is about to end completely,
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- COMPLETED : triggered when engine's run is completed
- COMPLETED : triggered when engine's run is completed or terminated with
:meth:`~ignite.engine.engine.Engine.terminate()`, unless the flag
`skip_completed` is set to True.
The table below illustrates which events are triggered when various termination methods are called.
.. list-table::
:widths: 24 25 33 18
:widths: 35 38 28 20 20
:header-rows: 1
* - Method
- EVENT_COMPLETED
- TERMINATE_SINGLE_EPOCH
- EPOCH_COMPLETED
- TERMINATE
- COMPLETED
* - no termination
- ✔
- ✗
- ✔
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()`
- ✔
- ✔
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate()`
- ✗
- ✔
- ✔
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate()` with `skip_completed=True`
- ✗
- ✔
- ✔
- ✗
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine:
Expand Down Expand Up @@ -357,7 +368,7 @@ class CustomEvents(EventEnum):
STARTED = "started"
"""triggered when engine's run is started."""
COMPLETED = "completed"
"""triggered when engine's run is completed"""
"""triggered when engine's run is completed, or after receiving terminate() call."""

ITERATION_STARTED = "iteration_started"
"""triggered when an iteration is started."""
Expand Down
1 change: 0 additions & 1 deletion tests/ignite/contrib/engines/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.utils.data.distributed import DistributedSampler

import ignite.distributed as idist

import ignite.handlers as handlers
from ignite.contrib.engines.common import (
_setup_logging,
Expand Down
45 changes: 30 additions & 15 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ class TestEngine:
def set_interrupt_resume_enabled(self, interrupt_resume_enabled):
Engine.interrupt_resume_enabled = interrupt_resume_enabled

def test_terminate(self):
@pytest.mark.parametrize("skip_completed", [True, False])
def test_terminate(self, skip_completed):
engine = Engine(lambda e, b: 1)
assert not engine.should_terminate
engine.terminate()
assert not engine.skip_completed_after_termination
engine.terminate(skip_completed)
assert engine.should_terminate
assert engine.skip_completed_after_termination == skip_completed

def test_invalid_process_raises_with_invalid_signature(self):
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
Expand Down Expand Up @@ -236,25 +239,32 @@ def check_iter_and_data():
assert num_calls_check_iter_epoch == 1

@pytest.mark.parametrize(
"terminate_event, e, i",
"terminate_event, e, i, skip_completed",
[
(Events.STARTED, 0, 0),
(Events.EPOCH_STARTED(once=2), 2, None),
(Events.EPOCH_COMPLETED(once=2), 2, None),
(Events.GET_BATCH_STARTED(once=12), None, 12),
(Events.GET_BATCH_COMPLETED(once=12), None, 12),
(Events.ITERATION_STARTED(once=14), None, 14),
(Events.ITERATION_COMPLETED(once=14), None, 14),
(Events.STARTED, 0, 0, True),
(Events.EPOCH_STARTED(once=2), 2, None, True),
(Events.EPOCH_COMPLETED(once=2), 2, None, True),
(Events.GET_BATCH_STARTED(once=12), None, 12, True),
(Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
(Events.ITERATION_STARTED(once=14), None, 14, True),
(Events.ITERATION_COMPLETED(once=14), None, 14, True),
(Events.STARTED, 0, 0, False),
(Events.EPOCH_STARTED(once=2), 2, None, False),
(Events.EPOCH_COMPLETED(once=2), 2, None, False),
(Events.GET_BATCH_STARTED(once=12), None, 12, False),
(Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
(Events.ITERATION_STARTED(once=14), None, 14, False),
(Events.ITERATION_COMPLETED(once=14), None, 14, False),
],
)
def test_terminate_events_sequence(self, terminate_event, e, i):
def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 5

@engine.on(terminate_event)
def call_terminate():
engine.terminate()
engine.terminate(skip_completed)

@engine.on(Events.EXCEPTION_RAISED)
def assert_no_exceptions(ee):
Expand All @@ -271,10 +281,15 @@ def assert_no_exceptions(ee):
if e is None:
e = i // len(data) + 1

if skip_completed:
assert engine.called_events[-1] == (e, i, Events.TERMINATE)
assert engine.called_events[-2] == (e, i, terminate_event)
else:
assert engine.called_events[-1] == (e, i, Events.COMPLETED)
assert engine.called_events[-2] == (e, i, Events.TERMINATE)
assert engine.called_events[-3] == (e, i, terminate_event)

assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine.called_events[-1] == (e, i, Events.COMPLETED)
assert engine.called_events[-2] == (e, i, Events.TERMINATE)
assert engine.called_events[-3] == (e, i, terminate_event)
assert engine._dataloader_iter is None

@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
Expand Down

0 comments on commit 6f8ad2a

Please sign in to comment.