Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 131 additions & 109 deletions tensorflow_datasets/vision_language/wit_kaggle/wit_kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,117 @@
_BEAM_NAMESPACE = "TFDS_WIT_KAGGLE"


def _get_csv_reader(filename, *, counter):
if filename.suffix == ".gz":
counter("gz_csv_files").inc()
g = tf.io.gfile.GFile(filename, "rb")
f = gzip.open(g, "rt", newline="")
else:
counter("normal_csv_files").inc()
f = tf.io.gfile.GFile(filename, "r")
# Limit to 100 MB. Value must be smaller than the C long maximum value.
csv.field_size_limit(sys.maxsize)
return csv.reader(f, delimiter="\t")


def _read_pixel_rows(filename, *, counter):
r"""Contains image_url \t image_pixel \t metadata_url."""
reader = _get_csv_reader(filename, counter=counter)
for row in reader:
counter("pixel_rows").inc()
image_url, image_representation, metadata_url = row
if image_url:
yield [image_url, (image_representation, metadata_url)]
else:
counter("pixel_rows_no_image_url").inc()


def _read_resnet_rows(filename, *, counter):
r"""Contains image_url \t resnet_embedding."""
reader = _get_csv_reader(filename, counter=counter)
for row in reader:
counter("resnet_rows").inc()
image_url, image_representation = row
if image_url:
yield [image_url, image_representation]
else:
counter("resnet_rows_no_image_url").inc()


def _read_samples_rows(folder_path, *, builder_config, counter):
"""Contains samples: train and test have different fields."""
for filename in tf.io.gfile.listdir(folder_path):
file_path = folder_path / filename
f = tf.io.gfile.GFile(file_path, "r")
# Limit to 100 MB. Value must be smaller than the C long maximum value.
csv.field_size_limit(sys.maxsize)
csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_ALL)
for row in csv_reader:
counter("samples_rows").inc()
sample = {
feature_key: row[feature_key]
for feature_key in builder_config.split_specific_features.keys()
}
image_url = row["image_url"]
if image_url:
yield [image_url, sample]
else:
counter("samples_rows_no_image_url").inc()


def _process_examples(el, *, builder_config, counter):
"""Process examples."""
sample_url, sample_fields = el
# Each image_url can be associated with multiple samples (e.g., multiple
# languages).
for i, sample_info in enumerate(sample_fields["sample_info"]):
sample_id = f"{i}_{sample_url}"
sample = {"image_url": sample_url}
for feature_key in builder_config.split_specific_features.keys():
sample[feature_key] = sample_info[feature_key]
is_boolean_feature = (
builder_config.split_specific_features[feature_key].np_dtype
== np.bool_
)
if is_boolean_feature:
sample[feature_key] = bool_utils.parse_bool(sample[feature_key])
# Test samples don't have gold captions.
if "caption_title_and_reference_description" not in sample_info:
sample["caption_title_and_reference_description"] = ""

# We output image data only if there is at least one image
# representation per image_url.
# Not all of the samples in the competition have corresponding image
# data. In case multiple different image representations are associated
# with the same image_url, we don't know which one is correct and don't
# output any.
if len(set(sample_fields["image_pixels"])) == 1:
sample_image, sample_metadata = sample_fields["image_pixels"][0]
sample["image"] = io.BytesIO(base64.b64decode(sample_image))
sample["metadata_url"] = sample_metadata
else:
if len(set(sample_fields["image_pixels"])) > 1:
counter("image_pixels_multiple").inc()
else:
counter("image_pixels_missing").inc()
sample["image"] = io.BytesIO(base64.b64decode(_EMPTY_IMAGE_BYTES))
sample["metadata_url"] = ""

if len(set(sample_fields["image_resnet"])) == 1:
image_resnet = [
float(x) for x in sample_fields["image_resnet"][0].split(",")
]
sample["embedding"] = image_resnet
else:
if len(set(sample_fields["image_resnet"])) > 1:
counter("image_resnet_multiple").inc()
else:
counter("image_resnet_missing").inc()
sample["embedding"] = builder_config.empty_resnet_embedding

yield sample_id, sample


class WitKaggleConfig(tfds.core.BuilderConfig):
"""BuilderConfig for WitKaggle."""

