From 6940ec6c292f03705d4e6da92f08e2ac871807ed Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 16 Jan 2024 10:41:28 +0100 Subject: [PATCH] Update examples (#638) --- .../contrastive-image-text/run_bridgetower.py | 6 +- examples/contrastive-image-text/run_clip.py | 6 +- .../example_diff/run_image_classification.txt | 66 ++----------------- 3 files changed, 16 insertions(+), 62 deletions(-) diff --git a/examples/contrastive-image-text/run_bridgetower.py b/examples/contrastive-image-text/run_bridgetower.py index 471914c666..9037dccff2 100644 --- a/examples/contrastive-image-text/run_bridgetower.py +++ b/examples/contrastive-image-text/run_bridgetower.py @@ -617,7 +617,11 @@ def transform_images(examples): trainer.save_metrics("test", metrics) # 12. Write Training Stats and push to hub. - kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "contrastive-image-text-modeling"} + finetuned_from = model_args.model_name_or_path + # If from a local directory, don't set `finetuned_from` as this is required to be a valid repo. id on the Hub. + if os.path.isdir(finetuned_from): + finetuned_from = None + kwargs = {"finetuned_from": finetuned_from, "tasks": "contrastive-image-text-modeling"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/examples/contrastive-image-text/run_clip.py b/examples/contrastive-image-text/run_clip.py index a0fd41c2b4..aaa90c4752 100644 --- a/examples/contrastive-image-text/run_clip.py +++ b/examples/contrastive-image-text/run_clip.py @@ -603,7 +603,11 @@ def filter_corrupt_images(examples): trainer.save_metrics("eval", metrics) # 11. Write Training Stats and push to hub. - kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "contrastive-image-text-modeling"} + finetuned_from = model_args.model_name_or_path + # If from a local directory, don't set `finetuned_from` as this is required to be a valid repo. id on the Hub. + if os.path.isdir(finetuned_from): + finetuned_from = None + kwargs = {"finetuned_from": finetuned_from, "tasks": "contrastive-image-text-modeling"} if data_args.dataset_name is not None: kwargs["dataset_tags"] = data_args.dataset_name if data_args.dataset_config_name is not None: diff --git a/tests/example_diff/run_image_classification.txt b/tests/example_diff/run_image_classification.txt index 5c4a9d61c2..5db5dfc7d7 100644 --- a/tests/example_diff/run_image_classification.txt +++ b/tests/example_diff/run_image_classification.txt @@ -30,27 +30,11 @@ > # Will error if the minimal version of Transformers and Optimum Habana are not installed. Remove at your own risks. > check_min_version("4.34.0") > check_optimum_habana_min_version("1.8.1") -113a122,129 -> image_column_name: str = field( -> default="image", -> metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."}, -> ) -> label_column_name: str = field( -> default="label", -> metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."}, -> ) -178,183d193 -< def collate_fn(examples): -< pixel_values = torch.stack([example["pixel_values"] for example in examples]) -< labels = torch.tensor([example["labels"] for example in examples]) -< return {"pixel_values": pixel_values, "labels": labels} -< -< -189c199 +191c199 < parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) --- > parser = HfArgumentParser((ModelArguments, DataTrainingArguments, GaudiTrainingArguments)) -226a237,243 +228a237,243 > gaudi_config = GaudiConfig.from_pretrained( > training_args.gaudi_config_name, > cache_dir=model_args.cache_dir, @@ -58,56 +42,18 @@ > use_auth_token=True if model_args.use_auth_token else None, > ) > -227a245 +229a245 > mixed_precision = training_args.bf16 or gaudi_config.use_torch_autocast -229,230c247,249 +231,232c247,249 < f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " < + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" --- > f"Process rank: {training_args.local_rank}, device: {training_args.device}, " > + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, " > + f"mixed-precision training: {mixed_precision}" -258d276 -< task="image-classification", -271d288 -< task="image-classification", -273a291,309 -> dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names -> if data_args.image_column_name not in dataset_column_names: -> raise ValueError( -> f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. " -> "Make sure to set `--image_column_name` to the correct audio column - one of " -> f"{', '.join(dataset_column_names)}." -> ) -> if data_args.label_column_name not in dataset_column_names: -> raise ValueError( -> f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. " -> "Make sure to set `--label_column_name` to the correct text column - one of " -> f"{', '.join(dataset_column_names)}." -> ) -> -> def collate_fn(examples): -> pixel_values = torch.stack([example["pixel_values"] for example in examples]) -> labels = torch.tensor([example[data_args.label_column_name] for example in examples]) -> return {"pixel_values": pixel_values, "labels": labels} -> -283c319 -< labels = dataset["train"].features["labels"].names ---- -> labels = dataset["train"].features[data_args.label_column_name].names -357c393 -< _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"] ---- -> _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name] -363c399,401 -< example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]] ---- -> example_batch["pixel_values"] = [ -> _val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name] -> ] -387c425 +408c425 < trainer = Trainer( --- > trainer = GaudiTrainer( -388a427 +409a427 > gaudi_config=gaudi_config,