|
5 | 5 | @author: sakurai
|
6 | 6 | """
|
7 | 7 |
|
8 |
| -import os |
9 |
| -import time |
10 |
| -import copy |
11 |
| -import numpy as np |
12 |
| -import matplotlib.pyplot as plt |
13 |
| -import six |
14 |
| - |
15 |
| -import chainer |
16 |
| -from chainer import cuda |
17 |
| -from chainer import optimizers |
18 |
| -from tqdm import tqdm |
19 | 8 | import colorama
|
| 9 | + |
20 | 10 | from sklearn.model_selection import ParameterSampler
|
21 | 11 |
|
22 |
| -from functions.clustering_loss import clustering_loss |
23 |
| -import common |
24 |
| -from datasets import data_provider |
25 |
| -from models.modified_googlenet import ModifiedGoogLeNet |
26 |
| -from common import UniformDistribution, LogUniformDistribution |
| 12 | +from lib.functions.clustering_loss import clustering_loss |
| 13 | +from lib.common.utils import ( |
| 14 | + UniformDistribution, LogUniformDistribution, load_params) |
| 15 | +from lib.common.train_eval import train |
27 | 16 |
|
28 | 17 | colorama.init()
|
29 | 18 |
|
30 | 19 |
|
31 |
| -def main(param_dict, save_distance_matrix=False): |
32 |
| - script_filename = os.path.splitext(os.path.basename(__file__))[0] |
33 |
| - device = 0 |
34 |
| - xp = chainer.cuda.cupy |
35 |
| - config_parser = six.moves.configparser.ConfigParser() |
36 |
| - config_parser.read('config') |
37 |
| - log_dir_path = os.path.expanduser(config_parser.get('logs', 'dir_path')) |
38 |
| - |
39 |
| - p = common.Logger(log_dir_path, **param_dict) # hyperparameters |
40 |
| - |
41 |
| - ########################################################## |
42 |
| - # load database |
43 |
| - ########################################################## |
44 |
| - streams = data_provider.get_streams(p.batch_size, dataset=p.dataset, |
45 |
| - method='clustering') |
46 |
| - stream_train, stream_train_eval, stream_test = streams |
47 |
| - iter_train = stream_train.get_epoch_iterator() |
48 |
| - |
49 |
| - ########################################################## |
50 |
| - # construct the model |
51 |
| - ########################################################## |
52 |
| - model = ModifiedGoogLeNet(p.out_dim, p.normalize_output) |
53 |
| - if device >= 0: |
54 |
| - model.to_gpu() |
55 |
| - xp = model.xp |
56 |
| - if p.optimizer == 'Adam': |
57 |
| - optimizer = optimizers.Adam(p.learning_rate) |
58 |
| - elif p.optimizer == 'RMSProp': |
59 |
| - optimizer = optimizers.RMSprop(p.learning_rate) |
| 20 | +def lossfun_one_batch(model, params, batch): |
| 21 | + x_data, c_data = batch |
| 22 | + x_data = model.xp.asarray(x_data) |
| 23 | + c_data = model.xp.asarray(c_data) |
| 24 | + |
| 25 | + y = model(x_data) |
| 26 | + |
| 27 | + # decay gamma at regular interval |
| 28 | + if type(params.gamma) is not float: |
| 29 | + params.gamma = params.gamma_init |
| 30 | + params.num_updates = 0 |
60 | 31 | else:
|
61 |
| - raise ValueError |
62 |
| - optimizer.setup(model) |
63 |
| - optimizer.add_hook(chainer.optimizer.WeightDecay(p.l2_weight_decay)) |
64 |
| - gamma = p.gamma_init |
65 |
| - |
66 |
| - logger = common.Logger(log_dir_path) |
67 |
| - logger.soft_test_best = [0] |
68 |
| - time_origin = time.time() |
69 |
| - try: |
70 |
| - for epoch in range(p.num_epochs): |
71 |
| - time_begin = time.time() |
72 |
| - epoch_losses = [] |
73 |
| - |
74 |
| - for i in tqdm(range(p.num_batches_per_epoch), |
75 |
| - desc='# {}'.format(epoch)): |
76 |
| - # the first half of a batch are the anchors and the latters |
77 |
| - # are the positive examples corresponding to each anchor |
78 |
| - x_data, c_data = next(iter_train) |
79 |
| - if device >= 0: |
80 |
| - x_data = cuda.to_gpu(x_data, device) |
81 |
| - c_data = cuda.to_gpu(c_data, device) |
82 |
| - y = model(x_data, train=True) |
83 |
| - |
84 |
| - loss = clustering_loss(y, c_data, gamma) |
85 |
| - optimizer.zero_grads() |
86 |
| - loss.backward() |
87 |
| - optimizer.update() |
88 |
| - |
89 |
| - epoch_losses.append(loss.data) |
90 |
| - y = y_a = y_p = loss = None |
91 |
| - |
92 |
| - loss_average = cuda.to_cpu(xp.array( |
93 |
| - xp.hstack(epoch_losses).mean())) |
94 |
| - |
95 |
| - # average accuracy and distance matrix for training data |
96 |
| - D, soft, hard, retrieval = common.evaluate( |
97 |
| - model, stream_train_eval.get_epoch_iterator(), p.distance_type, |
98 |
| - return_distance_matrix=save_distance_matrix) |
99 |
| - |
100 |
| - # average accuracy and distance matrix for testing data |
101 |
| - D_test, soft_test, hard_test, retrieval_test = common.evaluate( |
102 |
| - model, stream_test.get_epoch_iterator(), p.distance_type, |
103 |
| - return_distance_matrix=save_distance_matrix) |
104 |
| - |
105 |
| - time_end = time.time() |
106 |
| - epoch_time = time_end - time_begin |
107 |
| - total_time = time_end - time_origin |
108 |
| - |
109 |
| - logger.epoch = epoch |
110 |
| - logger.total_time = total_time |
111 |
| - logger.loss_log.append(loss_average) |
112 |
| - logger.train_log.append([soft[0], hard[0], retrieval[0]]) |
113 |
| - logger.test_log.append( |
114 |
| - [soft_test[0], hard_test[0], retrieval_test[0]]) |
115 |
| - |
116 |
| - # retain the model if it scored the best test acc. ever |
117 |
| - if soft_test[0] > logger.soft_test_best[0]: |
118 |
| - logger.model_best = copy.deepcopy(model) |
119 |
| - logger.optimizer_best = copy.deepcopy(optimizer) |
120 |
| - logger.epoch_best = epoch |
121 |
| - logger.D_best = D |
122 |
| - logger.D_test_best = D_test |
123 |
| - logger.soft_best = soft |
124 |
| - logger.soft_test_best = soft_test |
125 |
| - logger.hard_best = hard |
126 |
| - logger.hard_test_best = hard_test |
127 |
| - logger.retrieval_best = retrieval |
128 |
| - logger.retrieval_test_best = retrieval_test |
129 |
| - |
130 |
| - print("#", epoch) |
131 |
| - print("time: {} ({})".format(epoch_time, total_time)) |
132 |
| - print("[train] loss:", loss_average) |
133 |
| - print("[train] soft:", soft) |
134 |
| - print("[train] hard:", hard) |
135 |
| - print("[train] retr:", retrieval) |
136 |
| - print("[test] soft:", soft_test) |
137 |
| - print("[test] hard:", hard_test) |
138 |
| - print("[test] retr:", retrieval_test) |
139 |
| - print("[best] soft: {} (at # {})".format(logger.soft_test_best, |
140 |
| - logger.epoch_best)) |
141 |
| - print(p, 'gamma:{}'.format(gamma)) |
142 |
| - # print norms of the weights |
143 |
| - params = xp.hstack([xp.linalg.norm(param.data) |
144 |
| - for param in model.params()]).tolist() |
145 |
| - print("|W|", map(lambda param: float('%0.2f' % param), params)) |
146 |
| - print() |
147 |
| - |
148 |
| - # Draw plots |
149 |
| - if save_distance_matrix: |
150 |
| - plt.figure(figsize=(8, 4)) |
151 |
| - plt.subplot(1, 2, 1) |
152 |
| - mat = plt.matshow(D, fignum=0, cmap=plt.cm.gray) |
153 |
| - plt.colorbar(mat, fraction=0.045) |
154 |
| - plt.subplot(1, 2, 2) |
155 |
| - mat = plt.matshow(D_test, fignum=0, cmap=plt.cm.gray) |
156 |
| - plt.colorbar(mat, fraction=0.045) |
157 |
| - plt.tight_layout() |
158 |
| - |
159 |
| - plt.figure(figsize=(8, 4)) |
160 |
| - plt.subplot(1, 2, 1) |
161 |
| - plt.plot(logger.loss_log, label="tr-loss") |
162 |
| - plt.grid() |
163 |
| - plt.legend(loc='best') |
164 |
| - plt.subplot(1, 2, 2) |
165 |
| - plt.plot(logger.train_log) |
166 |
| - plt.plot(logger.test_log) |
167 |
| - plt.grid() |
168 |
| - plt.legend(["tr-soft", "tr-hard", "tr-retr", |
169 |
| - "te-soft", "te-hard", "te-retr"], |
170 |
| - bbox_to_anchor=(1.4, 1)) |
171 |
| - plt.ylim([0.0, 1.0]) |
172 |
| - plt.xlim([0, p.num_epochs]) |
173 |
| - plt.tight_layout() |
174 |
| - plt.show() |
175 |
| - plt.draw() |
176 |
| - |
177 |
| - loss = None |
178 |
| - D = None |
179 |
| - D_test = None |
180 |
| - |
181 |
| - gamma *= p.gamma_decay |
182 |
| - |
183 |
| - except KeyboardInterrupt: |
184 |
| - pass |
185 |
| - |
186 |
| - dir_name = "-".join([script_filename, time.strftime("%Y%m%d%H%M%S"), |
187 |
| - str(logger.soft_test_best[0])]) |
188 |
| - |
189 |
| - logger.save(dir_name) |
190 |
| - p.save(dir_name) |
191 |
| - |
192 |
| - print("total epochs: {} ({} [s])".format(logger.epoch, logger.total_time)) |
193 |
| - print("best test score (at # {})".format(logger.epoch_best)) |
194 |
| - print("[test] soft:", logger.soft_test_best) |
195 |
| - print("[test] hard:", logger.hard_test_best) |
196 |
| - print("[test] retr:", logger.retrieval_test_best) |
197 |
| - print(str(p).replace(', ', '\n')) |
198 |
| - print() |
| 32 | + if (params.num_updates != 0 and |
| 33 | + params.num_updates % params.num_batches_per_epoch == 0): |
| 34 | + params.gamma *= params.gamma_decay |
| 35 | + params.num_updates += 1 |
| 36 | + |
| 37 | + return clustering_loss(y, c_data, params.gamma) |
199 | 38 |
|
200 | 39 |
|
201 | 40 | if __name__ == '__main__':
|
| 41 | + param_filename = 'clustering_cub200_2011.yaml' |
| 42 | + random_search_mode = True |
202 | 43 | random_state = None
|
203 | 44 | num_runs = 100000
|
204 | 45 | save_distance_matrix = False
|
205 |
| - param_distributions = dict( |
206 |
| - learning_rate=LogUniformDistribution(low=1e-6, high=1e-4), |
207 |
| - gamma_init=LogUniformDistribution(low=1e+1, high=1e+4), |
208 |
| - gamma_decay=UniformDistribution(low=0.7, high=1.0), |
209 |
| - l2_weight_decay=LogUniformDistribution(low=1e-5, high=1e-2), |
210 |
| - optimizer=['RMSProp', 'Adam'] # 'RMSPeop' or 'Adam' |
211 |
| - ) |
212 |
| - static_params = dict( |
213 |
| - num_epochs=15, |
214 |
| - num_batches_per_epoch=500, |
215 |
| - batch_size=120, |
216 |
| - out_dim=64, |
217 |
| -# learning_rate=0.0001, |
218 |
| -# gamma_init=10.0, |
219 |
| -# gamma_decay=0.94, |
220 |
| - crop_size=224, |
221 |
| - normalize_output=True, |
222 |
| -# l2_weight_decay=0, # non-negative constant |
223 |
| -# optimizer='RMSProp', # 'Adam' or 'RMSPeop' |
224 |
| - distance_type='euclidean', # 'euclidean' or 'cosine' |
225 |
| - dataset='cars196' # 'cars196' or 'cub200_2011' or 'products' |
226 |
| - ) |
227 |
| - |
228 |
| - sampler = ParameterSampler(param_distributions, num_runs, random_state) |
229 |
| - |
230 |
| - for random_params in sampler: |
231 |
| - params = {} |
232 |
| - params.update(random_params) |
233 |
| - params.update(static_params) |
234 |
| - |
235 |
| - main(params, save_distance_matrix) |
| 46 | + |
| 47 | + if random_search_mode: |
| 48 | + param_distributions = dict( |
| 49 | + learning_rate=LogUniformDistribution(low=1e-6, high=1e-4), |
| 50 | + gamma_init=LogUniformDistribution(low=1e+1, high=1e+4), |
| 51 | + gamma_decay=UniformDistribution(low=0.7, high=1.0), |
| 52 | + l2_weight_decay=LogUniformDistribution(low=1e-5, high=1e-2), |
| 53 | + optimizer=['RMSProp', 'Adam'] # 'RMSPeop' or 'Adam' |
| 54 | + ) |
| 55 | + static_params = dict( |
| 56 | + num_epochs=15, |
| 57 | + num_batches_per_epoch=500, |
| 58 | + batch_size=120, |
| 59 | + out_dim=64, |
| 60 | +# learning_rate=0.0001, |
| 61 | +# gamma_init=10.0, |
| 62 | +# gamma_decay=0.94, |
| 63 | + crop_size=224, |
| 64 | + normalize_output=True, |
| 65 | +# l2_weight_decay=0, # non-negative constant |
| 66 | +# optimizer='RMSProp', # 'Adam' or 'RMSPeop' |
| 67 | + distance_type='euclidean', # 'euclidean' or 'cosine' |
| 68 | + dataset='cub200_2011', # 'cars196' or 'cub200_2011' or 'products' |
| 69 | + method='n_pairs_mc' # sampling method for batch construction |
| 70 | + ) |
| 71 | + |
| 72 | + sampler = ParameterSampler(param_distributions, num_runs, random_state) |
| 73 | + |
| 74 | + for random_params in sampler: |
| 75 | + params = {} |
| 76 | + params.update(random_params) |
| 77 | + params.update(static_params) |
| 78 | + |
| 79 | + stop = train(__file__, lossfun_one_batch, params, |
| 80 | + save_distance_matrix) |
| 81 | + if stop: |
| 82 | + break |
| 83 | + else: |
| 84 | + print('Train once using config file "{}".'.format(param_filename)) |
| 85 | + params = load_params(param_filename) |
| 86 | + train(__file__, lossfun_one_batch, params, save_distance_matrix) |
0 commit comments