|
4 | 4 | from models.backbone.linear import LinearBackbone
|
5 | 5 | from models.backbone.resnet import ResNetBackbone
|
6 | 6 | from models.backbone.vgg import VGGBackbone
|
| 7 | +from models.backbone.vit import ViTBackbone |
7 | 8 | from models.layers.activation import SignHashLayer, StochasticBinaryLayer
|
8 | 9 | from models.layers.bihalf import BiHalfLayer
|
9 | 10 | from models.layers.zm import MeanOnlyBatchNorm
|
@@ -36,6 +37,15 @@ def get_backbone(backbone, nbit, nclass, pretrained, freeze_weight, **kwargs):
|
36 | 37 | vgg_size='vgg16bn', freeze_weight=freeze_weight, **kwargs)
|
37 | 38 | elif backbone == 'linear':
|
38 | 39 | return LinearBackbone(nclass=nclass, nbit=nbit, **kwargs)
|
| 40 | + elif backbone == 'vit': |
| 41 | + return ViTBackbone(nbit=nbit, nclass=nclass, vit_name='vit_base_patch16_224', |
| 42 | + pretrained=pretrained, freeze_weight=freeze_weight, **kwargs) |
| 43 | + elif backbone == 'vittiny': |
| 44 | + return ViTBackbone(nbit=nbit, nclass=nclass, vit_name='vit_tiny_patch16_224', |
| 45 | + pretrained=pretrained, freeze_weight=freeze_weight, **kwargs) |
| 46 | + elif backbone == 'vitsmall': |
| 47 | + return ViTBackbone(nbit=nbit, nclass=nclass, vit_name='vit_small_patch16_224', |
| 48 | + pretrained=pretrained, freeze_weight=freeze_weight, **kwargs) |
39 | 49 | else:
|
40 | 50 | raise NotImplementedError('The backbone not implemented.')
|
41 | 51 |
|
|
0 commit comments