@@ -342,6 +342,8 @@ def keras_image_generator(self, image_format='channels_first'):
342
342
:return: An image generator.
343
343
"""
344
344
345
+ # TODO: Always return at least the original dataset as well as the augmented dataset
346
+
345
347
while True :
346
348
batch_indices = list (range (0 , len (self .augmentor_images )))
347
349
for i in range (0 , len (self .augmentor_images )):
@@ -352,19 +354,20 @@ def keras_image_generator(self, image_format='channels_first'):
352
354
353
355
if image_format == 'channels_first' :
354
356
num_of_channels = len (im_PIL .getbands ())
355
- im_array = im_array .reshape (num_of_channels , im_PIL .width , im_PIL .height )
356
- yield im_array , self .augmentor_images [im_index ].class_label_int
357
+ im_array = im_array .reshape (1 , num_of_channels , im_PIL .width , im_PIL .height )
358
+ yield im_array , self .augmentor_images [im_index ].categorical_label
357
359
elif image_format == 'channels_last' :
358
360
num_of_channels = len (im_PIL .getbands ())
359
- im_array = im_array .reshape (im_PIL .width , im_PIL .height , num_of_channels )
360
- yield im_array , self .augmentor_images [im_index ].class_label_int
361
+ im_array = im_array .reshape (1 , im_PIL .width , im_PIL .height , num_of_channels )
362
+ yield im_array , self .augmentor_images [im_index ].categorical_label
361
363
362
364
def keras_image_generator_with_replacement (self , image_format = 'channels_first' ):
363
- while True :
364
- im_index = random .randint (0 , len (self .augmentor_images ))
365
- im_PIL = self ._execute (self .augmentor_images [im_index ], save_to_disk = False )
366
- im_array = np .asarray (im_PIL )
367
- yield im_array , self .augmentor_images [im_index ].class_label_int
365
+ raise NotImplementedError ("This method is currently not implemented. Use keras_image_generator()." )
366
+ #while True:
367
+ # im_index = random.randint(0, len(self.augmentor_images))
368
+ # im_PIL = self._execute(self.augmentor_images[im_index], save_to_disk=False)
369
+ # im_array = np.asarray(im_PIL)
370
+ # yield im_array, self.augmentor_images[im_index].class_label_int
368
371
369
372
def add_operation (self , operation ):
370
373
"""
0 commit comments