1
1
import numpy as np
2
2
3
3
from keras_core .api_export import keras_core_export
4
+ from keras_core .backend .config import standardize_data_format
4
5
from keras_core .utils import dataset_utils
5
6
from keras_core .utils import image_utils
6
7
from keras_core .utils .module_utils import tensorflow as tf
@@ -29,6 +30,7 @@ def image_dataset_from_directory(
29
30
interpolation = "bilinear" ,
30
31
follow_links = False ,
31
32
crop_to_aspect_ratio = False ,
33
+ data_format = None ,
32
34
):
33
35
"""Generates a `tf.data.Dataset` from image files in a directory.
34
36
@@ -111,6 +113,8 @@ def image_dataset_from_directory(
111
113
(of size `image_size`) that matches the target aspect ratio. By
112
114
default (`crop_to_aspect_ratio=False`), aspect ratio may not be
113
115
preserved.
116
+ data_format: If None uses keras_core.config.image_data_format()
117
+ otherwise either 'channel_last' or 'channel_first'.
114
118
115
119
Returns:
116
120
@@ -142,6 +146,7 @@ def image_dataset_from_directory(
142
146
- if `color_mode` is `"rgba"`,
143
147
there are 4 channels in the image tensors.
144
148
"""
149
+
145
150
if labels not in ("inferred" , None ):
146
151
if not isinstance (labels , (list , tuple )):
147
152
raise ValueError (
@@ -220,6 +225,8 @@ def image_dataset_from_directory(
220
225
f"class_names. Received: class_names={ class_names } "
221
226
)
222
227
228
+ data_format = standardize_data_format (data_format = data_format )
229
+
223
230
if subset == "both" :
224
231
(
225
232
image_paths_train ,
@@ -252,7 +259,9 @@ def image_dataset_from_directory(
252
259
num_classes = len (class_names ),
253
260
interpolation = interpolation ,
254
261
crop_to_aspect_ratio = crop_to_aspect_ratio ,
262
+ data_format = data_format ,
255
263
)
264
+
256
265
val_dataset = paths_and_labels_to_dataset (
257
266
image_paths = image_paths_val ,
258
267
image_size = image_size ,
@@ -262,6 +271,7 @@ def image_dataset_from_directory(
262
271
num_classes = len (class_names ),
263
272
interpolation = interpolation ,
264
273
crop_to_aspect_ratio = crop_to_aspect_ratio ,
274
+ data_format = data_format ,
265
275
)
266
276
267
277
if batch_size is not None :
@@ -288,6 +298,7 @@ def image_dataset_from_directory(
288
298
# Include file paths for images as attribute.
289
299
train_dataset .file_paths = image_paths_train
290
300
val_dataset .file_paths = image_paths_val
301
+
291
302
dataset = [train_dataset , val_dataset ]
292
303
else :
293
304
image_paths , labels = dataset_utils .get_training_or_validation_split (
@@ -308,6 +319,7 @@ def image_dataset_from_directory(
308
319
num_classes = len (class_names ),
309
320
interpolation = interpolation ,
310
321
crop_to_aspect_ratio = crop_to_aspect_ratio ,
322
+ data_format = data_format ,
311
323
)
312
324
313
325
if batch_size is not None :
@@ -320,12 +332,12 @@ def image_dataset_from_directory(
320
332
dataset = dataset .shuffle (buffer_size = 1024 , seed = seed )
321
333
322
334
dataset = dataset .prefetch (tf .data .AUTOTUNE )
323
-
324
335
# Users may need to reference `class_names`.
325
336
dataset .class_names = class_names
326
337
327
338
# Include file paths for images as attribute.
328
339
dataset .file_paths = image_paths
340
+
329
341
return dataset
330
342
331
343
@@ -337,12 +349,19 @@ def paths_and_labels_to_dataset(
337
349
label_mode ,
338
350
num_classes ,
339
351
interpolation ,
352
+ data_format ,
340
353
crop_to_aspect_ratio = False ,
341
354
):
342
355
"""Constructs a dataset of images and labels."""
343
356
# TODO(fchollet): consider making num_parallel_calls settable
344
357
path_ds = tf .data .Dataset .from_tensor_slices (image_paths )
345
- args = (image_size , num_channels , interpolation , crop_to_aspect_ratio )
358
+ args = (
359
+ image_size ,
360
+ num_channels ,
361
+ interpolation ,
362
+ data_format ,
363
+ crop_to_aspect_ratio ,
364
+ )
346
365
img_ds = path_ds .map (
347
366
lambda x : load_image (x , * args ), num_parallel_calls = tf .data .AUTOTUNE
348
367
)
@@ -355,7 +374,12 @@ def paths_and_labels_to_dataset(
355
374
356
375
357
376
def load_image (
358
- path , image_size , num_channels , interpolation , crop_to_aspect_ratio = False
377
+ path ,
378
+ image_size ,
379
+ num_channels ,
380
+ interpolation ,
381
+ data_format ,
382
+ crop_to_aspect_ratio = False ,
359
383
):
360
384
"""Load an image from a path and resize it."""
361
385
img = tf .io .read_file (path )
@@ -369,6 +393,7 @@ def load_image(
369
393
img ,
370
394
image_size ,
371
395
interpolation = interpolation ,
396
+ data_format = data_format ,
372
397
backend_module = tf_backend ,
373
398
)
374
399
else :
0 commit comments