Skip to content

Commit aa865be

Browse files
committed
add train and test log
1 parent e8782f7 commit aa865be

File tree

4 files changed

+53
-56
lines changed

4 files changed

+53
-56
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Data Format Explanation.
6363
- [x] Support model finetuning
6464
- [x] Complete test module
6565
- [ ] Add switch to control loss and accuracy curves displaying on one plot or multiple
66-
- [ ] Train and Test Log
66+
- [x] Train and Test Log
6767

6868

6969
## Reference

demo.sh

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/bin/bash
22

3+
#--------------train--------------
34
# to visualize on web, you have to start visdom server. Result is on localhost:8900
45
# 1. create screen:
56
# screen -S visdom.8900
@@ -8,5 +9,9 @@
89
# 3. leave screen:
910
# ctrl + a + d
1011
# 4. start demo.sh
11-
python multi_label_classifier.py --dir "./test/celeba/" --mode "Train" --name "test" --batch_size 64 --gpu_ids 0 --input_channel 3 --load_size 144 --input_size 128 --mean [0,0,0] --std [1,1,1] --ratio "[0.94, 0.03, 0.03]" --shuffle --load_thread 8 --sum_epoch 20 --lr_decay_in_epoch 4 --display_port 8900 --validate_ratio 0.1 --top_k "(1,)" --score_thres 0.1 --display_train_freq 20 --display_validate_freq 20 --save_epoch_freq 1 --display_image_ratio 0.2
12+
#python multi_label_classifier.py --dir "./test/celeba/" --mode "Train" --name "test" --batch_size 64 --gpu_ids 0 --input_channel 3 --load_size 144 --input_size 128 --mean [0,0,0] --std [1,1,1] --ratio "[0.94, 0.03, 0.03]" --shuffle --load_thread 8 --sum_epoch 20 --lr_decay_in_epoch 4 --display_port 8900 --validate_ratio 0.1 --top_k "(1,)" --score_thres 0.1 --display_train_freq 20 --display_validate_freq 20 --save_epoch_freq 1 --display_image_ratio 0.2
1213
# 5. open localhost:8900 on your browser and you will see loss and accuracy curves and training images samples later on.
14+
15+
16+
#--------------test--------------
17+
python multi_label_classifier.py --dir "./test/celeba/" --mode "Test" --name "test" --batch_size 1 --gpu_ids 4 --input_channel 3 --load_size 144 --input_size 128 --mean [0,0,0] --std [1,1,1] --shuffle --load_thread 1 --top_k "(1,2)" --score_thres 0.1 --checkpoint_name "epoch_1_snapshot.pth"

multi_label_classifier.py

+38-22
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
import copy
44
import random
5+
import logging
56
import numpy as np
67
import torch
78
import torch.nn as nn
@@ -96,7 +97,7 @@ def _forward_dataset(model, criterion, data_loader, opt):
9697
if random.random() > opt.validate_ratio:
9798
continue
9899
if opt.mode == "Test":
99-
print "test %s/%s image" %(i, len(data_loader))
100+
logging.info("test %s/%s image" %(i, len(data_loader)))
100101
sum_batch += 1
101102
inputs, targets = data
102103
output, loss, loss_list = _forward(model, criterion, inputs, targets, opt, "Validate")
@@ -118,6 +119,7 @@ def _forward_dataset(model, criterion, data_loader, opt):
118119
for index, loss in enumerate(loss_list):
119120
avg_loss[index] += loss
120121
# average on batches
122+
print sum_batch
121123
for index, item in enumerate(accuracy):
122124
for k,v in item.iteritems():
123125
accuracy[index][k]["ratio"] /= float(sum_batch)
@@ -129,19 +131,16 @@ def validate(model, criterion, val_set, opt):
129131
return _forward_dataset(model, criterion, val_set, opt)
130132

