7
7
from wirecell .util .paths import unglob , listify
8
8
9
9
10
+ from wirecell import dnn
11
+
12
+
10
13
@context ("dnn" )
11
14
def cli (ctx ):
12
15
'''
@@ -41,68 +44,130 @@ def cli(ctx):
41
44
help = "File name providing the initial model state dict (def=None - construct fresh)" )
42
45
@click .option ("-s" , "--save" , default = None ,
43
46
help = "File name to save model state dict after training (def=None - results not saved)" )
44
- @click .argument ("files" , nargs = - 1 )
47
+ @click .option ("--eval-files" , multiple = True , type = str , # fixme: remove this in favor of a single file set and a train/eval partitioning
48
+ help = "File path or globs as comma separated list to use for evaluation dataset" )
49
+ @click .argument ("train_files" , nargs = - 1 )
45
50
@click .pass_context
46
51
def train (ctx , config , epochs , batch , device , cache , debug_torch ,
47
52
checkpoint_save , checkpoint_modulus ,
48
- name , load , save , files ):
53
+ name , load , save , eval_files , train_files ):
49
54
'''
50
55
Train a model.
51
56
'''
52
- if not files :
53
- raise click .BadArgumentUsage ("no files given" )
57
+ if not train_files :
58
+ raise click .BadArgumentUsage ("no training files given" )
59
+ train_files = unglob (listify (train_files ))
60
+ log .info (f'training files: { train_files } ' )
54
61
55
62
if device == 'gpu' : device = 'cuda'
56
63
log .info (f'using device { device } ' )
57
64
58
65
if debug_torch :
59
66
torch .autograd .set_detect_anomaly (True )
60
67
61
- # fixme: make choice of dataset optional
62
- import wirecell .dnn .apps
63
- from wirecell .dnn import io
64
-
65
- app = getattr (wirecell .dnn .apps , name )
68
+ app = getattr (dnn .apps , name )
66
69
67
70
net = app .Network ()
68
71
opt = app .Optimizer (net .parameters ())
72
+ crit = app .Criterion ()
73
+ trainer = app .Trainer (net , opt , crit , device = device )
69
74
70
- par = dict (epoch = 0 , loss = 0 )
71
-
75
+ history = dict ()
72
76
if load :
73
77
if not Path (load ).exists ():
74
78
raise click .FileError (load , 'warning: DNN module load file does not exist' )
75
- par = io .load_checkpoint (load , net , opt )
76
-
77
- tot_epoch = par ["epoch" ]
78
- del par
79
+ history = dnn .io .load_checkpoint (load , net , opt )
79
80
80
- ds = app .Dataset (files , cache = cache )
81
- nsamples = len (ds )
82
- if nsamples == 0 :
83
- raise click .BadArgumentUsage (f'no samples from { len (files )} files' )
81
+ train_ds = app .Dataset (train_files , cache = cache )
82
+ ntrain = len (train_ds )
83
+ if ntrain == 0 :
84
+ raise click .BadArgumentUsage (f'no samples from { len (train_files )} files' )
84
85
85
86
from torch .utils .data import DataLoader
86
- dl = DataLoader (ds , batch_size = batch , shuffle = True , pin_memory = True )
87
+ train_dl = DataLoader (train_ds , batch_size = batch , shuffle = True , pin_memory = True )
87
88
88
- trainer = app .Trainer (net , device = device )
89
-
90
- checkpoint = 2 # fixme make configurable
91
- for epoch in range (epochs ):
92
- losslist = trainer .epoch (dl )
93
- loss = sum (losslist )
94
- log .debug (f'epoch { tot_epoch } loss { loss } ' )
89
+ neval = 0
90
+ eval_dl = None
91
+ if eval_files :
92
+ eval_files = unglob (listify (eval_files , delim = "," ))
93
+ log .info (f'eval files: { eval_files } ' )
94
+ eval_ds = app .Dataset (eval_files , cache = cache )
95
+ neval = len (eval_ds )
96
+ eval_dl = DataLoader (train_ds , batch_size = batch , shuffle = False , pin_memory = True )
97
+ else :
98
+ log .info ("no eval files" )
99
+
100
+ # History
101
+ run_history = history .get ("runs" , dict ())
102
+ this_run_number = 0
103
+ if run_history :
104
+ this_run_number = max (run_history .keys ()) + 1
105
+ this_run = dict (
106
+ run = this_run_number ,
107
+ train_files = train_files ,
108
+ ntrain = ntrain ,
109
+ eval_files = eval_files or [],
110
+ neval = neval ,
111
+ nepochs = epochs ,
112
+ batch = batch ,
113
+ device = device ,
114
+ cache = cache ,
115
+ name = name ,
116
+ load = load ,
117
+ )
118
+ run_history [this_run_number ] = this_run
119
+
120
+ epoch_history = history .get ("epochs" , dict ())
121
+ first_epoch_number = 0
122
+ if epoch_history :
123
+ first_epoch_number = max (epoch_history .keys ()) + 1
124
+
125
+ def saveit (path ):
126
+ if not path :
127
+ return
128
+ dnn .io .save_checkpoint (path , net , opt , runs = run_history , epochs = epoch_history )
129
+
130
+ for this_epoch_number in range (first_epoch_number , first_epoch_number + epochs ):
131
+ train_losses = trainer .epoch (train_dl )
132
+ train_loss = sum (train_losses )/ ntrain
133
+
134
+ eval_losses = []
135
+ eval_loss = 0
136
+ if eval_dl :
137
+ eval_losses = trainer .evaluate (eval_dl )
138
+ eval_loss = sum (eval_losses ) / neval
139
+
140
+ this_epoch = dict (
141
+ run = this_run_number ,
142
+ epoch = this_epoch_number ,
143
+ train_losses = train_losses ,
144
+ train_loss = train_loss ,
145
+ eval_losses = eval_losses ,
146
+ eval_loss = eval_loss )
147
+ epoch_history [this_epoch_number ] = this_epoch
148
+
149
+ log .info (f'run: { this_run_number } epoch: { this_epoch_number } loss: { train_loss } eval: { eval_loss } ' )
95
150
96
151
if checkpoint_save :
97
- if tot_epoch % checkpoint_modulus == 0 :
98
- cpath = checkpoint_save . format ( epoch = tot_epoch )
99
- io . save_checkpoint ( cpath , net , opt ,
100
- epoch = tot_epoch , loss = loss )
101
- tot_epoch += 1
152
+ if this_epoch_number % checkpoint_modulus == 0 :
153
+ parms = dict ( this_run , ** this_epoch )
154
+ cpath = checkpoint_save . format ( ** parms )
155
+ saveit ( cpath )
156
+ saveit ( save )
102
157
103
- if save :
104
- io .save_checkpoint (save , net , opt , epoch = tot_epoch , loss = loss )
105
158
159
+ @cli .command ('dump' )
160
+ @click .argument ("checkpoint" )
161
+ @click .pass_context
162
+ def dump (ctx , checkpoint ):
163
+ '''
164
+ Dump info about a checkpoint file.
165
+ '''
166
+ state = dnn .io .load_checkpoint_raw (checkpoint )
167
+ for rnum , robj in state .get ("runs" ,{}).items ():
168
+ print ('run: {run} ntrain: {ntrain} neval: {neval}' .format (** robj ))
169
+ for enum , eobj in state .get ("epochs" ,{}).items ():
170
+ print ('run: {run} epoch: {epoch} train: {train_loss} eval: {eval_loss}' .format (** eobj ))
106
171
107
172
@cli .command ('extract' )
108
173
@click .option ("-o" , "--output" , default = 'samples.npz' ,
@@ -120,7 +185,6 @@ def extract(ctx, output, sample, datapaths):
120
185
samples = map (int ,listify (* sample , delim = "," ))
121
186
122
187
# fixme: make choice of dataset optional
123
- from wirecell .dnn .apps import dnnroi as app
124
188
ds = app .Dataset (datapaths )
125
189
126
190
log .info (f'dataset has { len (ds )} entries from { len (datapaths )} data paths' )
0 commit comments