Skip to content

Commit 2721614

Browse files
committed
refactor data generator using load function
1 parent 56581e6 commit 2721614

File tree

2 files changed

+58
-74
lines changed

2 files changed

+58
-74
lines changed

generator.py

+14-29
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,8 @@
2525
from concurrent.futures.thread import ThreadPoolExecutor
2626

2727

28-
class DataGenerator:
28+
class DataGenerator(tf.keras.utils.Sequence):
2929
def __init__(self, root_path, image_paths, input_shape, batch_size, class_names, aug_brightness=0.0, aug_contrast=0.0, aug_rotate=0, aug_h_flip=False):
30-
self.generator_flow = GeneratorFlow(root_path, image_paths, class_names, input_shape, batch_size, aug_brightness, aug_contrast, aug_rotate, aug_h_flip)
31-
32-
def flow(self):
33-
return self.generator_flow
34-
35-
36-
class GeneratorFlow(tf.keras.utils.Sequence):
37-
def __init__(self, root_path, image_paths, class_names, input_shape, batch_size, aug_brightness, aug_contrast, aug_rotate, aug_h_flip):
3830
assert 0.0 <= aug_brightness <= 1.0
3931
assert 0.0 <= aug_contrast <= 1.0
4032
assert type(aug_h_flip) == bool
@@ -59,17 +51,19 @@ def __init__(self, root_path, image_paths, class_names, input_shape, batch_size,
5951
self.transform = A.Compose(aug_methods)
6052
self.augmentation = len(aug_methods) > 0
6153

62-
def __getitem__(self, index):
54+
def __len__(self):
55+
return len(self.image_paths) // self.batch_size
56+
57+
def load(self):
6358
fs = []
64-
for i in range(self.batch_size):
59+
for _ in range(self.batch_size):
6560
fs.append(self.pool.submit(self.load_img, self.get_next_image_path()))
6661
batch_x = []
6762
batch_y = []
6863
for f in fs:
6964
img, path = f.result()
7065
x = self.preprocess(img, aug=self.augmentation)
7166
batch_x.append(x)
72-
7367
dir_name = path.replace(self.root_path, '').split('/')[1]
7468
y = np.zeros((self.num_classes,), dtype=np.float32)
7569
if dir_name != 'unknown':
@@ -79,23 +73,6 @@ def __getitem__(self, index):
7973
batch_y = np.asarray(batch_y).reshape((self.batch_size, self.num_classes)).astype('float32')
8074
return batch_x, batch_y
8175

82-
def __len__(self):
83-
return int(np.floor(len(self.image_paths) / self.batch_size))
84-
85-
def get_next_image_path(self):
86-
path = self.image_paths[self.img_index]
87-
self.img_index += 1
88-
if self.img_index == len(self.image_paths):
89-
self.img_index = 0
90-
np.random.shuffle(self.image_paths)
91-
return path
92-
93-
def random_blur(self, img):
94-
if np.random.rand() > 0.5:
95-
kernel_size = np.random.choice([3, 5])
96-
img = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
97-
return img
98-
9976
def preprocess(self, img, aug=False):
10077
img = cv2.resize(img, (self.input_shape[1], self.input_shape[0]))
10178
if aug:
@@ -105,6 +82,14 @@ def preprocess(self, img, aug=False):
10582
x = np.asarray(img).reshape(self.input_shape).astype('float32') / 255.0
10683
return x
10784

85+
def get_next_image_path(self):
86+
path = self.image_paths[self.img_index]
87+
self.img_index += 1
88+
if self.img_index == len(self.image_paths):
89+
self.img_index = 0
90+
np.random.shuffle(self.image_paths)
91+
return path
92+
10893
def load_img(self, path):
10994
img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_GRAYSCALE if self.input_shape[-1] == 1 else cv2.IMREAD_COLOR)
11095
return img, path

sigmoid_classifier.py

