Skip to content

Commit 46d3daa

Browse files
Updated output config file to be same as config file (#267)
* Updated output config file to be same as config file * Remove yaml from imports since it is not used anymore * Edited to copy config-lock using shutil - Removed write_config_to_output_dir() function - Uses shutil to write to the output dir - Edited so that config-lock.yaml is copies after output_dir is created * Modified to copy config.yaml to output dir in rest of the templates * Satisfy lint --------- Co-authored-by: vfdev <[email protected]>
1 parent e5cd96b commit 46d3daa

File tree

4 files changed

+16
-12
lines changed
  • src/templates

4 files changed

+16
-12
lines changed

src/templates/template-text-classification/main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
from pprint import pformat
3+
from shutil import copy
34
from typing import Any, cast
45

56
import ignite.distributed as idist
6-
import yaml
77
from data import setup_data
88
from ignite.engine import Events
99
from ignite.handlers import LRScheduler, PiecewiseLinear
@@ -27,8 +27,10 @@ def run(local_rank: int, config: Any):
2727
rank = idist.get_rank()
2828
manual_seed(config.seed + rank)
2929

30-
# create output folder
30+
# create output folder and copy config file to output dir
3131
config.output_dir = setup_output_dir(config, rank)
32+
if rank == 0:
33+
copy(config.config, f"{config.output_dir}/config-lock.yaml")
3234

3335
# donwload datasets and create dataloaders
3436
dataloader_train, dataloader_eval = setup_data(config)
@@ -79,7 +81,6 @@ def run(local_rank: int, config: Any):
7981
# print training configurations
8082
logger = setup_logging(config)
8183
logger.info("Configuration: \n%s", pformat(vars(config)))
82-
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
8384
trainer.logger = evaluator.logger = logger
8485

8586
if isinstance(lr_scheduler, PyTorchLRScheduler):

src/templates/template-vision-classification/main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from pprint import pformat
2+
from shutil import copy
23
from typing import Any
34

45
import ignite.distributed as idist
5-
import yaml
66
from data import setup_data
77
from ignite.engine import Events
88
from ignite.handlers import PiecewiseLinear
@@ -19,8 +19,10 @@ def run(local_rank: int, config: Any):
1919
rank = idist.get_rank()
2020
manual_seed(config.seed + rank)
2121

22-
# create output folder
22+
# create output folder and copy config file to output dir
2323
config.output_dir = setup_output_dir(config, rank)
24+
if rank == 0:
25+
copy(config.config, f"{config.output_dir}/config-lock.yaml")
2426

2527
# donwload datasets and create dataloaders
2628
dataloader_train, dataloader_eval = setup_data(config)
@@ -60,7 +62,6 @@ def run(local_rank: int, config: Any):
6062
# print training configurations
6163
logger = setup_logging(config)
6264
logger.info("Configuration: \n%s", pformat(vars(config)))
63-
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
6465
trainer.logger = evaluator.logger = logger
6566

6667
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

src/templates/template-vision-dcgan/main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from pprint import pformat
2+
from shutil import copy
23
from typing import Any
34

45
import ignite.distributed as idist
56
import torch
67
import torchvision.utils as vutils
7-
import yaml
88
from data import setup_data
99
from ignite.engine import Events
1010
from ignite.utils import manual_seed
@@ -22,8 +22,10 @@ def run(local_rank: int, config: Any):
2222
rank = idist.get_rank()
2323
manual_seed(config.seed + rank)
2424

25-
# create output folder
25+
# create output folder and copy config file to output dir
2626
config.output_dir = setup_output_dir(config, rank)
27+
if rank == 0:
28+
copy(config.config, f"{config.output_dir}/config-lock.yaml")
2729

2830
# donwload datasets and create dataloaders
2931
dataloader_train, dataloader_eval, num_channels = setup_data(config)
@@ -77,7 +79,6 @@ def run(local_rank: int, config: Any):
7779
# print training configurations
7880
logger = setup_logging(config)
7981
logger.info("Configuration: \n%s", pformat(vars(config)))
80-
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
8182
trainer.logger = evaluator.logger = logger
8283

8384
# setup ignite handlers

src/templates/template-vision-segmentation/main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from functools import partial
22
from pprint import pformat
3+
from shutil import copy
34
from typing import Any, cast
45

56
import ignite.distributed as idist
6-
import yaml
77
from data import denormalize, setup_data
88
from ignite.engine import Events
99
from ignite.handlers import LRScheduler
@@ -27,8 +27,10 @@ def run(local_rank: int, config: Any):
2727
rank = idist.get_rank()
2828
manual_seed(config.seed + rank)
2929

30-
# create output folder
30+
# create output folder and copy config file to output dir
3131
config.output_dir = setup_output_dir(config, rank)
32+
if rank == 0:
33+
copy(config.config, f"{config.output_dir}/config-lock.yaml")
3234

3335
# donwload datasets and create dataloaders
3436
dataloader_train, dataloader_eval = setup_data(config)
@@ -73,7 +75,6 @@ def run(local_rank: int, config: Any):
7375
# print training configurations
7476
logger = setup_logging(config)
7577
logger.info("Configuration: \n%s", pformat(vars(config)))
76-
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
7778
trainer.logger = evaluator.logger = logger
7879

7980
if isinstance(lr_scheduler, PyTorchLRScheduler):

0 commit comments

Comments
 (0)