From d2f3156bb34c32ae54312de7fe4e31580f12ff7f Mon Sep 17 00:00:00 2001 From: Mathilde Caron Date: Thu, 23 Sep 2021 19:04:19 +0200 Subject: [PATCH] linear weights --- README.md | 27 +++++++++++++++++ eval_linear.py | 82 +++++++++++++++++++++++++++----------------------- utils.py | 20 ++++++++++++ 3 files changed, 92 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 94529d042..72c28096b 100644 --- a/README.md +++ b/README.md @@ -287,6 +287,12 @@ We release the logs and weights from evaluating the different models: linear weights logs + + ViT-B/8 + 80.1% + linear weights + logs + xcit_small_12_p16 77.8% @@ -319,6 +325,27 @@ We release the logs and weights from evaluating the different models: +You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines: +``` +python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train +``` + +``` +python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train +``` + +``` +python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train +``` + +``` +python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train +``` + +``` +python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train +``` + ## Evaluation: DAVIS 2017 Video object segmentation Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment. diff --git a/eval_linear.py b/eval_linear.py index e95315bf8..81eb94fd0 100644 --- a/eval_linear.py +++ b/eval_linear.py @@ -34,37 +34,6 @@ def eval_linear(args): print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) cudnn.benchmark = True - # ============ preparing data ... ============ - train_transform = pth_transforms.Compose([ - pth_transforms.RandomResizedCrop(224), - pth_transforms.RandomHorizontalFlip(), - pth_transforms.ToTensor(), - pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ]) - val_transform = pth_transforms.Compose([ - pth_transforms.Resize(256, interpolation=3), - pth_transforms.CenterCrop(224), - pth_transforms.ToTensor(), - pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ]) - dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform) - dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform) - sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) - train_loader = torch.utils.data.DataLoader( - dataset_train, - sampler=sampler, - batch_size=args.batch_size_per_gpu, - num_workers=args.num_workers, - pin_memory=True, - ) - val_loader = torch.utils.data.DataLoader( - dataset_val, - batch_size=args.batch_size_per_gpu, - num_workers=args.num_workers, - pin_memory=True, - ) - print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") - # ============ building network ... ============ # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base) if args.arch in vits.__dict__.keys(): @@ -92,6 +61,44 @@ def eval_linear(args): linear_classifier = linear_classifier.cuda() linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) + # ============ preparing data ... ============ + val_transform = pth_transforms.Compose([ + pth_transforms.Resize(256, interpolation=3), + pth_transforms.CenterCrop(224), + pth_transforms.ToTensor(), + pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform) + val_loader = torch.utils.data.DataLoader( + dataset_val, + batch_size=args.batch_size_per_gpu, + num_workers=args.num_workers, + pin_memory=True, + ) + + if args.evaluate: + utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size) + test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) + print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + return + + train_transform = pth_transforms.Compose([ + pth_transforms.RandomResizedCrop(224), + pth_transforms.RandomHorizontalFlip(), + pth_transforms.ToTensor(), + pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform) + sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) + train_loader = torch.utils.data.DataLoader( + dataset_train, + sampler=sampler, + batch_size=args.batch_size_per_gpu, + num_workers=args.num_workers, + pin_memory=True, + ) + print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") + # set optimizer optimizer = torch.optim.SGD( linear_classifier.parameters(), @@ -157,10 +164,10 @@ def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool): with torch.no_grad(): if "vit" in args.arch: intermediate_output = model.get_intermediate_layers(inp, n) - output = [x[:, 0] for x in intermediate_output] + output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) if avgpool: - output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1)) - output = torch.cat(output, dim=-1) + output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) + output = output.reshape(output.shape[0], -1) else: output = model(inp) output = linear_classifier(output) @@ -199,10 +206,10 @@ def validate_network(val_loader, model, linear_classifier, n, avgpool): with torch.no_grad(): if "vit" in args.arch: intermediate_output = model.get_intermediate_layers(inp, n) - output = [x[:, 0] for x in intermediate_output] + output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) if avgpool: - output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1)) - output = torch.cat(output, dim=-1) + output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) + output = output.reshape(output.shape[0], -1) else: output = model(inp) output = linear_classifier(output) @@ -269,5 +276,6 @@ def forward(self, x): parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') + parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') args = parser.parse_args() eval_linear(args) diff --git a/utils.py b/utils.py index 978d79d6a..958625012 100644 --- a/utils.py +++ b/utils.py @@ -109,6 +109,26 @@ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_nam print("There is no reference weights available for this model => We use random weights.") +def load_pretrained_linear_weights(linear_classifier, model_name, patch_size): + url = None + if model_name == "vit_small" and patch_size == 16: + url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth" + elif model_name == "vit_small" and patch_size == 8: + url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth" + elif model_name == "vit_base" and patch_size == 16: + url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth" + elif model_name == "vit_base" and patch_size == 8: + url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth" + elif model_name == "resnet50": + url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth" + if url is not None: + print("We load the reference pretrained linear weights.") + state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"] + linear_classifier.load_state_dict(state_dict, strict=True) + else: + print("We use random linear weights.") + + def clip_gradients(model, clip): norms = [] for name, p in model.named_parameters():