Expand Down Expand Up @@ -285,119 +396,15 @@ def _generate_examples(
beam = tfds.core.lazy_imports.apache_beam
counter = functools.partial(beam.metrics.Metrics.counter, _BEAM_NAMESPACE)

def _get_csv_reader(filename):
if filename.suffix == ".gz":
counter("gz_csv_files").inc()
g = tf.io.gfile.GFile(filename, "rb")
f = gzip.open(g, "rt", newline="")
else:
counter("normal_csv_files").inc()
f = tf.io.gfile.GFile(filename, "r")
# Limit to 100 MB. Value must be smaller than the C long maximum value.
csv.field_size_limit(sys.maxsize)
return csv.reader(f, delimiter="\t")

def _read_pixel_rows(filename):
r"""Contains image_url \t image_pixel \t metadata_url."""
reader = _get_csv_reader(filename)
for row in reader:
counter("pixel_rows").inc()
image_url, image_representation, metadata_url = row
if image_url:
yield [image_url, (image_representation, metadata_url)]
else:
counter("pixel_rows_no_image_url").inc()

def _read_resnet_rows(filename):
r"""Contains image_url \t resnet_embedding."""
reader = _get_csv_reader(filename)
for row in reader:
counter("resnet_rows").inc()
image_url, image_representation = row
if image_url:
yield [image_url, image_representation]
else:
counter("resnet_rows_no_image_url").inc()

def _read_samples_rows(folder_path):
"""Contains samples: train and test have different fields."""
for filename in tf.io.gfile.listdir(folder_path):
file_path = folder_path / filename
f = tf.io.gfile.GFile(file_path, "r")
# Limit to 100 MB. Value must be smaller than the C long maximum value.
csv.field_size_limit(sys.maxsize)
csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_ALL)
for row in csv_reader:
counter("samples_rows").inc()
sample = {
feature_key: row[feature_key]
for feature_key in self.builder_config.split_specific_features.keys()
}
image_url = row["image_url"]
if image_url:
yield [image_url, sample]
else:
counter("samples_rows_no_image_url").inc()

def _process_examples(el):
sample_url, sample_fields = el
# Each image_url can be associated with multiple samples (e.g., multiple
# languages).
for i, sample_info in enumerate(sample_fields["sample_info"]):
sample_id = f"{i}_{sample_url}"
sample = {"image_url": sample_url}
for feature_key in self.builder_config.split_specific_features.keys():
sample[feature_key] = sample_info[feature_key]
is_boolean_feature = (
self.builder_config.split_specific_features[feature_key].np_dtype
== np.bool_
)
if is_boolean_feature:
sample[feature_key] = bool_utils.parse_bool(sample[feature_key])
# Test samples don't have gold captions.
if "caption_title_and_reference_description" not in sample_info:
sample["caption_title_and_reference_description"] = ""

# We output image data only if there is at least one image
# representation per image_url.
# Not all of the samples in the competition have corresponding image
# data. In case multiple different image representations are associated
# with the same image_url, we don't know which one is correct and don't
# output any.
if len(set(sample_fields["image_pixels"])) == 1:
sample_image, sample_metadata = sample_fields["image_pixels"][0]
sample["image"] = io.BytesIO(base64.b64decode(sample_image))
sample["metadata_url"] = sample_metadata
else:
if len(set(sample_fields["image_pixels"])) > 1:
counter("image_pixels_multiple").inc()
else:
counter("image_pixels_missing").inc()
sample["image"] = io.BytesIO(base64.b64decode(_EMPTY_IMAGE_BYTES))
sample["metadata_url"] = ""

if len(set(sample_fields["image_resnet"])) == 1:
image_resnet = [
float(x) for x in sample_fields["image_resnet"][0].split(",")
]
sample["embedding"] = image_resnet
else:
if len(set(sample_fields["image_resnet"])) > 1:
counter("image_resnet_multiple").inc()
else:
counter("image_resnet_missing").inc()
sample["embedding"] = self.builder_config.empty_resnet_embedding

yield sample_id, sample

# Read embeddings and bytes representations from (possibly compressed) csv.
image_resnet_files = [
image_resnet_path / f for f in tf.io.gfile.listdir(image_resnet_path)
]
resnet_collection = (
pipeline
| "Collection from resnet files" >> beam.Create(image_resnet_files)
| "Get embeddings per image" >> beam.FlatMap(_read_resnet_rows)
| "Get embeddings per image"
>> beam.FlatMap(functools.partial(_read_resnet_rows, counter=counter))
)

image_pixel_files = [
Expand All @@ -406,14 +413,22 @@ def _process_examples(el):
pixel_collection = (
pipeline
| "Collection from pixel files" >> beam.Create(image_pixel_files)
| "Get pixels per image" >> beam.FlatMap(_read_pixel_rows)
| "Get pixels per image"
>> beam.FlatMap(functools.partial(_read_pixel_rows, counter=counter))
)

# Read samples from tsv files.
sample_collection = (
pipeline
| "Collection from sample files" >> beam.Create(samples_path)
| "Get samples" >> beam.FlatMap(_read_samples_rows)
| "Get samples"
>> beam.FlatMap(
functools.partial(
_read_samples_rows,
builder_config=self.builder_config,
counter=counter,
)
)
)

# Combine the features and yield examples.
Expand All @@ -425,5 +440,12 @@ def _process_examples(el):
}
| "Group by image_url" >> beam.CoGroupByKey()
| "Reshuffle" >> beam.Reshuffle()
| "Process and yield examples" >> beam.FlatMap(_process_examples)
| "Process and yield examples"
>> beam.FlatMap(
functools.partial(
_process_examples,
builder_config=self.builder_config,
counter=counter,
)
)
)
Loading