Skip to content

Commit 76f93a4

Browse files
committed
Use train function for main scripts
1 parent 0c91c48 commit 76f93a4

File tree

3 files changed

+170
-632
lines changed

3 files changed

+170
-632
lines changed

main_clustering_loss.py

+65-214
Original file line numberDiff line numberDiff line change
@@ -5,231 +5,82 @@
55
@author: sakurai
66
"""
77

8-
import os
9-
import time
10-
import copy
11-
import numpy as np
12-
import matplotlib.pyplot as plt
13-
import six
14-
15-
import chainer
16-
from chainer import cuda
17-
from chainer import optimizers
18-
from tqdm import tqdm
198
import colorama
9+
2010
from sklearn.model_selection import ParameterSampler
2111

22-
from functions.clustering_loss import clustering_loss
23-
import common
24-
from datasets import data_provider
25-
from models.modified_googlenet import ModifiedGoogLeNet
26-
from common import UniformDistribution, LogUniformDistribution
12+
from lib.functions.clustering_loss import clustering_loss
13+
from lib.common.utils import (
14+
UniformDistribution, LogUniformDistribution, load_params)
15+
from lib.common.train_eval import train
2716

2817
colorama.init()
2918

3019

31-
def main(param_dict, save_distance_matrix=False):
32-
script_filename = os.path.splitext(os.path.basename(__file__))[0]
33-
device = 0
34-
xp = chainer.cuda.cupy
35-
config_parser = six.moves.configparser.ConfigParser()
36-
config_parser.read('config')
37-
log_dir_path = os.path.expanduser(config_parser.get('logs', 'dir_path'))
38-
39-
p = common.Logger(log_dir_path, **param_dict) # hyperparameters
40-
41-
##########################################################
42-
# load database
43-
##########################################################
44-
streams = data_provider.get_streams(p.batch_size, dataset=p.dataset,
45-
method='clustering')
46-
stream_train, stream_train_eval, stream_test = streams
47-
iter_train = stream_train.get_epoch_iterator()
48-
49-
##########################################################
50-
# construct the model
51-
##########################################################
52-
model = ModifiedGoogLeNet(p.out_dim, p.normalize_output)
53-
if device >= 0:
54-
model.to_gpu()
55-
xp = model.xp
56-
if p.optimizer == 'Adam':
57-
optimizer = optimizers.Adam(p.learning_rate)
58-
elif p.optimizer == 'RMSProp':
59-
optimizer = optimizers.RMSprop(p.learning_rate)
20+
def lossfun_one_batch(model, params, batch):
21+
x_data, c_data = batch
22+
x_data = model.xp.asarray(x_data)
23+
c_data = model.xp.asarray(c_data)
24+
25+
y = model(x_data)
26+
27+
# decay gamma at regular interval
28+
if type(params.gamma) is not float:
29+
params.gamma = params.gamma_init
30+
params.num_updates = 0
6031
else:
61-
raise ValueError
62-
optimizer.setup(model)
63-
optimizer.add_hook(chainer.optimizer.WeightDecay(p.l2_weight_decay))
64-
gamma = p.gamma_init
65-
66-
logger = common.Logger(log_dir_path)
67-
logger.soft_test_best = [0]
68-
time_origin = time.time()
69-
try:
70-
for epoch in range(p.num_epochs):
71-
time_begin = time.time()
72-
epoch_losses = []
73-
74-
for i in tqdm(range(p.num_batches_per_epoch),
75-
desc='# {}'.format(epoch)):
76-
# the first half of a batch are the anchors and the latters
77-
# are the positive examples corresponding to each anchor
78-
x_data, c_data = next(iter_train)
79-
if device >= 0:
80-
x_data = cuda.to_gpu(x_data, device)
81-
c_data = cuda.to_gpu(c_data, device)
82-
y = model(x_data, train=True)
83-
84-
loss = clustering_loss(y, c_data, gamma)
85-
optimizer.zero_grads()
86-
loss.backward()
87-
optimizer.update()
88-
89-
epoch_losses.append(loss.data)
90-
y = y_a = y_p = loss = None
91-
92-
loss_average = cuda.to_cpu(xp.array(
93-
xp.hstack(epoch_losses).mean()))
94-
95-
# average accuracy and distance matrix for training data
96-
D, soft, hard, retrieval = common.evaluate(
97-
model, stream_train_eval.get_epoch_iterator(), p.distance_type,
98-
return_distance_matrix=save_distance_matrix)
99-
100-
# average accuracy and distance matrix for testing data
101-
D_test, soft_test, hard_test, retrieval_test = common.evaluate(
102-
model, stream_test.get_epoch_iterator(), p.distance_type,
103-
return_distance_matrix=save_distance_matrix)
104-
105-
time_end = time.time()
106-
epoch_time = time_end - time_begin
107-
total_time = time_end - time_origin
108-
109-
logger.epoch = epoch
110-
logger.total_time = total_time
111-
logger.loss_log.append(loss_average)
112-
logger.train_log.append([soft[0], hard[0], retrieval[0]])
113-
logger.test_log.append(
114-
[soft_test[0], hard_test[0], retrieval_test[0]])
115-
116-
# retain the model if it scored the best test acc. ever
117-
if soft_test[0] > logger.soft_test_best[0]:
118-
logger.model_best = copy.deepcopy(model)
119-
logger.optimizer_best = copy.deepcopy(optimizer)
120-
logger.epoch_best = epoch
121-
logger.D_best = D
122-
logger.D_test_best = D_test
123-
logger.soft_best = soft
124-
logger.soft_test_best = soft_test
125-
logger.hard_best = hard
126-
logger.hard_test_best = hard_test
127-
logger.retrieval_best = retrieval
128-
logger.retrieval_test_best = retrieval_test
129-
130-
print("#", epoch)
131-
print("time: {} ({})".format(epoch_time, total_time))
132-
print("[train] loss:", loss_average)
133-
print("[train] soft:", soft)
134-
print("[train] hard:", hard)
135-
print("[train] retr:", retrieval)
136-
print("[test] soft:", soft_test)
137-
print("[test] hard:", hard_test)
138-
print("[test] retr:", retrieval_test)
139-
print("[best] soft: {} (at # {})".format(logger.soft_test_best,
140-
logger.epoch_best))
141-
print(p, 'gamma:{}'.format(gamma))
142-
# print norms of the weights
143-
params = xp.hstack([xp.linalg.norm(param.data)
144-
for param in model.params()]).tolist()
145-
print("|W|", map(lambda param: float('%0.2f' % param), params))
146-
print()
147-
148-
# Draw plots
149-
if save_distance_matrix:
150-
plt.figure(figsize=(8, 4))
151-
plt.subplot(1, 2, 1)
152-
mat = plt.matshow(D, fignum=0, cmap=plt.cm.gray)
153-
plt.colorbar(mat, fraction=0.045)
154-
plt.subplot(1, 2, 2)
155-
mat = plt.matshow(D_test, fignum=0, cmap=plt.cm.gray)
156-
plt.colorbar(mat, fraction=0.045)
157-
plt.tight_layout()
158-
159-
plt.figure(figsize=(8, 4))
160-
plt.subplot(1, 2, 1)
161-
plt.plot(logger.loss_log, label="tr-loss")
162-
plt.grid()
163-
plt.legend(loc='best')
164-
plt.subplot(1, 2, 2)
165-
plt.plot(logger.train_log)
166-
plt.plot(logger.test_log)
167-
plt.grid()
168-
plt.legend(["tr-soft", "tr-hard", "tr-retr",
169-
"te-soft", "te-hard", "te-retr"],
170-
bbox_to_anchor=(1.4, 1))
171-
plt.ylim([0.0, 1.0])
172-
plt.xlim([0, p.num_epochs])
173-
plt.tight_layout()
174-
plt.show()
175-
plt.draw()
176-
177-
loss = None
178-
D = None
179-
D_test = None
180-
181-
gamma *= p.gamma_decay
182-
183-
except KeyboardInterrupt:
184-
pass
185-
186-
dir_name = "-".join([script_filename, time.strftime("%Y%m%d%H%M%S"),
187-
str(logger.soft_test_best[0])])
188-
189-
logger.save(dir_name)
190-
p.save(dir_name)
191-
192-
print("total epochs: {} ({} [s])".format(logger.epoch, logger.total_time))
193-
print("best test score (at # {})".format(logger.epoch_best))
194-
print("[test] soft:", logger.soft_test_best)
195-
print("[test] hard:", logger.hard_test_best)
196-
print("[test] retr:", logger.retrieval_test_best)
197-
print(str(p).replace(', ', '\n'))
198-
print()
32+
if (params.num_updates != 0 and
33+
params.num_updates % params.num_batches_per_epoch == 0):
34+
params.gamma *= params.gamma_decay
35+
params.num_updates += 1
36+
37+
return clustering_loss(y, c_data, params.gamma)
19938

20039

20140
if __name__ == '__main__':
41+
param_filename = 'clustering_cub200_2011.yaml'
42+
random_search_mode = True
20243
random_state = None
20344
num_runs = 100000
20445
save_distance_matrix = False
205-
param_distributions = dict(
206-
learning_rate=LogUniformDistribution(low=1e-6, high=1e-4),
207-
gamma_init=LogUniformDistribution(low=1e+1, high=1e+4),
208-
gamma_decay=UniformDistribution(low=0.7, high=1.0),
209-
l2_weight_decay=LogUniformDistribution(low=1e-5, high=1e-2),
210-
optimizer=['RMSProp', 'Adam'] # 'RMSPeop' or 'Adam'
211-
)
212-
static_params = dict(
213-
num_epochs=15,
214-
num_batches_per_epoch=500,
215-
batch_size=120,
216-
out_dim=64,
217-
# learning_rate=0.0001,
218-
# gamma_init=10.0,
219-
# gamma_decay=0.94,
220-
crop_size=224,
221-
normalize_output=True,
222-
# l2_weight_decay=0, # non-negative constant
223-
# optimizer='RMSProp', # 'Adam' or 'RMSPeop'
224-
distance_type='euclidean', # 'euclidean' or 'cosine'
225-
dataset='cars196' # 'cars196' or 'cub200_2011' or 'products'
226-
)
227-
228-
sampler = ParameterSampler(param_distributions, num_runs, random_state)
229-
230-
for random_params in sampler:
231-
params = {}
232-
params.update(random_params)
233-
params.update(static_params)
234-
235-
main(params, save_distance_matrix)
46+
47+
if random_search_mode:
48+
param_distributions = dict(
49+
learning_rate=LogUniformDistribution(low=1e-6, high=1e-4),
50+
gamma_init=LogUniformDistribution(low=1e+1, high=1e+4),
51+
gamma_decay=UniformDistribution(low=0.7, high=1.0),
52+
l2_weight_decay=LogUniformDistribution(low=1e-5, high=1e-2),
53+
optimizer=['RMSProp', 'Adam'] # 'RMSPeop' or 'Adam'
54+
)
55+
static_params = dict(
56+
num_epochs=15,
57+
num_batches_per_epoch=500,
58+
batch_size=120,
59+
out_dim=64,
60+
# learning_rate=0.0001,
61+
# gamma_init=10.0,
62+
# gamma_decay=0.94,
63+
crop_size=224,
64+
normalize_output=True,
65+
# l2_weight_decay=0, # non-negative constant
66+
# optimizer='RMSProp', # 'Adam' or 'RMSPeop'
67+
distance_type='euclidean', # 'euclidean' or 'cosine'
68+
dataset='cub200_2011', # 'cars196' or 'cub200_2011' or 'products'
69+
method='n_pairs_mc' # sampling method for batch construction
70+
)
71+
72+
sampler = ParameterSampler(param_distributions, num_runs, random_state)
73+
74+
for random_params in sampler:
75+
params = {}
76+
params.update(random_params)
77+
params.update(static_params)
78+
79+
stop = train(__file__, lossfun_one_batch, params,
80+
save_distance_matrix)
81+
if stop:
82+
break
83+
else:
84+
print('Train once using config file "{}".'.format(param_filename))
85+
params = load_params(param_filename)
86+
train(__file__, lossfun_one_batch, params, save_distance_matrix)

0 commit comments

Comments
 (0)