Skip to content

Commit

Permalink
Rewritten Engine's terminate and terminate_epoch logic (#2645)
Browse files Browse the repository at this point in the history
* Added test_engine_run_resume

* Terminate/Terminate Single Epoch work on all EPOCH/ITERATION events

* - terminate() work on all events, called on catched _EngineTerminateException
- terminate_epoch work on iteration-based events, called on catched  _EngineTerminateSingleEpochExpection
- Fixed issue when attaching handlers on Events.TERMINATE_SINGLE_EPOCH

* Updated docstring

* Fixed issue with max_iters handling

* Fixed issue with _EngineTerminateException handled as a general exception

* Updated tests and docs

Co-authored-by: Sadra Barikbin <[email protected]>
  • Loading branch information
vfdev-5 and sadra-barikbin authored Aug 23, 2022
1 parent 48364bd commit 32ba9cd
Show file tree
Hide file tree
Showing 3 changed files with 346 additions and 56 deletions.
138 changes: 101 additions & 37 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,12 @@ def execute_something():

self._assert_allowed_event(event_name)

event_args = (Exception(),) if event_name == Events.EXCEPTION_RAISED else ()
event_args = () # type: Tuple[Any, ...]
if event_name == Events.EXCEPTION_RAISED:
event_args += (Exception(),)
elif event_name == Events.TERMINATE_SINGLE_EPOCH:
event_args += (0,)

try:
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
Expand Down Expand Up @@ -433,14 +438,28 @@ def fire_event(self, event_name: Any) -> None:
return self._fire_event(event_name)

def terminate(self) -> None:
"""Sends terminate signal to the engine, so that it terminates completely the run after
the current iteration."""
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
terminated after the event on which ``terminate`` method was called. The following events are triggered:
- ...
- Terminating event
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True

def terminate_epoch(self) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch
after the current iteration."""
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
continues from the next epoch. The following events are triggered:
- ...
- Event on which ``terminate_epoch`` method is called
- :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
- :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
- ...
"""
self.logger.info(
"Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished."
Expand Down Expand Up @@ -742,33 +761,43 @@ def _internal_run(self) -> State:
self.should_terminate = self.should_terminate_single_epoch = False
self._init_timers(self.state)
try:
start_time = time.time()
self._fire_event(Events.STARTED)
while not self._is_done(self.state) and not self.should_terminate:
self.state.epoch += 1
self._fire_event(Events.EPOCH_STARTED)

if self._dataloader_iter is None:
self._setup_engine()

time_taken = self._run_once_on_dataset()
# time is available for handlers but must be update after fire
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
handlers_start_time = time.time()
if self.should_terminate:
self._fire_event(Events.TERMINATE)
else:
try:
start_time = time.time()
self._fire_event(Events.STARTED)
self._maybe_terminate()

while not self._is_done(self.state) and not self.should_terminate:
self.state.epoch += 1
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_STARTED)
epoch_time_taken = time.time() - handlers_start_time
self._maybe_terminate()

if self._dataloader_iter is None:
self._setup_engine()

epoch_time_taken += self._run_once_on_dataset()

# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
if self.should_terminate:
break
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
self._maybe_terminate()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
self.logger.info(
f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}"
)

except _EngineTerminateException:
self._fire_event(Events.TERMINATE)

time_taken = time.time() - start_time
# time is available for handlers but must be update after fire
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
Expand All @@ -786,6 +815,13 @@ def _internal_run(self) -> State:
self._dataloader_iter = None
return self.state

def _maybe_terminate(self) -> None:
if self.should_terminate:
raise _EngineTerminateException()

if self.should_terminate_single_epoch:
raise _EngineTerminateSingleEpochException()

def _run_once_on_dataset(self) -> float:
start_time = time.time()

Expand All @@ -805,8 +841,12 @@ def _run_once_on_dataset(self) -> float:
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
self._fire_event(Events.GET_BATCH_STARTED)
self._maybe_terminate()

self.state.batch = next(self._dataloader_iter)
self._fire_event(Events.GET_BATCH_COMPLETED)
self._maybe_terminate()

iter_counter += 1
should_exit = False
except StopIteration:
Expand Down Expand Up @@ -835,29 +875,37 @@ def _run_once_on_dataset(self) -> float:
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
self._setup_dataloader_iter()
self._maybe_terminate()

self._setup_dataloader_iter()
should_exit = True

continue

self.state.iteration += 1
self._fire_event(Events.ITERATION_STARTED)
self._maybe_terminate()

self.state.output = self._process_function(self, self.state.batch)
self._fire_event(Events.ITERATION_COMPLETED)

if self.should_terminate or self.should_terminate_single_epoch:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self._setup_dataloader_iter()
break
self._maybe_terminate()

if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
break

if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
self.should_terminate = True
break
raise _EngineTerminateException()

except _EngineTerminateSingleEpochException:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self._setup_dataloader_iter()

except _EngineTerminateException as e:
# we need to reraise this exception such that it is not handled
# as a general exception by the code below
raise e

except Exception as e:
self.logger.error(f"Current run is terminating due to exception: {e}")
Expand All @@ -870,3 +918,19 @@ def _get_none_data_iter(size: int) -> Iterator:
# Sized iterator for data as None
for _ in range(size):
yield None


class _EngineTerminateSingleEpochException(Exception):
"""
Exception associated with Terminate Single Epoch event
"""

pass


class _EngineTerminateException(Exception):
"""
Exception associated with Terminate event
"""

pass
2 changes: 1 addition & 1 deletion ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class CustomEvents(EventEnum):
"""triggered when the run is about to end completely, after receiving terminate() call."""
TERMINATE_SINGLE_EPOCH = "terminate_single_epoch"
"""triggered when the run is about to end the current epoch,
after receiving a terminate_epoch() or terminate() call."""
after receiving a terminate_epoch() call."""

def __or__(self, other: Any) -> "EventsList":
return EventsList() | self | other
Expand Down
Loading

0 comments on commit 32ba9cd

Please sign in to comment.