22import time
33import copy
44import random
5+ import logging
56import numpy as np
67import torch
78import torch .nn as nn
@@ -96,7 +97,7 @@ def _forward_dataset(model, criterion, data_loader, opt):
9697 if random .random () > opt .validate_ratio :
9798 continue
9899 if opt .mode == "Test" :
99- print "test %s/%s image" % (i , len (data_loader ))
100+ logging . info ( "test %s/%s image" % (i , len (data_loader ) ))
100101 sum_batch += 1
101102 inputs , targets = data
102103 output , loss , loss_list = _forward (model , criterion , inputs , targets , opt , "Validate" )
@@ -118,6 +119,7 @@ def _forward_dataset(model, criterion, data_loader, opt):
118119 for index , loss in enumerate (loss_list ):
119120 avg_loss [index ] += loss
120121 # average on batches
122+ print sum_batch
121123 for index , item in enumerate (accuracy ):
122124 for k ,v in item .iteritems ():
123125 accuracy [index ][k ]["ratio" ] /= float (sum_batch )
@@ -129,19 +131,16 @@ def validate(model, criterion, val_set, opt):
129131 return _forward_dataset (model , criterion , val_set , opt )
130132
131133def test (model , criterion , test_set , opt ):
132- print "####################Test Model###################"
134+ logging . info ( "####################Test Model###################" )
133135 test_accuracy , test_loss = _forward_dataset (model , criterion , test_set , opt )
134- print "data_dir: " , opt .data_dir + "/TestSet/"
135- print "state_dict: " , opt .model_dir + "/" + opt .checkpoint_name
136- util .print_loss (test_loss , "Test" )
137- util .print_accuracy (test_accuracy , "Test" )
138- test_result = os .path .join (opt .test_dir , "result.txt" )
139- with open (test_result , 'w' ) as t :
140- for index , item in enumerate (test_accuracy ):
141- t .write ("Attribute %d:\n " % (index ))
142- for top_k , value in item .iteritems ():
143- t .write ("----Accuracy of Top%d: %f\n " % (top_k , value ["ratio" ]))
144- print "#################Finished Testing################"
136+ logging .info ("data_dir: " + opt .data_dir + "/TestSet/" )
137+ logging .info ("state_dict: " + opt .model_dir + "/" + opt .checkpoint_name )
138+ logging .info ("score_thres:" + str (opt .score_thres ))
139+ for index , item in enumerate (test_accuracy ):
140+ logging .info ("Attribute %d:" % (index ))
141+ for top_k , value in item .iteritems ():
142+ logging .info ("----Accuracy of Top%d: %f" % (top_k , value ["ratio" ]))
143+ logging .info ("#################Finished Testing################" )
145144
146145def train (model , criterion , train_set , val_set , opt , labels = None ):
147146 # define web visualizer using visdom
@@ -165,11 +164,11 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
165164 # record forward and backward times
166165 train_batch_num = len (train_set )
167166 total_batch_iter = 0
168- print "####################Train Model###################"
167+ logging . info ( "####################Train Model###################" )
169168 for epoch in range (opt .sum_epoch ):
170169 epoch_start_t = time .time ()
171170 epoch_batch_iter = 0
172- print 'Begin of epoch %d' % (epoch )
171+ logging . info ( 'Begin of epoch %d' % (epoch ) )
173172 for i , data in enumerate (train_set ):
174173 iter_start_t = time .time ()
175174 # train
@@ -215,7 +214,7 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
215214 webvis .plot_images (image_dict , opt .display_id + 2 * opt .class_num , epoch , save_result )
216215
217216 # validate and display validate loss and accuracy
218- if total_batch_iter % opt .display_validate_freq == 0 and len ( val_set ) > 0 :
217+ if len ( val_set ) > 0 and total_batch_iter % opt .display_validate_freq == 0 :
219218 val_accuracy , val_loss = validate (model , criterion , val_set , opt )
220219 x_axis = epoch + float (epoch_batch_iter )/ train_batch_num
221220 accuracy_list = [val_accuracy [i ][opt .top_k [0 ]]["ratio" ] for i in range (len (val_accuracy ))]
@@ -227,21 +226,21 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
227226
228227 # save snapshot
229228 if total_batch_iter % opt .save_batch_iter_freq == 0 :
230- print "saving the latest model (epoch %d, total_batch_iter %d)" % (epoch , total_batch_iter )
229+ logging . info ( "saving the latest model (epoch %d, total_batch_iter %d)" % (epoch , total_batch_iter ) )
231230 util .save_model (model , opt , epoch )
232231 # TODO snapshot loss and accuracy
233232
234- print ('End of epoch %d / %d \t Time Taken: %d sec' %
233+ logging . info ('End of epoch %d / %d \t Time Taken: %d sec' %
235234 (epoch , opt .sum_epoch , time .time () - epoch_start_t ))
236235
237- if epoch % opt .save_epoch_freq :
238- print 'saving the model at the end of epoch %d, iters %d' % (epoch , total_batch_iter )
236+ if epoch % opt .save_epoch_freq == 0 :
237+ logging . info ( 'saving the model at the end of epoch %d, iters %d' % (epoch , total_batch_iter ) )
239238 util .save_model (model , opt , epoch )
240239
241240 # adjust learning rate
242241 scheduler .step ()
243242 lr = optimizer .param_groups [0 ]['lr' ]
244- print ('learning rate = %.7f' % lr , ' epoch = %d' % (epoch ))
243+ logging . info ('learning rate = %.7f epoch = %d' % (lr , epoch ))
245244
246245def _load_model (opt , num_classes ):
247246 # load model
@@ -251,7 +250,8 @@ def _load_model(opt, num_classes):
251250 tmp_output = templet (tmp_input )
252251 output_dim = int (tmp_output .size ()[- 1 ])
253252 model = BuildMultiLabelModel (templet , output_dim , num_classes )
254- print model
253+ #print model
254+ logging .info (model )
255255
256256 # load exsiting model
257257 if opt .checkpoint_name != "" :
@@ -283,6 +283,22 @@ def main():
283283 op = Options ()
284284 opt = op .parse ()
285285
286+ # log setting
287+ log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
288+ formatter = logging .Formatter (log_format )
289+ if opt .mode == "Train" :
290+ log_path = os .path .join (opt .model_dir , "train.log" )
291+ else :
292+ log_path = os .path .join (opt .test_dir , "test.log" )
293+ fh = logging .FileHandler (log_path , 'a' )
294+ fh .setFormatter (formatter )
295+ ch = logging .StreamHandler ()
296+ ch .setFormatter (formatter )
297+ logging .getLogger ().addHandler (fh )
298+ logging .getLogger ().addHandler (ch )
299+ log_level = logging .INFO
300+ logging .getLogger ().setLevel (log_level )
301+
286302 # load train or test data
287303 data_loader = MultiLabelDataLoader (opt )
288304 if opt .mode == "Train" :
0 commit comments