Skip to content

Commit cb95f39

Browse files
authored
torch image added (#659)
* torch image added * moved load image format to image_utils with addition of data_format * fixed format * revert changes to torch image_resize * revert changes to backend specific image ops * fixed ops imports * fixed format * fixed format * remove smart resize from * remove smart resize from tensorflow backend * remove backend changes * removed unwanted changes * removed skip test * prefetch after shuffle * added interpolation block * fixed format
1 parent 3bc3544 commit cb95f39

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

keras_core/utils/image_dataset_utils.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from keras_core.api_export import keras_core_export
4+
from keras_core.backend.config import standardize_data_format
45
from keras_core.utils import dataset_utils
56
from keras_core.utils import image_utils
67
from keras_core.utils.module_utils import tensorflow as tf
@@ -29,6 +30,7 @@ def image_dataset_from_directory(
2930
interpolation="bilinear",
3031
follow_links=False,
3132
crop_to_aspect_ratio=False,
33+
data_format=None,
3234
):
3335
"""Generates a `tf.data.Dataset` from image files in a directory.
3436
@@ -111,6 +113,8 @@ def image_dataset_from_directory(
111113
(of size `image_size`) that matches the target aspect ratio. By
112114
default (`crop_to_aspect_ratio=False`), aspect ratio may not be
113115
preserved.
116+
data_format: If None uses keras_core.config.image_data_format()
117+
otherwise either 'channel_last' or 'channel_first'.
114118
115119
Returns:
116120
@@ -142,6 +146,7 @@ def image_dataset_from_directory(
142146
- if `color_mode` is `"rgba"`,
143147
there are 4 channels in the image tensors.
144148
"""
149+
145150
if labels not in ("inferred", None):
146151
if not isinstance(labels, (list, tuple)):
147152
raise ValueError(
@@ -220,6 +225,8 @@ def image_dataset_from_directory(
220225
f"class_names. Received: class_names={class_names}"
221226
)
222227

228+
data_format = standardize_data_format(data_format=data_format)
229+
223230
if subset == "both":
224231
(
225232
image_paths_train,
@@ -252,7 +259,9 @@ def image_dataset_from_directory(
252259
num_classes=len(class_names),
253260
interpolation=interpolation,
254261
crop_to_aspect_ratio=crop_to_aspect_ratio,
262+
data_format=data_format,
255263
)
264+
256265
val_dataset = paths_and_labels_to_dataset(
257266
image_paths=image_paths_val,
258267
image_size=image_size,
@@ -262,6 +271,7 @@ def image_dataset_from_directory(
262271
num_classes=len(class_names),
263272
interpolation=interpolation,
264273
crop_to_aspect_ratio=crop_to_aspect_ratio,
274+
data_format=data_format,
265275
)
266276

267277
if batch_size is not None:
@@ -288,6 +298,7 @@ def image_dataset_from_directory(
288298
# Include file paths for images as attribute.
289299
train_dataset.file_paths = image_paths_train
290300
val_dataset.file_paths = image_paths_val
301+
291302
dataset = [train_dataset, val_dataset]
292303
else:
293304
image_paths, labels = dataset_utils.get_training_or_validation_split(
@@ -308,6 +319,7 @@ def image_dataset_from_directory(
308319
num_classes=len(class_names),
309320
interpolation=interpolation,
310321
crop_to_aspect_ratio=crop_to_aspect_ratio,
322+
data_format=data_format,
311323
)
312324

313325
if batch_size is not None:
@@ -320,12 +332,12 @@ def image_dataset_from_directory(
320332
dataset = dataset.shuffle(buffer_size=1024, seed=seed)
321333

322334
dataset = dataset.prefetch(tf.data.AUTOTUNE)
323-
324335
# Users may need to reference `class_names`.
325336
dataset.class_names = class_names
326337

327338
# Include file paths for images as attribute.
328339
dataset.file_paths = image_paths
340+
329341
return dataset
330342

331343

@@ -337,12 +349,19 @@ def paths_and_labels_to_dataset(
337349
label_mode,
338350
num_classes,
339351
interpolation,
352+
data_format,
340353
crop_to_aspect_ratio=False,
341354
):
342355
"""Constructs a dataset of images and labels."""
343356
# TODO(fchollet): consider making num_parallel_calls settable
344357
path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
345-
args = (image_size, num_channels, interpolation, crop_to_aspect_ratio)
358+
args = (
359+
image_size,
360+
num_channels,
361+
interpolation,
362+
data_format,
363+
crop_to_aspect_ratio,
364+
)
346365
img_ds = path_ds.map(
347366
lambda x: load_image(x, *args), num_parallel_calls=tf.data.AUTOTUNE
348367
)
@@ -355,7 +374,12 @@ def paths_and_labels_to_dataset(
355374

356375

357376
def load_image(
358-
path, image_size, num_channels, interpolation, crop_to_aspect_ratio=False
377+
path,
378+
image_size,
379+
num_channels,
380+
interpolation,
381+
data_format,
382+
crop_to_aspect_ratio=False,
359383
):
360384
"""Load an image from a path and resize it."""
361385
img = tf.io.read_file(path)
@@ -369,6 +393,7 @@ def load_image(
369393
img,
370394
image_size,
371395
interpolation=interpolation,
396+
data_format=data_format,
372397
backend_module=tf_backend,
373398
)
374399
else:

0 commit comments

Comments
 (0)