131133
def test(model, criterion, test_set, opt):
132-
print "####################Test Model###################"
134+
logging.info("####################Test Model###################")
133135
test_accuracy, test_loss = _forward_dataset(model, criterion, test_set, opt)
134-
print "data_dir: ", opt.data_dir + "/TestSet/"
135-
print "state_dict: ", opt.model_dir + "/" + opt.checkpoint_name
136-
util.print_loss(test_loss, "Test")
137-
util.print_accuracy(test_accuracy, "Test")
138-
test_result = os.path.join(opt.test_dir, "result.txt")
139-
with open(test_result, 'w') as t:
140-
for index, item in enumerate(test_accuracy):
141-
t.write("Attribute %d:\n" %(index))
142-
for top_k, value in item.iteritems():
143-
t.write("----Accuracy of Top%d: %f\n" %(top_k, value["ratio"]))
144-
print "#################Finished Testing################"
136+
logging.info("data_dir: " + opt.data_dir + "/TestSet/")
137+
logging.info("state_dict: " + opt.model_dir + "/" + opt.checkpoint_name)
138+
logging.info("score_thres:"+ str(opt.score_thres))
139+
for index, item in enumerate(test_accuracy):
140+
logging.info("Attribute %d:" %(index))
141+
for top_k, value in item.iteritems():
142+
logging.info("----Accuracy of Top%d: %f" %(top_k, value["ratio"]))
143+
logging.info("#################Finished Testing################")
145144

146145
def train(model, criterion, train_set, val_set, opt, labels=None):
147146
# define web visualizer using visdom
@@ -165,11 +164,11 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
165164
# record forward and backward times
166165
train_batch_num = len(train_set)
167166
total_batch_iter = 0
168-
print "####################Train Model###################"
167+
logging.info("####################Train Model###################")
169168
for epoch in range(opt.sum_epoch):
170169
epoch_start_t = time.time()
171170
epoch_batch_iter = 0
172-
print 'Begin of epoch %d' %(epoch)
171+
logging.info('Begin of epoch %d' %(epoch))
173172
for i, data in enumerate(train_set):
174173
iter_start_t = time.time()
175174
# train
@@ -215,7 +214,7 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
215214
webvis.plot_images(image_dict, opt.display_id + 2*opt.class_num, epoch, save_result)
216215

217216
# validate and display validate loss and accuracy
218-
if total_batch_iter % opt.display_validate_freq == 0 and len(val_set) > 0:
217+
if len(val_set) > 0 and total_batch_iter % opt.display_validate_freq == 0:
219218
val_accuracy, val_loss = validate(model, criterion, val_set, opt)
220219
x_axis = epoch + float(epoch_batch_iter)/train_batch_num
221220
accuracy_list = [val_accuracy[i][opt.top_k[0]]["ratio"] for i in range(len(val_accuracy))]
@@ -227,21 +226,21 @@ def train(model, criterion, train_set, val_set, opt, labels=None):
227226

228227
# save snapshot
229228
if total_batch_iter % opt.save_batch_iter_freq == 0:
230-
print "saving the latest model (epoch %d, total_batch_iter %d)" %(epoch, total_batch_iter)
229+
logging.info("saving the latest model (epoch %d, total_batch_iter %d)" %(epoch, total_batch_iter))
231230
util.save_model(model, opt, epoch)
232231
# TODO snapshot loss and accuracy
233232

