Skip to content

Commit 75030c2

Browse files
committed
Add history to the checkpoint state
1 parent ec20447 commit 75030c2

File tree

6 files changed

+139
-61
lines changed

6 files changed

+139
-61
lines changed

wirecell/dnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import data, models, io, apps

wirecell/dnn/__main__.py

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from wirecell.util.paths import unglob, listify
88

99

10+
from wirecell import dnn
11+
12+
1013
@context("dnn")
1114
def cli(ctx):
1215
'''
@@ -41,68 +44,130 @@ def cli(ctx):
4144
help="File name providing the initial model state dict (def=None - construct fresh)")
4245
@click.option("-s", "--save", default=None,
4346
help="File name to save model state dict after training (def=None - results not saved)")
44-
@click.argument("files", nargs=-1)
47+
@click.option("--eval-files", multiple=True, type=str, # fixme: remove this in favor of a single file set and a train/eval partitioning
48+
help="File path or globs as comma separated list to use for evaluation dataset")
49+
@click.argument("train_files", nargs=-1)
4550
@click.pass_context
4651
def train(ctx, config, epochs, batch, device, cache, debug_torch,
4752
checkpoint_save, checkpoint_modulus,
48-
name, load, save, files):
53+
name, load, save, eval_files, train_files):
4954
'''
5055
Train a model.
5156
'''
52-
if not files:
53-
raise click.BadArgumentUsage("no files given")
57+
if not train_files:
58+
raise click.BadArgumentUsage("no training files given")
59+
train_files = unglob(listify(train_files))
60+
log.info(f'training files: {train_files}')
5461

5562
if device == 'gpu': device = 'cuda'
5663
log.info(f'using device {device}')
5764

5865
if debug_torch:
5966
torch.autograd.set_detect_anomaly(True)
6067

61-
# fixme: make choice of dataset optional
62-
import wirecell.dnn.apps
63-
from wirecell.dnn import io
64-
65-
app = getattr(wirecell.dnn.apps, name)
68+
app = getattr(dnn.apps, name)
6669

6770
net = app.Network()
6871
opt = app.Optimizer(net.parameters())
72+
crit = app.Criterion()
73+
trainer = app.Trainer(net, opt, crit, device=device)
6974

70-
par = dict(epoch=0, loss=0)
71-
75+
history = dict()
7276
if load:
7377
if not Path(load).exists():
7478
raise click.FileError(load, 'warning: DNN module load file does not exist')
75-
par = io.load_checkpoint(load, net, opt)
76-
77-
tot_epoch = par["epoch"]
78-
del par
79+
history = dnn.io.load_checkpoint(load, net, opt)
7980

80-
ds = app.Dataset(files, cache=cache)
81-
nsamples = len(ds)
82-
if nsamples == 0:
83-
raise click.BadArgumentUsage(f'no samples from {len(files)} files')
81+
train_ds = app.Dataset(train_files, cache=cache)
82+
ntrain = len(train_ds)
83+
if ntrain == 0:
84+
raise click.BadArgumentUsage(f'no samples from {len(train_files)} files')
8485

8586
from torch.utils.data import DataLoader
86-
dl = DataLoader(ds, batch_size=batch, shuffle=True, pin_memory=True)
87+
train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True, pin_memory=True)
8788

88-
trainer = app.Trainer(net, device=device)
89-
90-
checkpoint=2 # fixme make configurable
91-
for epoch in range(epochs):
92-
losslist = trainer.epoch(dl)
93-
loss = sum(losslist)
94-
log.debug(f'epoch {tot_epoch} loss {loss}')
89+
neval = 0
90+
eval_dl = None
91+
if eval_files:
92+
eval_files = unglob(listify(eval_files, delim=","))
93+
log.info(f'eval files: {eval_files}')
94+
eval_ds = app.Dataset(eval_files, cache=cache)
95+
neval = len(eval_ds)
96+
eval_dl = DataLoader(train_ds, batch_size=batch, shuffle=False, pin_memory=True)
97+
else:
98+
log.info("no eval files")
99+
100+
# History
101+
run_history = history.get("runs", dict())
102+
this_run_number = 0
103+
if run_history:
104+
this_run_number = max(run_history.keys()) + 1
105+
this_run = dict(
106+
run = this_run_number,
107+
train_files = train_files,
108+
ntrain = ntrain,
109+
eval_files = eval_files or [],
110+
neval = neval,
111+
nepochs = epochs,
112+
batch = batch,
113+
device = device,
114+
cache = cache,
115+
name = name,
116+
load = load,
117+
)
118+
run_history[this_run_number] = this_run
119+
120+
epoch_history = history.get("epochs", dict())
121+
first_epoch_number = 0
122+
if epoch_history:
123+
first_epoch_number = max(epoch_history.keys()) + 1
124+
125+
def saveit(path):
126+
if not path:
127+
return
128+
dnn.io.save_checkpoint(path, net, opt, runs=run_history, epochs=epoch_history)
129+
130+
for this_epoch_number in range(first_epoch_number, first_epoch_number + epochs):
131+
train_losses = trainer.epoch(train_dl)
132+
train_loss = sum(train_losses)/ntrain
133+
134+
eval_losses = []
135+
eval_loss = 0
136+
if eval_dl:
137+
eval_losses = trainer.evaluate(eval_dl)
138+
eval_loss = sum(eval_losses) / neval
139+
140+
this_epoch = dict(
141+
run=this_run_number,
142+
epoch=this_epoch_number,
143+
train_losses=train_losses,
144+
train_loss=train_loss,
145+
eval_losses=eval_losses,
146+
eval_loss=eval_loss)
147+
epoch_history[this_epoch_number] = this_epoch
148+
149+
log.info(f'run: {this_run_number} epoch: {this_epoch_number} loss: {train_loss} eval: {eval_loss}')
95150

96151
if checkpoint_save:
97-
if tot_epoch%checkpoint_modulus == 0:
98-
cpath = checkpoint_save.format(epoch=tot_epoch)
99-
io.save_checkpoint(cpath, net, opt,
100-
epoch=tot_epoch, loss=loss)
101-
tot_epoch += 1
152+
if this_epoch_number % checkpoint_modulus == 0:
153+
parms = dict(this_run, **this_epoch)
154+
cpath = checkpoint_save.format(**parms)
155+
saveit(cpath)
156+
saveit(save)
102157

103-
if save:
104-
io.save_checkpoint(save, net, opt, epoch=tot_epoch, loss=loss)
105158

159+
@cli.command('dump')
160+
@click.argument("checkpoint")
161+
@click.pass_context
162+
def dump(ctx, checkpoint):
163+
'''
164+
Dump info about a checkpoint file.
165+
'''
166+
state = dnn.io.load_checkpoint_raw(checkpoint)
167+
for rnum, robj in state.get("runs",{}).items():
168+
print('run: {run} ntrain: {ntrain} neval: {neval}'.format(**robj))
169+
for enum, eobj in state.get("epochs",{}).items():
170+
print('run: {run} epoch: {epoch} train: {train_loss} eval: {eval_loss}'.format(**eobj))
106171

107172
@cli.command('extract')
108173
@click.option("-o", "--output", default='samples.npz',
@@ -120,7 +185,6 @@ def extract(ctx, output, sample, datapaths):
120185
samples = map(int,listify(*sample, delim=","))
121186

122187
# fixme: make choice of dataset optional
123-
from wirecell.dnn.apps import dnnroi as app
124188
ds = app.Dataset(datapaths)
125189

126190
log.info(f'dataset has {len(ds)} entries from {len(datapaths)} data paths')

wirecell/dnn/apps/dnnroi/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#!/usr/bin/env python
2+
from torch import optim
23

3-
4+
## The "app" API
45
from .model import Network
56
from .data import Dataset
6-
from .train import Classifier as Trainer
7+
from wirecell.dnn.train import Classifier as Trainer
8+
from torch.nn import BCELoss as Criterion
79

810

9-
from torch import optim
1011
def Optimizer(params):
1112
return optim.SGD(params, lr=0.1, momentum=0.9, weight_decay=0.0005)
1213

wirecell/dnn/apps/dnnroi/train.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

wirecell/dnn/io.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,19 @@ def save_checkpoint(path, model, optimizer, **kwds):
2020
torch.save(kwds, path)
2121

2222

23+
def load_checkpoint_raw(path):
24+
return torch.load(path, weights_only=True)
25+
26+
2327
def load_checkpoint(path, model, optimizer):
2428
'''
2529
Load a checkpoint.
2630
2731
The model and optimizer state dicts are updated and a dict of any additional
2832
parameters is returned.
2933
'''
30-
cp = torch.load(path, weights_only=True)
34+
cp = load_checkpoint_raw(path)
3135
model.load_state_dict(cp.pop("model_state_dict"))
3236
optimizer.load_state_dict(cp.pop("optimizer_state_dict"))
3337
return cp
38+

wirecell/dnn/train.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,45 @@
1919
- optimizer.step()
2020
2121
'''
22-
from torch import optim
22+
from torch import optim, no_grad
2323
import torch.nn as nn
2424

