Skip to content

Commit 7253c79

Browse files
committed
Switch from explicit eval dataset to a train/eval split ratio and add configuration file support.
Initial config files for hyu's pdvd and renny's pdhd datasets
1 parent 1f6e471 commit 7253c79

File tree

4 files changed

+98
-49
lines changed

4 files changed

+98
-49
lines changed

wirecell/dnn/__main__.py

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import click
44
import torch
5+
from torch.utils.data import DataLoader
6+
57
from pathlib import Path
6-
from wirecell.util.cli import context, log, jsonnet_loader
8+
from wirecell.util.cli import context, log, jsonnet_loader, anyconfig_file
79
from wirecell.util.paths import unglob, listify
810

911

@@ -17,16 +19,23 @@ def cli(ctx):
1719
'''
1820
pass
1921

22+
@cli.command('dump-config')
23+
@anyconfig_file("wirecelldnn")
24+
@click.pass_context
25+
def dump_config(ctx, config):
26+
print(config)
27+
28+
return
29+
30+
31+
train_defaults = dict(epochs=1, batch=1, device='cpu', name='dnnroi', train_ratio=0.8)
2032
@cli.command('train')
21-
@click.option("-c", "--config",
22-
type=click.Path(),
23-
help="Set configuration file")
24-
@click.option("-e", "--epochs", default=1,
33+
@click.option("-e", "--epochs", default=None, type=int,
2534
help="Number of epochs over which to train. "
2635
"This is a relative count if the training starts with a -l/--load'ed state.")
27-
@click.option("-b", "--batch", default=1,
36+
@click.option("-b", "--batch", default=None, type=int,
2837
help="Batch size")
29-
@click.option("-d", "--device", default='cpu',
38+
@click.option("-d", "--device", default=None, type=str,
3039
help="The compute device")
3140
@click.option("--cache/--no-cache", is_flag=True, default=False,
3241
help="Cache data in memory")
@@ -38,33 +47,40 @@ def cli(ctx):
3847
@click.option("--checkpoint-modulus", default=1,
3948
help="Checkpoint modulus. "
4049
"If checkpoint path is given, the training is checkpointed ever this many epochs..")
41-
@click.option("-n", "--name", default='dnnroi',
42-
help="The application name (def=dnnroi)")
50+
@click.option("-a", "--app", default=None, type=str,
51+
help="The application name")
4352
@click.option("-l", "--load", default=None,
4453
help="File name providing the initial model state dict (def=None - construct fresh)")
4554
@click.option("-s", "--save", default=None,
4655
help="File name to save model state dict after training (def=None - results not saved)")
47-
@click.option("--eval-files", multiple=True, type=str, # fixme: remove this in favor of a single file set and a train/eval partitioning
48-
help="File path or globs as comma separated list to use for evaluation dataset")
49-
@click.argument("train_files", nargs=-1)
56+
@click.option("--train-ratio", default=None, type=float,
57+
help="Fraction of samples to use for training (default=1.0, no evaluation loss calculated)")
58+
@anyconfig_file("wirecelldnn", section='train', defaults=train_defaults)
59+
@click.argument("files", nargs=-1)
5060
@click.pass_context
5161
def train(ctx, config, epochs, batch, device, cache, debug_torch,
5262
checkpoint_save, checkpoint_modulus,
53-
name, load, save, eval_files, train_files):
63+
app, load, save, train_ratio, files):
5464
'''
5565
Train a model.
5666
'''
57-
if not train_files:
67+
68+
if not files: # args not processed by anyconfig_files
69+
try:
70+
files = config['train']['files']
71+
except KeyError:
72+
files = None
73+
if not files:
5874
raise click.BadArgumentUsage("no training files given")
59-
train_files = unglob(listify(train_files))
60-
log.info(f'training files: {train_files}')
75+
files = unglob(listify(files))
76+
log.info(f'training files: {files}')
6177

6278
if device == 'gpu': device = 'cuda'
63-
log.info(f'using device {device}')
6479

6580
if debug_torch:
6681
torch.autograd.set_detect_anomaly(True)
6782

83+
name = app
6884
app = getattr(dnn.apps, name)
6985

7086
net = app.Network()
@@ -78,24 +94,17 @@ def train(ctx, config, epochs, batch, device, cache, debug_torch,
7894
raise click.FileError(load, 'warning: DNN module load file does not exist')
7995
history = dnn.io.load_checkpoint(load, net, opt)
8096

81-
train_ds = app.Dataset(train_files, cache=cache)
82-
ntrain = len(train_ds)
83-
if ntrain == 0:
84-
raise click.BadArgumentUsage(f'no samples from {len(train_files)} files')
85-
86-
from torch.utils.data import DataLoader
87-
train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True, pin_memory=True)
88-
89-
neval = 0
90-
eval_dl = None
91-
if eval_files:
92-
eval_files = unglob(listify(eval_files, delim=","))
93-
log.info(f'eval files: {eval_files}')
94-
eval_ds = app.Dataset(eval_files, cache=cache)
95-
neval = len(eval_ds)
96-
eval_dl = DataLoader(train_ds, batch_size=batch, shuffle=False, pin_memory=True)
97-
else:
98-
log.info("no eval files")
97+
ds = app.Dataset(files, cache=cache, config=config.get("dataset", None))
98+
if len(ds) == 0:
99+
raise click.BadArgumentUsage(f'no samples from {len(files)} files')
100+
101+
tbatch,ebatch = batch,1
102+
103+
dses = dnn.data.train_eval_split(ds, train_ratio)
104+
dles = [DataLoader(one, batch_size=bb, shuffle=True, pin_memory=True) for one,bb in zip(dses, [tbatch,ebatch])]
105+
106+
ntrain = len(dses[0])
107+
neval = len(dses[1])
99108

100109
# History
101110
run_history = history.get("runs", dict())
@@ -104,9 +113,8 @@ def train(ctx, config, epochs, batch, device, cache, debug_torch,
104113
this_run_number = max(run_history.keys()) + 1
105114
this_run = dict(
106115
run = this_run_number,
107-
train_files = train_files,
116+
data_files = files,
108117
ntrain = ntrain,
109-
eval_files = eval_files or [],
110118
neval = neval,
111119
nepochs = epochs,
112120
batch = batch,
@@ -128,13 +136,17 @@ def saveit(path):
128136
dnn.io.save_checkpoint(path, net, opt, runs=run_history, epochs=epoch_history)
129137

130138
for this_epoch_number in range(first_epoch_number, first_epoch_number + epochs):
131-
train_losses = trainer.epoch(train_dl)
132-
train_loss = sum(train_losses)/ntrain
133139

134-
eval_losses = []
140+
train_loss = 0
141+
train_losses = []
142+
if ntrain:
143+
train_losses = trainer.epoch(dles[0])
144+
train_loss = sum(train_losses)/ntrain
145+
135146
eval_loss = 0
136-
if eval_dl:
137-
eval_losses = trainer.evaluate(eval_dl)
147+
eval_losses = []
148+
if neval:
149+
eval_losses = trainer.evaluate(dles[1])
138150
eval_loss = sum(eval_losses) / neval
139151

140152
this_epoch = dict(
@@ -146,7 +158,7 @@ def saveit(path):
146158
eval_loss=eval_loss)
147159
epoch_history[this_epoch_number] = this_epoch
148160

149-
log.info(f'run: {this_run_number} epoch: {this_epoch_number} loss: {train_loss} eval: {eval_loss}')
161+
log.info(f'run: {this_run_number} epoch: {this_epoch_number} loss: {train_loss:.4e} [b={tbatch},n={ntrain}] eval: {eval_loss:.4e} [b={ebatch},n={neval}]')
150162

151163
if checkpoint_save:
152164
if this_epoch_number % checkpoint_modulus == 0:

wirecell/dnn/apps/dnnroi/data.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from .transforms import Rec as Rect, Tru as Trut, Params as TrParams
1414

15+
import logging
16+
log = logging.getLogger("wirecell.dnn")
1517

1618
class Rec(hdf.Single):
1719
'''
@@ -21,7 +23,7 @@ class Rec(hdf.Single):
2123
OmnibusSigProc in HDF5 "frame file" form.
2224
'''
2325

24-
file_re = r'.*g4-rec-r(\d+)\.h5'
26+
file_re = r'.*g4-rec-[r]?(\d+)\.h5'
2527

2628
path_res = tuple(
2729
r'/(\d+)/%s\d'%tag for tag in [
@@ -47,7 +49,7 @@ class Tru(hdf.Single):
4749
This consists of the target ROI
4850
'''
4951