234-
print('End of epoch %d / %d \t Time Taken: %d sec' %
233+
logging.info('End of epoch %d / %d \t Time Taken: %d sec' %
235234
(epoch, opt.sum_epoch, time.time() - epoch_start_t))
236235

237-
if epoch % opt.save_epoch_freq:
238-
print 'saving the model at the end of epoch %d, iters %d' %(epoch, total_batch_iter)
236+
if epoch % opt.save_epoch_freq == 0:
237+
logging.info('saving the model at the end of epoch %d, iters %d' %(epoch, total_batch_iter))
239238
util.save_model(model, opt, epoch)
240239

241240
# adjust learning rate
242241
scheduler.step()
243242
lr = optimizer.param_groups[0]['lr']
244-
print('learning rate = %.7f' % lr, 'epoch = %d' %(epoch))
243+
logging.info('learning rate = %.7f epoch = %d' %(lr,epoch))
245244

246245
def _load_model(opt, num_classes):
247246
# load model
@@ -251,7 +250,8 @@ def _load_model(opt, num_classes):
251250
tmp_output = templet(tmp_input)
252251
output_dim = int(tmp_output.size()[-1])
253252
model = BuildMultiLabelModel(templet, output_dim, num_classes)
254-
print model
253+
#print model
254+
logging.info(model)
255255

256256
# load exsiting model
257257
if opt.checkpoint_name != "":
@@ -283,6 +283,22 @@ def main():
283283
op = Options()
284284
opt = op.parse()
285285

286+
# log setting
287+
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
288+
formatter = logging.Formatter(log_format)
289+
if opt.mode == "Train":
290+
log_path = os.path.join(opt.model_dir, "train.log")
291+
else:
292+
log_path = os.path.join(opt.test_dir, "test.log")
293+
fh = logging.FileHandler(log_path, 'a')
294+
fh.setFormatter(formatter)
295+
ch = logging.StreamHandler()
296+
ch.setFormatter(formatter)
297+
logging.getLogger().addHandler(fh)
298+
logging.getLogger().addHandler(ch)
299+
log_level = logging.INFO
300+
logging.getLogger().setLevel(log_level)
301+
286302
# load train or test data
287303
data_loader = MultiLabelDataLoader(opt)
288304
if opt.mode == "Train":

util/util.py

+8-32
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,16 @@
1-
from __future__ import print_function
1+
import os
22
import torch
33
import numpy as np
4+
import logging
45
from PIL import Image
5-
import os
6-
76

8-
# Converts a Tensor into a Numpy array
9-
# |imtype|: the desired type of the converted numpy array
107
def tensor2im(image_tensor, imtype=np.uint8):
118
image_numpy = image_tensor.cpu().float().numpy()
129
if image_numpy.shape[0] == 1:
1310
image_numpy = np.tile(image_numpy, (3, 1, 1))
1411
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
1512
return image_numpy.astype(imtype)
1613

17-
def diagnose_network(net, name='network'):
18-
mean = 0.0
19-
count = 0
20-
for param in net.parameters():
21-
if param.grad is not None:
22-
mean += torch.mean(torch.abs(param.grad.data))
23-
count += 1
24-
if count > 0:
25-
mean = mean / count
26-
print(name)
27-
print(mean)
28-
2914
def save_image(image_numpy, image_path):
3015
image_pil = Image.fromarray(image_numpy)
3116
image_pil.save(image_path)
@@ -36,15 +21,6 @@ def save_model(model, opt, epoch):
3621
if opt.cuda and torch.cuda.is_available():
3722
model.cuda(opt.devices[0])
3823

39-
def print_numpy(x, val=True, shp=False):
40-
x = x.astype(np.float64)
41-
if shp:
42-
print('shape,', x.shape)
43-
if val:
44-
x = x.flatten()
45-
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
46-
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
47-
4824
def mkdirs(paths):
4925
if isinstance(paths, list) and not isinstance(paths, str):
5026
for path in paths:
@@ -62,19 +38,19 @@ def rmdir(path):
6238

6339
def print_loss(loss_list, label, epoch=0, batch_iter=0):
6440
if label == "Test":
65-
print("[ %s Loss ] of Test Dataset:" % (label))
41+
logging.info("[ %s Loss ] of Test Dataset:" % (label))
6642
else:
67-
print("[ %s Loss ] of Epoch %d Batch %d" % (label, epoch, batch_iter))
43+
logging.info("[ %s Loss ] of Epoch %d Batch %d" % (label, epoch, batch_iter))
6844

6945
for index, loss in enumerate(loss_list):
70-
print("----Attribute %d: %f" %(index, loss))
46+
logging.info("----Attribute %d: %f" %(index, loss))
7147

7248
def print_accuracy(accuracy_list, label, epoch=0, batch_iter=0):
7349
if label == "Test":
74-
print("[ %s Accuracy ] of Test Dataset:" % (label))
50+
logging.info("[ %s Accuracy ] of Test Dataset:" % (label))
7551
else:
76-
print("[ %s Accuracy ] of Epoch %d Batch %d" %(label, epoch, batch_iter))
52+
logging.info("[ %s Accuracy ] of Epoch %d Batch %d" %(label, epoch, batch_iter))
7753

7854
for index, item in enumerate(accuracy_list):
7955
for top_k, value in item.iteritems():
80-
print("----Attribute %d Top%d: %f" %(index, top_k, value["ratio"]))
56+
logging.info("----Attribute %d Top%d: %f" %(index, top_k, value["ratio"]))

0 commit comments

Comments
 (0)