@@ -245,6 +245,7 @@ def train(self):
245
245
if self .pretrained_iteration_count >= self .iterations :
246
246
print (f'pretrained iteration count { self .pretrained_iteration_count } is greater or equal than target iterations { self .iterations } ' )
247
247
exit (0 )
248
+
248
249
self .model .summary ()
249
250
print (f'\n train on { len (self .train_image_paths )} samples' )
250
251
print (f'validate on { len (self .validation_image_paths )} samples\n ' )
@@ -256,54 +257,51 @@ def train(self):
256
257
eta_calculator = ETACalculator (iterations = self .iterations , start_iteration = iteration_count )
257
258
eta_calculator .start ()
258
259
while True :
259
- for idx , ( batch_x , batch_y ) in enumerate ( self .train_data_generator .flow ()):
260
- lr_scheduler .update (optimizer , iteration_count )
261
- loss = self .compute_gradient (self .model , optimizer , batch_x , batch_y , loss_function )
262
- if self .show_class_activation_map and iteration_count % 100 == 0 :
263
- try_count = 0
264
- while True :
265
- if try_count > len (batch_x ):
266
- break
267
- rnum = random .randint (0 , len (batch_x ) - 1 )
268
- if np .all (batch_y [rnum ] < 0.3 ): # skip cam view if unknown data
269
- continue
270
- else :
271
- new_input_tensor = batch_x [rnum ]
272
- label_idx = np .argmax (batch_y [rnum ]).item ()
273
- break
274
- self .draw_cam (new_input_tensor , label_idx )
275
- if self .live_loss_plot_flag :
276
- self .live_loss_plot .update (loss )
277
- iteration_count += 1
278
- progress_str = eta_calculator .update (iteration_count )
279
- self .print_loss (progress_str , loss )
280
- if iteration_count % 2000 == 0 :
281
- self .save_last_model (self .model , iteration_count )
282
- if iteration_count == self .iterations :
283
- self .save_last_model (self .model , iteration_count )
284
- self .save_model (iteration_count )
285
- self .remove_last_model ()
286
- print ('train end successfully' )
287
- exit (0 )
288
- elif iteration_count >= int (self .iterations * self .warm_up ) and self .checkpoint_interval > 0 and iteration_count % self .checkpoint_interval == 0 :
289
- self .save_model (iteration_count )
260
+ batch_x , batch_y = self .train_data_generator .load ()
261
+ lr_scheduler .update (optimizer , iteration_count )
262
+ loss = self .compute_gradient (self .model , optimizer , batch_x , batch_y , loss_function )
263
+ if self .show_class_activation_map and iteration_count % 100 == 0 :
264
+ try_count = 0
265
+ while True :
266
+ if try_count > len (batch_x ):
267
+ break
268
+ rnum = random .randint (0 , len (batch_x ) - 1 )
269
+ if np .all (batch_y [rnum ] < 0.3 ): # skip cam view if unknown data
270
+ continue
271
+ else :
272
+ new_input_tensor = batch_x [rnum ]
273
+ label_idx = np .argmax (batch_y [rnum ]).item ()
274
+ break
275
+ self .draw_cam (new_input_tensor , label_idx )
276
+ if self .live_loss_plot_flag :
277
+ self .live_loss_plot .update (loss )
278
+ iteration_count += 1
279
+ progress_str = eta_calculator .update (iteration_count )
280
+ self .print_loss (progress_str , loss )
281
+ if iteration_count % 2000 == 0 :
282
+ self .save_last_model (self .model , iteration_count )
283
+ if iteration_count == self .iterations :
284
+ self .save_last_model (self .model , iteration_count )
285
+ self .save_model (iteration_count )
286
+ self .remove_last_model ()
287
+ print ('train end successfully' )
288
+ exit (0 )
289
+ elif iteration_count >= int (self .iterations * self .warm_up ) and self .checkpoint_interval > 0 and iteration_count % self .checkpoint_interval == 0 :
290
+ self .save_model (iteration_count )
290
291
291
292
def save_model (self , iteration_count ):
292
293
print ()
293
- if self .validation_data_generator .flow () is None :
294
- self .save_last_model (self .model , iteration_count )
294
+ val_acc , val_class_score , val_unknown_score = self .evaluate (unknown_threshold = 0.5 , dataset = 'validation' )
295
+ model_name = f'model_{ iteration_count } _iter_acc_{ val_acc :.4f} _class_score_{ val_class_score :.4f} '
296
+ if self .include_unknown :
297
+ model_name += f'_unknown_score_{ val_unknown_score :.4f} '
298
+ if val_acc > self .max_val_acc :
299
+ self .max_val_acc = val_acc
300
+ model_name = f'{ self .checkpoint_path } /best_{ model_name } .h5'
301
+ print (f'[best model saved]\n ' )
295
302
else :
296
- val_acc , val_class_score , val_unknown_score = self .evaluate (unknown_threshold = 0.5 , dataset = 'validation' )
297
- model_name = f'model_{ iteration_count } _iter_acc_{ val_acc :.4f} _class_score_{ val_class_score :.4f} '
298
- if self .include_unknown :
299
- model_name += f'_unknown_score_{ val_unknown_score :.4f} '
300
- if val_acc > self .max_val_acc :
301
- self .max_val_acc = val_acc
302
- model_name = f'{ self .checkpoint_path } /best_{ model_name } .h5'
303
- print (f'[best model saved]\n ' )
304
- else :
305
- model_name = f'{ self .checkpoint_path } /{ model_name } .h5'
306
- self .model .save (model_name , include_optimizer = False )
303
+ model_name = f'{ self .checkpoint_path } /{ model_name } .h5'
304
+ self .model .save (model_name , include_optimizer = False )
307
305
308
306
def evaluate (self , dataset , unknown_threshold = 0.5 ):
309
307
assert dataset in ['train' , 'validation' ]
@@ -322,7 +320,8 @@ def graph_forward(model, x):
322
320
hit_unknown_count = total_unknown_count = 0
323
321
hit_scores = np .zeros (shape = (num_classes ,), dtype = np .float32 )
324
322
unknown_score_sum = 0.0
325
- for batch_x , batch_y in tqdm (data_generator .flow ()):
323
+ for _ in tqdm (range (len (data_generator ))):
324
+ batch_x , batch_y = data_generator .load ()
326
325
y = graph_forward (self .model , batch_x )[0 ]
327
326
max_score_index = np .argmax (y )
328
327
max_score = y [max_score_index ]
0 commit comments