50-
file_re = r'.*g4-tru-r(\d+)\.h5'
52+
file_re = r'.*g4-tru-[r]?(\d+)\.h5'
5153

5254
path_res = tuple(
5355
r'/(\d+)/%s\d'%tag for tag in ['frame_ductor']
@@ -66,13 +68,31 @@ def __init__(self, paths, threshold = 0.5,
6668
super().__init__(dom, paths)
6769

6870

71+
72+
6973
class Dataset(hdf.Multi):
7074
'''
7175
The full DNNROI dataset is effectively zip(Rec,Tru).
7276
'''
73-
def __init__(self, paths, threshold=0.5, cache=False):
77+
def __init__(self, paths, threshold=0.5, cache=False, config=None):
78+
79+
log.debug(f'ddnroi dataset: {config=}')
80+
config = config or dict()
81+
def wash(key):
82+
val = config.get(key, None)
83+
if val is None:
84+
return
85+
if isinstance(val, str) and val.startswith(('[','{')):
86+
val = eval(val) # yes, I know
87+
log.debug(f'dnnroi dataset {key} = {val}')
88+
return val
89+
90+
7491
# fixme: allow configuring the transforms.
75-
super().__init__(Rec(paths, cache=cache),
76-
Tru(paths, threshold, cache=cache))
92+
super().__init__(Rec(paths, cache=cache,
93+
file_re=wash('rec_file_re'),
94+
path_res=wash('rec_path_res')),
95+
Tru(paths, threshold, cache=cache,
96+
file_re=wash('tru_file_re'),
97+
path_res=wash('tru_path_res')))
7798

78-

wirecell/dnn/cfg/hyu-pdvd.cfg

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[train]
2+
app=dnnroi
3+
device=gpu
4+
files=/nfs/data/1/bviren/dnnroi/data/hyu/pdvd/g4-*.h5
5+
batch=10
6+
epochs=25
7+

wirecell/dnn/cfg/renny-pdhd.cfg

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[train]
2+
app=dnnroi
3+
device=gpu
4+
files=/nfs/data/1/bviren/dnnroi/data/renney/train_data_PDHD_fixedbug_separateWC/*.h5
5+
batch=10
6+
epochs=25
7+
8+
[dataset]
9+
tru_path_res=[r'/(\d+)/frame_deposplat\d']
10+

0 commit comments

Comments
 (0)