Skip to content

Commit

Permalink
Fix image-classification example
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss committed Jan 11, 2024
1 parent 503e853 commit 024b22f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 14 deletions.
4 changes: 4 additions & 0 deletions examples/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ python run_image_classification.py \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
--remove_unused_columns False \
--image_column_name img \
--do_train \
--do_eval \
--learning_rate 3e-5 \
Expand Down Expand Up @@ -182,6 +183,7 @@ python ../gaudi_spawn.py \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
--remove_unused_columns False \
--image_column_name img \
--do_train \
--do_eval \
--learning_rate 2e-4 \
Expand Down Expand Up @@ -221,6 +223,7 @@ python ../gaudi_spawn.py \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
--remove_unused_columns False \
--image_column_name img \
--do_train \
--do_eval \
--learning_rate 2e-4 \
Expand Down Expand Up @@ -276,6 +279,7 @@ python run_image_classification.py \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
--remove_unused_columns False \
--image_column_name img \
--do_eval \
--per_device_eval_batch_size 64 \
--use_habana \
Expand Down
45 changes: 31 additions & 14 deletions examples/image-classification/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ class DataTrainingArguments:
)
},
)
image_column_name: str = field(
default="image",
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
)
label_column_name: str = field(
default="label",
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
)

def __post_init__(self):
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
Expand Down Expand Up @@ -183,12 +191,6 @@ class ModelArguments:
)


def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand Down Expand Up @@ -286,11 +288,24 @@ def main():
cache_dir=model_args.cache_dir,
)

# Rename image and label columns if needed (e.g. Cifar10)
if "img" in (dataset["train"].features if "train" in dataset else dataset["validation"].features):
dataset = dataset.rename_column("img", "image")
if "label" in (dataset["train"].features if "train" in dataset else dataset["validation"].features):
dataset = dataset.rename_column("label", "labels")
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if data_args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if data_args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
Expand All @@ -301,7 +316,7 @@ def main():

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[data_args.label_column_name].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
Expand Down Expand Up @@ -375,13 +390,15 @@ def compute_metrics(p):
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
example_batch["pixel_values"] = [
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

if training_args.do_train:
Expand Down
4 changes: 4 additions & 0 deletions tests/baselines/swin_base_patch4_window7_224_in22k.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"train_samples_per_second": 203.619,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--ignore_mismatched_sizes",
Expand All @@ -28,6 +29,7 @@
"train_samples_per_second": 1679.61,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--ignore_mismatched_sizes",
Expand All @@ -52,6 +54,7 @@
"train_samples_per_second": 840.673,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--ignore_mismatched_sizes",
Expand All @@ -68,6 +71,7 @@
"train_samples_per_second": 5820.915,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--ignore_mismatched_sizes",
Expand Down
4 changes: 4 additions & 0 deletions tests/baselines/vit_base_patch16_224_in21k.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"train_samples_per_second": 349.875,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
Expand All @@ -27,6 +28,7 @@
"train_samples_per_second": 2509.027,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
Expand All @@ -51,6 +53,7 @@
"train_samples_per_second": 904.475,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
Expand All @@ -66,6 +69,7 @@
"train_samples_per_second": 4251.991,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
Expand Down

0 comments on commit 024b22f

Please sign in to comment.