2525
def dump(name, data):
2626
# print(f'{name:20s}: {data.shape} {data.dtype} {data.device}')
2727
return
2828

2929
class Classifier:
30-
def __init__(self, net, device='cpu', optclass = optim.SGD, **optkwds):
30+
def __init__(self, net, optimizer, criterion = nn.BCELoss(), device='cpu'):
3131
net.to(device)
3232
self._device = device
3333
self.net = net # model
34-
self.optimizer = optclass(net.parameters(), **optkwds)
34+
self.optimizer = optimizer
35+
self.criterion = criterion
3536

36-
def epoch(self, data, criterion=nn.BCELoss(), retain_graph=False):
37+
def loss(self, features, labels):
38+
39+
features = features.to(self._device)
40+
dump('features', features)
41+
labels = labels.to(self._device)
42+
dump('labels', labels)
43+
44+
prediction = self.net(features)
45+
dump('prediction', prediction)
46+
47+
loss = self.criterion(prediction, labels)
48+
return loss
49+
50+
def evaluate(self, data):
51+
losses = list()
52+
with no_grad():
53+
for features, labels in data:
54+
loss = self.loss(features, labels)
55+
loss = loss.item()
56+
losses.append(loss)
57+
return losses
58+
59+
60+
def epoch(self, data, retain_graph=False):
3761
'''
3862
Train over the batches of the data, return list of losses at each batch.
3963
'''
@@ -42,15 +66,7 @@ def epoch(self, data, criterion=nn.BCELoss(), retain_graph=False):
4266
epoch_losses = list()
4367
for features, labels in data:
4468

45-
features = features.to(self._device)
46-
dump('features', features)
47-
labels = labels.to(self._device)
48-
dump('labels', labels)
49-
50-
prediction = self.net(features)
51-
dump('prediction', prediction)
52-
53-
loss = criterion(prediction, labels)
69+
loss = self.loss(features, labels)
5470

5571
loss.backward(retain_graph=retain_graph)
5672
self.optimizer.step()

0 commit comments

Comments
 (0)