Skip to content

Commit

Permalink
linear weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathilde Caron committed Sep 23, 2021
1 parent 9bebc3f commit d2f3156
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 37 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ We release the logs and weights from evaluating the different models:
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth">linear weights</a></td>
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_eval_linear_log.txt">logs</a></td>
</tr>
<tr>
<td>ViT-B/8</td>
<td>80.1%</td>
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth">linear weights</a></td>
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_eval_linear_log.txt">logs</a></td>
</tr>
<tr>
<td>xcit_small_12_p16</td>
<td>77.8%</td>
Expand Down Expand Up @@ -319,6 +325,27 @@ We release the logs and weights from evaluating the different models:
</tr>
</table>

You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines:
```
python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train
```

```
python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train
```

```
python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
```

```
python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
```

```
python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train
```

## Evaluation: DAVIS 2017 Video object segmentation
Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment.

Expand Down
82 changes: 45 additions & 37 deletions eval_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,37 +34,6 @@ def eval_linear(args):
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
cudnn.benchmark = True

# ============ preparing data ... ============
train_transform = pth_transforms.Compose([
pth_transforms.RandomResizedCrop(224),
pth_transforms.RandomHorizontalFlip(),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
val_transform = pth_transforms.Compose([
pth_transforms.Resize(256, interpolation=3),
pth_transforms.CenterCrop(224),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler=sampler,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
)
val_loader = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
)
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")

# ============ building network ... ============
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
if args.arch in vits.__dict__.keys():
Expand Down Expand Up @@ -92,6 +61,44 @@ def eval_linear(args):
linear_classifier = linear_classifier.cuda()
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])

# ============ preparing data ... ============
val_transform = pth_transforms.Compose([
pth_transforms.Resize(256, interpolation=3),
pth_transforms.CenterCrop(224),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
val_loader = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
)

if args.evaluate:
utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return

train_transform = pth_transforms.Compose([
pth_transforms.RandomResizedCrop(224),
pth_transforms.RandomHorizontalFlip(),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
train_loader = torch.utils.data.DataLoader(
dataset_train,
sampler=sampler,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
)
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")

# set optimizer
optimizer = torch.optim.SGD(
linear_classifier.parameters(),
Expand Down Expand Up @@ -157,10 +164,10 @@ def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
with torch.no_grad():
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
output = [x[:, 0] for x in intermediate_output]
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
if avgpool:
output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
output = torch.cat(output, dim=-1)
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
output = output.reshape(output.shape[0], -1)
else:
output = model(inp)
output = linear_classifier(output)
Expand Down Expand Up @@ -199,10 +206,10 @@ def validate_network(val_loader, model, linear_classifier, n, avgpool):
with torch.no_grad():
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
output = [x[:, 0] for x in intermediate_output]
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
if avgpool:
output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
output = torch.cat(output, dim=-1)
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
output = output.reshape(output.shape[0], -1)
else:
output = model(inp)
output = linear_classifier(output)
Expand Down Expand Up @@ -269,5 +276,6 @@ def forward(self, x):
parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
args = parser.parse_args()
eval_linear(args)
20 changes: 20 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,26 @@ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_nam
print("There is no reference weights available for this model => We use random weights.")


def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
url = None
if model_name == "vit_small" and patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
elif model_name == "vit_small" and patch_size == 8:
url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
elif model_name == "vit_base" and patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
elif model_name == "vit_base" and patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
elif model_name == "resnet50":
url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
if url is not None:
print("We load the reference pretrained linear weights.")
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
linear_classifier.load_state_dict(state_dict, strict=True)
else:
print("We use random linear weights.")


def clip_gradients(model, clip):
norms = []
for name, p in model.named_parameters():
Expand Down

0 comments on commit d2f3156

Please sign in to comment.