2
2
import time
3
3
import copy
4
4
import random
5
+ import logging
5
6
import numpy as np
6
7
import torch
7
8
import torch .nn as nn
@@ -96,7 +97,7 @@ def _forward_dataset(model, criterion, data_loader, opt):
96
97
if random .random () > opt .validate_ratio :
97
98
continue
98
99
if opt .mode == "Test" :
99
- print "test %s/%s image" % (i , len (data_loader ))
100
+ logging . info ( "test %s/%s image" % (i , len (data_loader ) ))
100
101
sum_batch += 1
101
102
inputs , targets = data
102
103
output , loss , loss_list = _forward (model , criterion , inputs , targets , opt , "Validate" )
@@ -118,6 +119,7 @@ def _forward_dataset(model, criterion, data_loader, opt):
118
119
for index , loss in enumerate (loss_list ):
119
120
avg_loss [index ] += loss
120
121
# average on batches
122
+ print sum_batch
121
123
for index , item in enumerate (accuracy ):
122
124
for k ,v in item .iteritems ():
123
125
accuracy [index ][k ]["ratio" ] /= float (sum_batch )
@@ -129,19 +131,16 @@ def validate(model, criterion, val_set, opt):
129
131
return _forward_dataset (model , criterion , val_set , opt )
130
132
131
133
def test (model , criterion , test_set , opt ):
132
- print "####################Test Model###################"
134
+ logging . info ( "####################Test Model###################" )
133
135
test_accuracy , test_loss = _forward_dataset (model , criterion , test_set , opt )
134
- print "data_dir: " , opt .data_dir + "/TestSet/"
135
- print "state_dict: " , opt .model_dir + "/" + opt .checkpoint_name
136
- util .print_loss (test_loss , "Test" )
137
- util .print_accuracy (test_accuracy , "Test" )
138
- test_result = os .path .join (opt .test_dir , "result.txt" )
139
- with open (test_result , 'w' ) as t :
140
- for index , item in enumerate (test_accuracy ):
141
- t .write ("Attribute %d:\n " % (index ))
142
- for top_k , value in item .iteritems ():
143
- t .write ("----Accuracy of Top%d: %f\n " % (top_k , value ["ratio" ]))
144
- print "#################Finished Testing################"
136
+ logging .info ("data_dir: " + opt .data_dir + "/TestSet/" )
137
+ logging .info ("state_dict: " + opt .model_dir + "/" + opt .checkpoint_name )
138
+ logging .info ("score_thres:" + str (opt .score_thres ))
139
+ for index , item in enumerate (test_accuracy ):
140
+ logging .info ("Attribute %d:" % (index ))
141
+ for top_k , value in item .iteritems ():
142
+ logging .info ("----Accuracy of Top%d: %f" % (top_k , value ["ratio" ]))
143
+ logging .info ("#################Finished Testing################" )
145
144
146
145
def train (model , criterion , train_set , val_set , opt , labels = None ):
147
146
# define web visualizer using visdom
@@ -165,11 +164,11 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
165
164
# record forward and backward times
166
165
train_batch_num = len (train_set )
167
166
total_batch_iter = 0
168
- print "####################Train Model###################"
167
+ logging . info ( "####################Train Model###################" )
169
168
for epoch in range (opt .sum_epoch ):
170
169
epoch_start_t = time .time ()
171
170
epoch_batch_iter = 0
172
- print 'Begin of epoch %d' % (epoch )
171
+ logging . info ( 'Begin of epoch %d' % (epoch ) )
173
172
for i , data in enumerate (train_set ):
174
173
iter_start_t = time .time ()
175
174
# train
@@ -215,7 +214,7 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
215
214
webvis .plot_images (image_dict , opt .display_id + 2 * opt .class_num , epoch , save_result )
216
215
217
216
# validate and display validate loss and accuracy
218
- if total_batch_iter % opt .display_validate_freq == 0 and len ( val_set ) > 0 :
217
+ if len ( val_set ) > 0 and total_batch_iter % opt .display_validate_freq == 0 :
219
218
val_accuracy , val_loss = validate (model , criterion , val_set , opt )
220
219
x_axis = epoch + float (epoch_batch_iter )/ train_batch_num
221
220
accuracy_list = [val_accuracy [i ][opt .top_k [0 ]]["ratio" ] for i in range (len (val_accuracy ))]
@@ -227,21 +226,21 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
227
226
228
227
# save snapshot
229
228
if total_batch_iter % opt .save_batch_iter_freq == 0 :
230
- print "saving the latest model (epoch %d, total_batch_iter %d)" % (epoch , total_batch_iter )
229
+ logging . info ( "saving the latest model (epoch %d, total_batch_iter %d)" % (epoch , total_batch_iter ) )
231
230
util .save_model (model , opt , epoch )
232
231
# TODO snapshot loss and accuracy
233
232
234
- print ('End of epoch %d / %d \t Time Taken: %d sec' %
233
+ logging . info ('End of epoch %d / %d \t Time Taken: %d sec' %
235
234
(epoch , opt .sum_epoch , time .time () - epoch_start_t ))
236
235
237
- if epoch % opt .save_epoch_freq :
238
- print 'saving the model at the end of epoch %d, iters %d' % (epoch , total_batch_iter )
236
+ if epoch % opt .save_epoch_freq == 0 :
237
+ logging . info ( 'saving the model at the end of epoch %d, iters %d' % (epoch , total_batch_iter ) )
239
238
util .save_model (model , opt , epoch )
240
239
241
240
# adjust learning rate
242
241
scheduler .step ()
243
242
lr = optimizer .param_groups [0 ]['lr' ]
244
- print ('learning rate = %.7f' % lr , ' epoch = %d' % (epoch ))
243
+ logging . info ('learning rate = %.7f epoch = %d' % (lr , epoch ))
245
244
246
245
def _load_model (opt , num_classes ):
247
246
# load model
@@ -251,7 +250,8 @@ def _load_model(opt, num_classes):
251
250
tmp_output = templet (tmp_input )
252
251
output_dim = int (tmp_output .size ()[- 1 ])
253
252
model = BuildMultiLabelModel (templet , output_dim , num_classes )
254
- print model
253
+ #print model
254
+ logging .info (model )
255
255
256
256
# load exsiting model
257
257
if opt .checkpoint_name != "" :
@@ -283,6 +283,22 @@ def main():
283
283
op = Options ()
284
284
opt = op .parse ()
285
285
286
+ # log setting
287
+ log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
288
+ formatter = logging .Formatter (log_format )
289
+ if opt .mode == "Train" :
290
+ log_path = os .path .join (opt .model_dir , "train.log" )
291
+ else :
292
+ log_path = os .path .join (opt .test_dir , "test.log" )
293
+ fh = logging .FileHandler (log_path , 'a' )
294
+ fh .setFormatter (formatter )
295
+ ch = logging .StreamHandler ()
296
+ ch .setFormatter (formatter )
297
+ logging .getLogger ().addHandler (fh )
298
+ logging .getLogger ().addHandler (ch )
299
+ log_level = logging .INFO
300
+ logging .getLogger ().setLevel (log_level )
301
+
286
302
# load train or test data
287
303
data_loader = MultiLabelDataLoader (opt )
288
304
if opt .mode == "Train" :
0 commit comments