Skip to content

Commit c83d662

Browse files
committed
Keras generator function now returns categorical label for each sample.
1 parent fef4005 commit c83d662

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

Augmentor/ImageUtilities.py

+13
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, image_path, output_directory):
4747
self._class_label = None
4848
self._class_label_int = None
4949
self._label_pair = None
50+
self._categorical_label = None
5051

5152
# Now we call the setters that we require.
5253
self.image_path = image_path
@@ -124,6 +125,14 @@ def class_label_int(self):
124125
def class_label_int(self, value):
125126
self._class_label_int = value
126127

128+
@property
129+
def categorical_label(self):
130+
return self._categorical_label
131+
132+
@categorical_label.setter
133+
def categorical_label(self, value):
134+
self._categorical_label = value
135+
127136
@property
128137
def ground_truth(self):
129138
"""
@@ -199,6 +208,7 @@ def scan(source_directory, abs_output_directory):
199208
a = AugmentorImage(image_path=image_path, output_directory=abs_output_directory)
200209
a.class_label = parent_directory_name
201210
a.class_label_int = label_counter
211+
a.categorical_label = np.ndarray(1, dtype=np.uint32) # TODO: Fix, as this is not good. Maybe leave as None.
202212
augmentor_images.append(a)
203213

204214
class_labels.append((label_counter, parent_directory_name))
@@ -211,9 +221,12 @@ def scan(source_directory, abs_output_directory):
211221
for d in directories:
212222
output_directory = os.path.join(abs_output_directory, os.path.split(d)[1])
213223
for image_path in scan_directory(d):
224+
categorical_label = np.zeros(directory_count, dtype=np.uint32)
214225
a = AugmentorImage(image_path=image_path, output_directory=output_directory)
215226
a.class_label = os.path.split(d)[1]
216227
a.class_label_int = label_counter
228+
categorical_label[label_counter] = 1
229+
a.categorical_label = categorical_label
217230
augmentor_images.append(a)
218231
class_labels.append((os.path.split(d)[1], label_counter))
219232
label_counter += 1

Augmentor/Pipeline.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ def keras_image_generator(self, image_format='channels_first'):
342342
:return: An image generator.
343343
"""
344344

345+
# TODO: Always return at least the original dataset as well as the augmented dataset
346+
345347
while True:
346348
batch_indices = list(range(0, len(self.augmentor_images)))
347349
for i in range(0, len(self.augmentor_images)):
@@ -352,19 +354,20 @@ def keras_image_generator(self, image_format='channels_first'):
352354

353355
if image_format == 'channels_first':
354356
num_of_channels = len(im_PIL.getbands())
355-
im_array = im_array.reshape(num_of_channels, im_PIL.width, im_PIL.height)
356-
yield im_array, self.augmentor_images[im_index].class_label_int
357+
im_array = im_array.reshape(1, num_of_channels, im_PIL.width, im_PIL.height)
358+
yield im_array, self.augmentor_images[im_index].categorical_label
357359
elif image_format == 'channels_last':
358360
num_of_channels = len(im_PIL.getbands())
359-
im_array = im_array.reshape(im_PIL.width, im_PIL.height, num_of_channels)
360-
yield im_array, self.augmentor_images[im_index].class_label_int
361+
im_array = im_array.reshape(1, im_PIL.width, im_PIL.height, num_of_channels)
362+
yield im_array, self.augmentor_images[im_index].categorical_label
361363

362364
def keras_image_generator_with_replacement(self, image_format='channels_first'):
363-
while True:
364-
im_index = random.randint(0, len(self.augmentor_images))
365-
im_PIL = self._execute(self.augmentor_images[im_index], save_to_disk=False)
366-
im_array = np.asarray(im_PIL)
367-
yield im_array, self.augmentor_images[im_index].class_label_int
365+
raise NotImplementedError("This method is currently not implemented. Use keras_image_generator().")
366+
#while True:
367+
# im_index = random.randint(0, len(self.augmentor_images))
368+
# im_PIL = self._execute(self.augmentor_images[im_index], save_to_disk=False)
369+
# im_array = np.asarray(im_PIL)
370+
# yield im_array, self.augmentor_images[im_index].class_label_int
368371

369372
def add_operation(self, operation):
370373
"""

0 commit comments

Comments
 (0)