|
10 | 10 | import yaml
|
11 | 11 | from ignite.contrib.engines import common
|
12 | 12 | from ignite.engine import Engine
|
| 13 | + |
| 14 | +#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.limit_sec) { :::# |
13 | 15 | from ignite.engine.events import Events
|
14 |
| -from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine |
| 16 | + |
| 17 | +#::: } :::# |
| 18 | +#::: if (it.save_training || it.save_evaluation) { :::# |
| 19 | +from ignite.handlers import ( |
| 20 | + Checkpoint, |
| 21 | + DiskSaver, |
| 22 | + global_step_from_engine, |
| 23 | +) # usort: skip |
| 24 | + |
| 25 | +#::: } else { :::# |
| 26 | +from ignite.handlers import Checkpoint |
| 27 | + |
| 28 | +#::: } :::# |
| 29 | +#::: if (it.patience) { :::# |
15 | 30 | from ignite.handlers.early_stopping import EarlyStopping
|
| 31 | + |
| 32 | +#::: } :::# |
| 33 | +#::: if (it.terminate_on_nan) { :::# |
16 | 34 | from ignite.handlers.terminate_on_nan import TerminateOnNan
|
| 35 | + |
| 36 | +#::: } :::# |
| 37 | +#::: if (it.limit_sec) { :::# |
17 | 38 | from ignite.handlers.time_limit import TimeLimit
|
| 39 | + |
| 40 | +#::: } :::# |
18 | 41 | from ignite.utils import setup_logger
|
19 | 42 |
|
20 | 43 |
|
@@ -141,72 +164,6 @@ def setup_logging(config: Any) -> Logger:
|
141 | 164 | return logger
|
142 | 165 |
|
143 | 166 |
|
144 |
| -#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.limit_sec) { :::# |
145 |
| - |
146 |
| - |
147 |
| -def setup_handlers( |
148 |
| - trainer: Engine, |
149 |
| - evaluator: Engine, |
150 |
| - config: Any, |
151 |
| - to_save_train: Optional[dict] = None, |
152 |
| - to_save_eval: Optional[dict] = None, |
153 |
| -): |
154 |
| - """Setup Ignite handlers.""" |
155 |
| - |
156 |
| - ckpt_handler_train = ckpt_handler_eval = None |
157 |
| - #::: if (it.save_training || it.save_evaluation) { :::# |
158 |
| - # checkpointing |
159 |
| - saver = DiskSaver(config.output_dir / "checkpoints", require_empty=False) |
160 |
| - #::: if (it.save_training) { :::# |
161 |
| - ckpt_handler_train = Checkpoint( |
162 |
| - to_save_train, |
163 |
| - saver, |
164 |
| - filename_prefix=config.filename_prefix, |
165 |
| - n_saved=config.n_saved, |
166 |
| - ) |
167 |
| - trainer.add_event_handler( |
168 |
| - Events.ITERATION_COMPLETED(every=config.save_every_iters), |
169 |
| - ckpt_handler_train, |
170 |
| - ) |
171 |
| - #::: } :::# |
172 |
| - #::: if (it.save_evaluation) { :::# |
173 |
| - global_step_transform = None |
174 |
| - if to_save_train.get("trainer", None) is not None: |
175 |
| - global_step_transform = global_step_from_engine(to_save_train["trainer"]) |
176 |
| - ckpt_handler_eval = Checkpoint( |
177 |
| - to_save_eval, |
178 |
| - saver, |
179 |
| - filename_prefix="best", |
180 |
| - n_saved=config.n_saved, |
181 |
| - global_step_transform=global_step_transform, |
182 |
| - ) |
183 |
| - evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1), ckpt_handler_eval) |
184 |
| - #::: } :::# |
185 |
| - #::: } :::# |
186 |
| - |
187 |
| - #::: if (it.patience) { :::# |
188 |
| - # early stopping |
189 |
| - |
190 |
| - es = EarlyStopping(config.patience, score_fn, trainer) |
191 |
| - evaluator.add_event_handler(Events.EPOCH_COMPLETED, es) |
192 |
| - #::: } :::# |
193 |
| - |
194 |
| - #::: if (it.terminate_on_nan) { :::# |
195 |
| - # terminate on nan |
196 |
| - trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) |
197 |
| - #::: } :::# |
198 |
| - |
199 |
| - #::: if (it.limit_sec) { :::# |
200 |
| - # time limit |
201 |
| - trainer.add_event_handler(Events.ITERATION_COMPLETED, TimeLimit(config.limit_sec)) |
202 |
| - #::: } :::# |
203 |
| - #::: if (it.save_training || it.save_evaluation) { :::# |
204 |
| - return ckpt_handler_train, ckpt_handler_eval |
205 |
| - #::: } :::# |
206 |
| - |
207 |
| - |
208 |
| -#::: } :::# |
209 |
| - |
210 | 167 | #::: if (it.logger) { :::#
|
211 | 168 |
|
212 | 169 |
|
|
0 commit comments