Skip to content

Commit

Permalink
Update examples (huggingface#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored and Jinyan chen committed Feb 27, 2024
1 parent 78fc45f commit 6940ec6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 62 deletions.
6 changes: 5 additions & 1 deletion examples/contrastive-image-text/run_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,11 @@ def transform_images(examples):
trainer.save_metrics("test", metrics)

# 12. Write Training Stats and push to hub.
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "contrastive-image-text-modeling"}
finetuned_from = model_args.model_name_or_path
# If from a local directory, don't set `finetuned_from` as this is required to be a valid repo. id on the Hub.
if os.path.isdir(finetuned_from):
finetuned_from = None
kwargs = {"finetuned_from": finetuned_from, "tasks": "contrastive-image-text-modeling"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
Expand Down
6 changes: 5 additions & 1 deletion examples/contrastive-image-text/run_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,11 @@ def filter_corrupt_images(examples):
trainer.save_metrics("eval", metrics)

# 11. Write Training Stats and push to hub.
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "contrastive-image-text-modeling"}
finetuned_from = model_args.model_name_or_path
# If from a local directory, don't set `finetuned_from` as this is required to be a valid repo. id on the Hub.
if os.path.isdir(finetuned_from):
finetuned_from = None
kwargs = {"finetuned_from": finetuned_from, "tasks": "contrastive-image-text-modeling"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
Expand Down
66 changes: 6 additions & 60 deletions tests/example_diff/run_image_classification.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,84 +30,30 @@
> # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks.
> check_min_version("4.34.0")
> check_optimum_habana_min_version("1.8.1")
113a122,129
> image_column_name: str = field(
> default="image",
> metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
> )
> label_column_name: str = field(
> default="label",
> metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
> )
178,183d193
< def collate_fn(examples):
< pixel_values = torch.stack([example["pixel_values"] for example in examples])
< labels = torch.tensor([example["labels"] for example in examples])
< return {"pixel_values": pixel_values, "labels": labels}
<
<
189c199
191c199
< parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
---
> parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiTrainingArguments))
226a237,243
228a237,243
> gaudi_config = GaudiConfig.from_pretrained(
> training_args.gaudi_config_name,
> cache_dir=model_args.cache_dir,
> revision=model_args.model_revision,
> use_auth_token=True if model_args.use_auth_token else None,
> )
>
227a245
229a245
> mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast
229,230c247,249
231,232c247,249
< f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
< + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
---
> f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
> + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, "
> + f"mixed-precision training: {mixed_precision}"
258d276
< task="image-classification",
271d288
< task="image-classification",
273a291,309
> dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
> if data_args.image_column_name not in dataset_column_names:
> raise ValueError(
> f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
> "Make sure to set `--image_column_name` to the correct audio column - one of "
> f"{', '.join(dataset_column_names)}."
> )
> if data_args.label_column_name not in dataset_column_names:
> raise ValueError(
> f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
> "Make sure to set `--label_column_name` to the correct text column - one of "
> f"{', '.join(dataset_column_names)}."
> )
>
> def collate_fn(examples):
> pixel_values = torch.stack([example["pixel_values"] for example in examples])
> labels = torch.tensor([example[data_args.label_column_name] for example in examples])
> return {"pixel_values": pixel_values, "labels": labels}
>
283c319
< labels = dataset["train"].features["labels"].names
---
> labels = dataset["train"].features[data_args.label_column_name].names
357c393
< _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
---
> _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
363c399,401
< example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
---
> example_batch["pixel_values"] = [
> _val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
> ]
387c425
408c425
< trainer = Trainer(
---
> trainer = GaudiTrainer(
388a427
409a427
> gaudi_config=gaudi_config,

0 comments on commit 6940ec6

Please sign in to comment.