2
2
3
3
import click
4
4
import torch
5
+ from torch .utils .data import DataLoader
6
+
5
7
from pathlib import Path
6
- from wirecell .util .cli import context , log , jsonnet_loader
8
+ from wirecell .util .cli import context , log , jsonnet_loader , anyconfig_file
7
9
from wirecell .util .paths import unglob , listify
8
10
9
11
@@ -17,16 +19,23 @@ def cli(ctx):
17
19
'''
18
20
pass
19
21
22
+ @cli .command ('dump-config' )
23
+ @anyconfig_file ("wirecelldnn" )
24
+ @click .pass_context
25
+ def dump_config (ctx , config ):
26
+ print (config )
27
+
28
+ return
29
+
30
+
31
+ train_defaults = dict (epochs = 1 , batch = 1 , device = 'cpu' , name = 'dnnroi' , train_ratio = 0.8 )
20
32
@cli .command ('train' )
21
- @click .option ("-c" , "--config" ,
22
- type = click .Path (),
23
- help = "Set configuration file" )
24
- @click .option ("-e" , "--epochs" , default = 1 ,
33
+ @click .option ("-e" , "--epochs" , default = None , type = int ,
25
34
help = "Number of epochs over which to train. "
26
35
"This is a relative count if the training starts with a -l/--load'ed state." )
27
- @click .option ("-b" , "--batch" , default = 1 ,
36
+ @click .option ("-b" , "--batch" , default = None , type = int ,
28
37
help = "Batch size" )
29
- @click .option ("-d" , "--device" , default = 'cpu' ,
38
+ @click .option ("-d" , "--device" , default = None , type = str ,
30
39
help = "The compute device" )
31
40
@click .option ("--cache/--no-cache" , is_flag = True , default = False ,
32
41
help = "Cache data in memory" )
@@ -38,33 +47,40 @@ def cli(ctx):
38
47
@click .option ("--checkpoint-modulus" , default = 1 ,
39
48
help = "Checkpoint modulus. "
40
49
"If checkpoint path is given, the training is checkpointed ever this many epochs.." )
41
- @click .option ("-n " , "--name " , default = 'dnnroi' ,
42
- help = "The application name (def=dnnroi) " )
50
+ @click .option ("-a " , "--app " , default = None , type = str ,
51
+ help = "The application name" )
43
52
@click .option ("-l" , "--load" , default = None ,
44
53
help = "File name providing the initial model state dict (def=None - construct fresh)" )
45
54
@click .option ("-s" , "--save" , default = None ,
46
55
help = "File name to save model state dict after training (def=None - results not saved)" )
47
- @click .option ("--eval-files" , multiple = True , type = str , # fixme: remove this in favor of a single file set and a train/eval partitioning
48
- help = "File path or globs as comma separated list to use for evaluation dataset" )
49
- @click .argument ("train_files" , nargs = - 1 )
56
+ @click .option ("--train-ratio" , default = None , type = float ,
57
+ help = "Fraction of samples to use for training (default=1.0, no evaluation loss calculated)" )
58
+ @anyconfig_file ("wirecelldnn" , section = 'train' , defaults = train_defaults )
59
+ @click .argument ("files" , nargs = - 1 )
50
60
@click .pass_context
51
61
def train (ctx , config , epochs , batch , device , cache , debug_torch ,
52
62
checkpoint_save , checkpoint_modulus ,
53
- name , load , save , eval_files , train_files ):
63
+ app , load , save , train_ratio , files ):
54
64
'''
55
65
Train a model.
56
66
'''
57
- if not train_files :
67
+
68
+ if not files : # args not processed by anyconfig_files
69
+ try :
70
+ files = config ['train' ]['files' ]
71
+ except KeyError :
72
+ files = None
73
+ if not files :
58
74
raise click .BadArgumentUsage ("no training files given" )
59
- train_files = unglob (listify (train_files ))
60
- log .info (f'training files: { train_files } ' )
75
+ files = unglob (listify (files ))
76
+ log .info (f'training files: { files } ' )
61
77
62
78
if device == 'gpu' : device = 'cuda'
63
- log .info (f'using device { device } ' )
64
79
65
80
if debug_torch :
66
81
torch .autograd .set_detect_anomaly (True )
67
82
83
+ name = app
68
84
app = getattr (dnn .apps , name )
69
85
70
86
net = app .Network ()
@@ -78,24 +94,17 @@ def train(ctx, config, epochs, batch, device, cache, debug_torch,
78
94
raise click .FileError (load , 'warning: DNN module load file does not exist' )
79
95
history = dnn .io .load_checkpoint (load , net , opt )
80
96
81
- train_ds = app .Dataset (train_files , cache = cache )
82
- ntrain = len (train_ds )
83
- if ntrain == 0 :
84
- raise click .BadArgumentUsage (f'no samples from { len (train_files )} files' )
85
-
86
- from torch .utils .data import DataLoader
87
- train_dl = DataLoader (train_ds , batch_size = batch , shuffle = True , pin_memory = True )
88
-
89
- neval = 0
90
- eval_dl = None
91
- if eval_files :
92
- eval_files = unglob (listify (eval_files , delim = "," ))
93
- log .info (f'eval files: { eval_files } ' )
94
- eval_ds = app .Dataset (eval_files , cache = cache )
95
- neval = len (eval_ds )
96
- eval_dl = DataLoader (train_ds , batch_size = batch , shuffle = False , pin_memory = True )
97
- else :
98
- log .info ("no eval files" )
97
+ ds = app .Dataset (files , cache = cache , config = config .get ("dataset" , None ))
98
+ if len (ds ) == 0 :
99
+ raise click .BadArgumentUsage (f'no samples from { len (files )} files' )
100
+
101
+ tbatch ,ebatch = batch ,1
102
+
103
+ dses = dnn .data .train_eval_split (ds , train_ratio )
104
+ dles = [DataLoader (one , batch_size = bb , shuffle = True , pin_memory = True ) for one ,bb in zip (dses , [tbatch ,ebatch ])]
105
+
106
+ ntrain = len (dses [0 ])
107
+ neval = len (dses [1 ])
99
108
100
109
# History
101
110
run_history = history .get ("runs" , dict ())
@@ -104,9 +113,8 @@ def train(ctx, config, epochs, batch, device, cache, debug_torch,
104
113
this_run_number = max (run_history .keys ()) + 1
105
114
this_run = dict (
106
115
run = this_run_number ,
107
- train_files = train_files ,
116
+ data_files = files ,
108
117
ntrain = ntrain ,
109
- eval_files = eval_files or [],
110
118
neval = neval ,
111
119
nepochs = epochs ,
112
120
batch = batch ,
@@ -128,13 +136,17 @@ def saveit(path):
128
136
dnn .io .save_checkpoint (path , net , opt , runs = run_history , epochs = epoch_history )
129
137
130
138
for this_epoch_number in range (first_epoch_number , first_epoch_number + epochs ):
131
- train_losses = trainer .epoch (train_dl )
132
- train_loss = sum (train_losses )/ ntrain
133
139
134
- eval_losses = []
140
+ train_loss = 0
141
+ train_losses = []
142
+ if ntrain :
143
+ train_losses = trainer .epoch (dles [0 ])
144
+ train_loss = sum (train_losses )/ ntrain
145
+
135
146
eval_loss = 0
136
- if eval_dl :
137
- eval_losses = trainer .evaluate (eval_dl )
147
+ eval_losses = []
148
+ if neval :
149
+ eval_losses = trainer .evaluate (dles [1 ])
138
150
eval_loss = sum (eval_losses ) / neval
139
151
140
152
this_epoch = dict (
@@ -146,7 +158,7 @@ def saveit(path):
146
158
eval_loss = eval_loss )
147
159
epoch_history [this_epoch_number ] = this_epoch
148
160
149
- log .info (f'run: { this_run_number } epoch: { this_epoch_number } loss: { train_loss } eval: { eval_loss } ' )
161
+ log .info (f'run: { this_run_number } epoch: { this_epoch_number } loss: { train_loss :.4e } [b= { tbatch } ,n= { ntrain } ] eval: { eval_loss :.4e } [b= { ebatch } ,n= { neval } ] ' )
150
162
151
163
if checkpoint_save :
152
164
if this_epoch_number % checkpoint_modulus == 0 :
0 commit comments