+44-45
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def train(self):
245245
if self.pretrained_iteration_count >= self.iterations:
246246
print(f'pretrained iteration count {self.pretrained_iteration_count} is greater or equal than target iterations {self.iterations}')
247247
exit(0)
248+
248249
self.model.summary()
249250
print(f'\ntrain on {len(self.train_image_paths)} samples')
250251
print(f'validate on {len(self.validation_image_paths)} samples\n')
@@ -256,54 +257,51 @@ def train(self):
256257
eta_calculator = ETACalculator(iterations=self.iterations, start_iteration=iteration_count)
257258
eta_calculator.start()
258259
while True:
259-
for idx, (batch_x, batch_y) in enumerate(self.train_data_generator.flow()):
260-
lr_scheduler.update(optimizer, iteration_count)
261-
loss = self.compute_gradient(self.model, optimizer, batch_x, batch_y, loss_function)
262-
if self.show_class_activation_map and iteration_count % 100 == 0:
263-
try_count = 0
264-
while True:
265-
if try_count > len(batch_x):
266-
break
267-
rnum = random.randint(0, len(batch_x) - 1)
268-
if np.all(batch_y[rnum] < 0.3): # skip cam view if unknown data
269-
continue
270-
else:
271-
new_input_tensor = batch_x[rnum]
272-
label_idx = np.argmax(batch_y[rnum]).item()
273-
break
274-
self.draw_cam(new_input_tensor, label_idx)
275-
if self.live_loss_plot_flag:
276-
self.live_loss_plot.update(loss)
277-
iteration_count += 1
278-
progress_str = eta_calculator.update(iteration_count)
279-
self.print_loss(progress_str, loss)
280-
if iteration_count % 2000 == 0:
281-
self.save_last_model(self.model, iteration_count)
282-
if iteration_count == self.iterations:
283-
self.save_last_model(self.model, iteration_count)
284-
self.save_model(iteration_count)
285-
self.remove_last_model()
286-
print('train end successfully')
287-
exit(0)
288-
elif iteration_count >= int(self.iterations * self.warm_up) and self.checkpoint_interval > 0 and iteration_count % self.checkpoint_interval == 0:
289-
self.save_model(iteration_count)
260+
batch_x, batch_y = self.train_data_generator.load()
261+
lr_scheduler.update(optimizer, iteration_count)
262+
loss = self.compute_gradient(self.model, optimizer, batch_x, batch_y, loss_function)
263+
if self.show_class_activation_map and iteration_count % 100 == 0:
264+
try_count = 0
265+
while True:
266+
if try_count > len(batch_x):
267+
break
268+
rnum = random.randint(0, len(batch_x) - 1)
269+
if np.all(batch_y[rnum] < 0.3): # skip cam view if unknown data
270+
continue
271+
else:
272+
new_input_tensor = batch_x[rnum]
273+
label_idx = np.argmax(batch_y[rnum]).item()
274+
break
275+
self.draw_cam(new_input_tensor, label_idx)
276+
if self.live_loss_plot_flag:
277+
self.live_loss_plot.update(loss)
278+
iteration_count += 1
279+
progress_str = eta_calculator.update(iteration_count)
280+
self.print_loss(progress_str, loss)
281+
if iteration_count % 2000 == 0:
282+
self.save_last_model(self.model, iteration_count)
283+
if iteration_count == self.iterations:
284+
self.save_last_model(self.model, iteration_count)
285+
self.save_model(iteration_count)
286+
self.remove_last_model()
287+
print('train end successfully')
288+
exit(0)
289+
elif iteration_count >= int(self.iterations * self.warm_up) and self.checkpoint_interval > 0 and iteration_count % self.checkpoint_interval == 0:
290+
self.save_model(iteration_count)
290291

291292
def save_model(self, iteration_count):
292293
print()
293-
if self.validation_data_generator.flow() is None:
294-
self.save_last_model(self.model, iteration_count)
294+
val_acc, val_class_score, val_unknown_score = self.evaluate(unknown_threshold=0.5, dataset='validation')
295+
model_name = f'model_{iteration_count}_iter_acc_{val_acc:.4f}_class_score_{val_class_score:.4f}'
296+
if self.include_unknown:
297+
model_name += f'_unknown_score_{val_unknown_score:.4f}'
298+
if val_acc > self.max_val_acc:
299+
self.max_val_acc = val_acc
300+
model_name = f'{self.checkpoint_path}/best_{model_name}.h5'
301+
print(f'[best model saved]\n')
295302
else:
296-
val_acc, val_class_score, val_unknown_score = self.evaluate(unknown_threshold=0.5, dataset='validation')
297-
model_name = f'model_{iteration_count}_iter_acc_{val_acc:.4f}_class_score_{val_class_score:.4f}'
298-
if self.include_unknown:
299-
model_name += f'_unknown_score_{val_unknown_score:.4f}'
300-
if val_acc > self.max_val_acc:
301-
self.max_val_acc = val_acc
302-
model_name = f'{self.checkpoint_path}/best_{model_name}.h5'
303-
print(f'[best model saved]\n')
304-
else:
305-
model_name = f'{self.checkpoint_path}/{model_name}.h5'
306-
self.model.save(model_name, include_optimizer=False)
303+
model_name = f'{self.checkpoint_path}/{model_name}.h5'
304+
self.model.save(model_name, include_optimizer=False)
307305

308306
def evaluate(self, dataset, unknown_threshold=0.5):
309307
assert dataset in ['train', 'validation']
@@ -322,7 +320,8 @@ def graph_forward(model, x):
322320
hit_unknown_count = total_unknown_count = 0
323321
hit_scores = np.zeros(shape=(num_classes,), dtype=np.float32)
324322
unknown_score_sum = 0.0
325-
for batch_x, batch_y in tqdm(data_generator.flow()):
323+
for _ in tqdm(range(len(data_generator))):
324+
batch_x, batch_y = data_generator.load()
326325
y = graph_forward(self.model, batch_x)[0]
327326
max_score_index = np.argmax(y)
328327
max_score = y[max_score_index]

0 commit comments

Comments
 (0)