|
43 | 43 | )
|
44 | 44 | _NUM_CORRUPT_IMAGES = 1738
|
45 | 45 | _DESCRIPTION = (
|
46 |
| - "A large set of images of cats and dogs. " |
47 |
| - "There are %d corrupted images that are dropped." % _NUM_CORRUPT_IMAGES |
| 46 | + "A large set of images of cats and dogs. " |
| 47 | + "There are %d corrupted images that are dropped." % _NUM_CORRUPT_IMAGES |
48 | 48 | )
|
49 | 49 |
|
50 | 50 | _NAME_RE = re.compile(r"^PetImages[\\/](Cat|Dog)[\\/]\d+\.jpg$")
|
51 | 51 |
|
52 | 52 |
|
53 | 53 | class CatsVsDogs(tfds.core.GeneratorBasedBuilder):
|
54 |
| - """Cats vs Dogs.""" |
55 |
| - |
56 |
| - VERSION = tfds.core.Version("4.0.1") |
57 |
| - RELEASE_NOTES = { |
58 |
| - "4.0.0": "New split API (https://tensorflow.org/datasets/splits)", |
59 |
| - "4.0.1": ( |
60 |
| - "Recoding images in generator to fix corrupt JPEG data warnings" |
61 |
| - " (https://github.com/tensorflow/datasets/issues/2188)" |
| 54 | + """Cats vs Dogs.""" |
| 55 | + |
| 56 | + VERSION = tfds.core.Version("4.0.1") |
| 57 | + RELEASE_NOTES = { |
| 58 | + "4.0.0": "New split API (https://tensorflow.org/datasets/splits)", |
| 59 | + "4.0.1": ( |
| 60 | + "Recoding images in generator to fix corrupt JPEG data warnings" |
| 61 | + " (https://github.com/tensorflow/datasets/issues/2188)" |
| 62 | + ), |
| 63 | + } |
| 64 | + |
| 65 | + def _info(self): |
| 66 | + return tfds.core.DatasetInfo( |
| 67 | + builder=self, |
| 68 | + description=_DESCRIPTION, |
| 69 | + features=tfds.features.FeaturesDict({ |
| 70 | + "image": tfds.features.Image(), |
| 71 | + "image/filename": tfds.features.Text(), # eg 'PetImages/Dog/0.jpg' |
| 72 | + "label": tfds.features.ClassLabel(names=["cat", "dog"]), |
| 73 | + }), |
| 74 | + supervised_keys=("image", "label"), |
| 75 | + homepage=( |
| 76 | + "https://www.microsoft.com/en-us/download/details.aspx?id=54765" |
62 | 77 | ),
|
63 |
| - } |
64 |
| - |
65 |
| - def _info(self): |
66 |
| - return tfds.core.DatasetInfo( |
67 |
| - builder=self, |
68 |
| - description=_DESCRIPTION, |
69 |
| - features=tfds.features.FeaturesDict({ |
70 |
| - "image": tfds.features.Image(), |
71 |
| - "image/filename": tfds.features.Text(), # eg 'PetImages/Dog/0.jpg' |
72 |
| - "label": tfds.features.ClassLabel(names=["cat", "dog"]), |
73 |
| - }), |
74 |
| - supervised_keys=("image", "label"), |
75 |
| - homepage=( |
76 |
| - "https://www.microsoft.com/en-us/download/details.aspx?id=54765" |
77 |
| - ), |
78 |
| - citation=_CITATION, |
79 |
| - ) |
80 |
| - |
81 |
| - def _split_generators(self, dl_manager): |
82 |
| - path = dl_manager.download(_URL) |
83 |
| - |
84 |
| - # There is no predefined train/val/test split for this dataset. |
85 |
| - return [ |
86 |
| - tfds.core.SplitGenerator( |
87 |
| - name=tfds.Split.TRAIN, |
88 |
| - gen_kwargs={ |
89 |
| - "archive": dl_manager.iter_archive(path), |
90 |
| - }, |
91 |
| - ), |
92 |
| - ] |
93 |
| - |
94 |
| - def _generate_examples(self, archive): |
95 |
| - """Generate Cats vs Dogs images and labels given a directory path.""" |
96 |
| - num_skipped = 0 |
97 |
| - for fname, fobj in archive: |
98 |
| - norm_fname = os.path.normpath(fname) |
99 |
| - res = _NAME_RE.match(norm_fname) |
100 |
| - if not res: # README file, ... |
101 |
| - continue |
102 |
| - label = res.group(1).lower() |
103 |
| - if tf.compat.as_bytes("JFIF") not in fobj.peek(10): |
104 |
| - num_skipped += 1 |
105 |
| - continue |
106 |
| - |
107 |
| - # Some images caused 'Corrupt JPEG data...' messages during training or |
108 |
| - # any other iteration recoding them once fixes the issue (discussion: |
109 |
| - # https://github.com/tensorflow/datasets/issues/2188). |
110 |
| - # Those messages are now displayed when generating the dataset instead. |
111 |
| - img_data = fobj.read() |
112 |
| - img_tensor = tf.image.decode_image(img_data) |
113 |
| - img_recoded = tf.io.encode_jpeg(img_tensor) |
114 |
| - |
115 |
| - # Converting the recoded image back into a zip file container. |
116 |
| - buffer = io.BytesIO() |
117 |
| - with zipfile.ZipFile(buffer, "w") as new_zip: |
118 |
| - new_zip.writestr(norm_fname, img_recoded.numpy()) |
119 |
| - new_fobj = zipfile.ZipFile(buffer).open(norm_fname) |
120 |
| - |
121 |
| - record = { |
122 |
| - "image": new_fobj, |
123 |
| - "image/filename": norm_fname, |
124 |
| - "label": label, |
125 |
| - } |
126 |
| - yield norm_fname, record |
127 |
| - |
128 |
| - if num_skipped != _NUM_CORRUPT_IMAGES: |
129 |
| - raise ValueError( |
130 |
| - "Expected %d corrupt images, but found %d" |
131 |
| - % (_NUM_CORRUPT_IMAGES, num_skipped) |
132 |
| - ) |
133 |
| - logging.warning("%d images were corrupted and were skipped", num_skipped) |
| 78 | + citation=_CITATION, |
| 79 | + ) |
| 80 | + |
| 81 | + def _split_generators(self, dl_manager): |
| 82 | + path = dl_manager.download(_URL) |
| 83 | + |
| 84 | + # There is no predefined train/val/test split for this dataset. |
| 85 | + return [ |
| 86 | + tfds.core.SplitGenerator( |
| 87 | + name=tfds.Split.TRAIN, |
| 88 | + gen_kwargs={ |
| 89 | + "archive": dl_manager.iter_archive(path), |
| 90 | + }, |
| 91 | + ), |
| 92 | + ] |
| 93 | + |
| 94 | + def _generate_examples(self, archive): |
| 95 | + """Generate Cats vs Dogs images and labels given a directory path.""" |
| 96 | + num_skipped = 0 |
| 97 | + for fname, fobj in archive: |
| 98 | + norm_fname = os.path.normpath(fname) |
| 99 | + res = _NAME_RE.match(norm_fname) |
| 100 | + if not res: # README file, ... |
| 101 | + continue |
| 102 | + label = res.group(1).lower() |
| 103 | + if tf.compat.as_bytes("JFIF") not in fobj.peek(10): |
| 104 | + num_skipped += 1 |
| 105 | + continue |
| 106 | + |
| 107 | + # Some images caused 'Corrupt JPEG data...' messages during training or |
| 108 | + # any other iteration recoding them once fixes the issue (discussion: |
| 109 | + # https://github.com/tensorflow/datasets/issues/2188). |
| 110 | + # Those messages are now displayed when generating the dataset instead. |
| 111 | + img_data = fobj.read() |
| 112 | + img_tensor = tf.image.decode_image(img_data) |
| 113 | + img_recoded = tf.io.encode_jpeg(img_tensor) |
| 114 | + |
| 115 | + # Converting the recoded image back into a zip file container. |
| 116 | + buffer = io.BytesIO() |
| 117 | + with zipfile.ZipFile(buffer, "w") as new_zip: |
| 118 | + new_zip.writestr(norm_fname, img_recoded.numpy()) |
| 119 | + new_fobj = zipfile.ZipFile(buffer).open(norm_fname) |
| 120 | + |
| 121 | + record = { |
| 122 | + "image": new_fobj, |
| 123 | + "image/filename": norm_fname, |
| 124 | + "label": label, |
| 125 | + } |
| 126 | + yield norm_fname, record |
| 127 | + |
| 128 | + if num_skipped != _NUM_CORRUPT_IMAGES: |
| 129 | + raise ValueError( |
| 130 | + "Expected %d corrupt images, but found %d" |
| 131 | + % (_NUM_CORRUPT_IMAGES, num_skipped) |
| 132 | + ) |
| 133 | + logging.warning("%d images were corrupted and were skipped", num_skipped) |
0 commit comments