Skip to content

Commit c9fed9e

Browse files
committed
vit draft (base,small,tiny), small fix to cifar
1 parent 36434b0 commit c9fed9e

File tree

5 files changed

+64
-1
lines changed

5 files changed

+64
-1
lines changed

configs.py

+1
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def dataset(config, filename, transform_mode,
339339
if dataset_name in ['cifar10', 'cifar10_II', 'cifar100']:
340340
d = datasets.cifar(nclass, transform=transform, filename=filename, evaluation_protocol=ep, reset=reset,
341341
remove_train_from_db=remove_train_from_db, extra_dataset=extra_dataset)
342+
logging.info(f'Number of data: {len(d.data)}')
342343
logging.info(f'Augmentation for {transform_mode}: {transform.transforms}')
343344
else:
344345
raise NotImplementedError(f"Not implementation for {dataset_name}")

models/architectures/helper.py

+10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from models.backbone.linear import LinearBackbone
55
from models.backbone.resnet import ResNetBackbone
66
from models.backbone.vgg import VGGBackbone
7+
from models.backbone.vit import ViTBackbone
78
from models.layers.activation import SignHashLayer, StochasticBinaryLayer
89
from models.layers.bihalf import BiHalfLayer
910
from models.layers.zm import MeanOnlyBatchNorm
@@ -36,6 +37,15 @@ def get_backbone(backbone, nbit, nclass, pretrained, freeze_weight, **kwargs):
3637
vgg_size='vgg16bn', freeze_weight=freeze_weight, **kwargs)
3738
elif backbone == 'linear':
3839
return LinearBackbone(nclass=nclass, nbit=nbit, **kwargs)
40+
elif backbone == 'vit':
41+
return ViTBackbone(nbit=nbit, nclass=nclass, vit_name='vit_base_patch16_224',
42+
pretrained=pretrained, freeze_weight=freeze_weight, **kwargs)
43+
elif backbone == 'vittiny':
44+
return ViTBackbone(nbit=nbit, nclass=nclass, vit_name='vit_tiny_patch16_224',
45+
pretrained=pretrained, freeze_weight=freeze_weight, **kwargs)
46+
elif backbone == 'vitsmall':
47+
return ViTBackbone(nbit=nbit, nclass=nclass, vit_name='vit_small_patch16_224',
48+
pretrained=pretrained, freeze_weight=freeze_weight, **kwargs)
3949
else:
4050
raise NotImplementedError('The backbone not implemented.')
4151

models/backbone/vit.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from abc import ABC
2+
from typing import List
3+
4+
import timm
5+
import torch
6+
from torch import nn
7+
8+
from models.backbone.base_backbone import BaseBackbone
9+
10+
11+
class ViTBackbone(BaseBackbone):
12+
def __init__(self, nbit, nclass, vit_name, pretrained=False, freeze_weight=False, **kwargs):
13+
super(ViTBackbone, self).__init__()
14+
15+
model = timm.create_model(vit_name, pretrained=pretrained)
16+
17+
self.patch_embed = model.patch_embed
18+
self.cls_token = model.cls_token
19+
self.pos_embed = model.pos_embed
20+
self.pos_drop = model.pos_drop
21+
self.blocks = model.blocks
22+
self.norm = model.norm
23+
self.pre_logits = model.pre_logits
24+
self.head = model.head # no need train as features_params because not using
25+
26+
self.in_features = model.head.in_features
27+
self.nbit = nbit
28+
self.nclass = nclass
29+
30+
assert freeze_weight is False, \
31+
'freeze_weight in backbone deprecated. Use --backbone-lr-scale=0 to freeze backbone'
32+
33+
def get_features_params(self) -> List:
34+
return list(self.parameters())
35+
36+
def get_hash_params(self) -> List:
37+
raise NotImplementedError('no hash layer in backbone')
38+
39+
def forward(self, x):
40+
x = self.patch_embed(x)
41+
42+
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
43+
x = torch.cat((cls_token, x), dim=1)
44+
45+
x = self.pos_drop(x + self.pos_embed)
46+
x = self.blocks(x)
47+
x = self.norm(x)
48+
49+
return self.pre_logits(x[:, 0])

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
torch~=1.10.0
22
torchvision~=0.11.0
3+
timm~=0.5.4
34
tqdm
45
opencv-python
56
scikit-learn

utils/datasets.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,9 @@ def cifar(nclass, **kwargs):
616616
transform=transform, target_transform=one_hot(int(nclass)),
617617
train=True, download=True)
618618
traind = IndexDatasetWrapper(traind)
619-
testd = CIFAR(f'data/cifar{nclass}', train=False, download=True)
619+
testd = CIFAR(f'data/cifar{nclass}',
620+
transform=transform, target_transform=one_hot(int(nclass)),
621+
train=False, download=True)
620622
testd = IndexDatasetWrapper(testd)
621623

622624
if ep == 2: # using orig train and test

0 commit comments

Comments
 (0)