Skip to content

Commit 19dfdc7

Browse files
add XCiT (#99)
* Update __init__.py * Create xcit.py * Create config files for XCiT * Update config files for XCiT * Update Usage_ViTs.md for XCiT * Rename yml to yaml * Update xcit.py * Update config files for XCiT * Create README.md for XCiT * Update __init__.py * Create regnet.py * Create DistillationWrapper.py * Update __init__.py * Update extract_weight.py * Update GETTING_STARTED.md * Update DistillationWrapper.py * Update GETTING_STARTED.md * Update config files for XCiT * Update README.md
1 parent e292e33 commit 19dfdc7

30 files changed

+3434
-0
lines changed

configs/xcit/README.md

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
## XCiT: Cross-Covariance Image Transformers ([arxiv](https://arxiv.org/abs/2106.09681))
2+
3+
## Introduction
4+
5+
Following tremendous success in natural language processing, transformers have recently shown much promise for computer vision. The self-attention operation underlying transformers yields global interactions between all tokens, i.e. words or image patches, and enables flexible modelling of image data beyond the local interactions of convolutions. This flexibility, however, comes with a quadratic complexity in time and memory, hindering application to long sequences and high-resolution images. We propose a *transposed* version of self-attention that operates across feature channels rather than tokens, where the interactions are based on the cross-covariance matrix between keys and queries. The resulting cross-covariance attention (XCA) has linear complexity in the number of tokens, and allows efficient processing of high-resolution images. Our cross-covariance image transformer (XCiT) – built upon XCA – combines the accuracy of conventional transformers with the scalability of convolutional architectures. We validate the effectiveness and generality of XCiT by reporting excellent results on multiple vision benchmarks, includ- ing (self-supervised) image classification on ImageNet-1k, object detection and instance segmentation on COCO, and semantic segmentation on ADE20k.
6+
7+
![XCiT](https://user-images.githubusercontent.com/42234328/154954202-e51e6c9d-68af-4f42-b466-2db3a82fd19a.png)
8+
9+
## Getting Started
10+
11+
#### Train with single gpu
12+
```bash
13+
python tools/train.py -c configs/xcit/${XCIT_ARCH}.yaml
14+
```
15+
#### Train with multiple gpus
16+
17+
```bash
18+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/xcit/${XCIT_ARCH}.yaml
19+
```
20+
#### Evaluate
21+
22+
```bash
23+
python tools/train.py -c configs/xcit/${XCIT_ARCH}.yaml --load ${XCIT_WEGHT_FILE} --evaluate-only
24+
```
25+
26+
#### Knowledge distillation
27+
28+
For knowledge distillation, you only need to replace `${XCIT_ARCH}.yaml` to corresponding distillation config file, `${XCIT_ARCH}_dist.yaml`, at above commands. We provide pretrained weights of Teacher model `RegNetY_160`, which can be downloaded [here](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/regnety_160.pdparams).
29+
30+
Checkpoints saved in distillation training include both Teacher's and Student's weights. You can extract the weights of Student by following command.
31+
```bash
32+
python tools/extract_weight.py ${DISTILLATION_WEIGHTS_FILE} --prefix Student --remove_prefix --output ${STUDENT_WEIGHTS_FILE}
33+
```
34+
35+
36+
## Model Zoo
37+
38+
The results are evaluated on ImageNet2012 validation set
39+
| Arch | Weight | Top-1 Acc | Top-5 Acc | Crop ratio | # Params |
40+
| ------------------ | ------------------------------------------------------------ | --------- | --------- | ---------- | -------- |
41+
| xcit_nano_12_p8_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_nano_12_p8_224.pdparams) | 73.90 | 92.13 | 1.0 | 3.05M |
42+
| xcit_tiny_12_p8_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_tiny_12_p8_224.pdparams) | 79.68 | 95.04 | 1.0 | 6.71M |
43+
| xcit_tiny_24_p8_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_tiny_24_p8_224.pdparams) | 81.87 | 95.97 | 1.0 | 12.11M |
44+
| xcit_small_12_p8_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_small_12_p8_224.pdparams) | 83.36 | 96.51 | 1.0 | 26.21M |
45+
| xcit_small_24_p8_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_small_24_p8_224.pdparams) | 83.82 | 96.65 | 1.0 | 47.63M |
46+
| xcit_medium_24_p8_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_medium_24_p8_224.pdparams ) | 83.73 | 96.39 | 1.0 | 84.32M |
47+
| xcit_large_24_p8_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_large_24_p8_224.pdparams) | 84.42 | 96.65 | 1.0 | 188.93M |
48+
| xcit_nano_12_p16_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_nano_12_p16_224.pdparams) | 70.01 | 89.82 | 1.0 | 3.05M |
49+
| xcit_tiny_12_p16_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_tiny_12_p16_224.pdparams) | 77.15 | 93.72 | 1.0 | 6.72M |
50+
| xcit_tiny_24_p16_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_tiny_24_p16_224.pdparams) | 79.42 | 94.86 | 1.0 | 12.12M |
51+
| xcit_small_12_p16_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_small_12_p16_224.pdparams) | 81.89 | 95.83 | 1.0 | 26.25M |
52+
| xcit_small_24_p16_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_small_24_p16_224.pdparams) | 82.51 | 95.97 | 1.0 | 47.67M |
53+
| xcit_medium_24_p16_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_medium_24_p16_224.pdparams) | 82.67 | 95.91 | 1.0 | 84.40M |
54+
| xcit_large_24_p16_224 | [pretrain 1k](https://passl.bj.bcebos.com/vision_transformers/pvt_v2/xcit_large_24_p16_224.pdparams) | 82.89 | 95.89 | 1.0 | 189.10M |
55+
56+
57+
## Usage
58+
59+
```python
60+
from passl.modeling.backbones import build_backbone
61+
from passl.modeling.heads import build_head
62+
from passl.utils.config import get_config
63+
64+
65+
class Model(nn.Layer):
66+
def __init__(self, cfg_file):
67+
super().__init__()
68+
cfg = get_config(cfg_file)
69+
self.backbone = build_backbone(cfg.model.architecture)
70+
self.head = build_head(cfg.model.head)
71+
72+
def forward(self, x):
73+
74+
x = self.backbone(x)
75+
x = self.head(x)
76+
return x
77+
78+
79+
cfg_file = "configs/xcit/xcit_nano_12_p8_224.yaml"
80+
m = Model(cfg_file)
81+
```
82+
83+
## Reference
84+
85+
```
86+
@article{xcit,
87+
title={{XCiT}: Cross-Covariance Image Transformers},
88+
author={Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Hervé Jegou},
89+
year={2021},
90+
eprint={2106.09681},
91+
archivePrefix={arXiv},
92+
primaryClass={cs.CV}
93+
}
94+
```
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
epochs: 400
2+
output_dir: output_dir
3+
seed: 0
4+
5+
model:
6+
name: SwinWrapper
7+
architecture:
8+
name: XCiT
9+
patch_size: 16
10+
embed_dim: 768
11+
depth: 24
12+
num_heads: 16
13+
eta: 1e-5
14+
tokens_norm: True
15+
head:
16+
name: SwinTransformerClsHead
17+
in_channels: 768 # equals to architecture.embed_dim
18+
num_classes: 1000
19+
20+
21+
dataloader:
22+
train:
23+
num_workers: 8
24+
sampler:
25+
batch_size: 16
26+
shuffle: True
27+
drop_last: True
28+
dataset:
29+
name: ImageNet
30+
dataroot: data/ILSVRC2012/train/
31+
return_label: True
32+
transforms:
33+
- name: RandomResizedCrop
34+
size: 224
35+
scale: [0.08, 1.]
36+
interpolation: 'bicubic'
37+
- name: RandomHorizontalFlip
38+
- name: AutoAugment
39+
config_str: 'rand-m9-mstd0.5-inc1'
40+
interpolation: 'bicubic'
41+
img_size: 224
42+
- name: Transpose
43+
- name: Normalize
44+
mean: [123.675, 116.28, 103.53]
45+
std: [58.395, 57.12, 57.375]
46+
- name: RandomErasing
47+
prob: 0.25
48+
mode: 'pixel'
49+
max_count: 1
50+
batch_transforms:
51+
- name: Mixup
52+
mixup_alpha: 0.8
53+
prob: 1.
54+
switch_prob: 0.5
55+
mode: 'batch'
56+
cutmix_alpha: 1.0
57+
val:
58+
num_workers: 8
59+
sampler:
60+
batch_size: 128
61+
shuffle: False
62+
drop_last: False
63+
dataset:
64+
name: ImageNet
65+
dataroot: data/ILSVRC2012/val
66+
return_label: True
67+
transforms:
68+
- name: Resize
69+
size: 224
70+
interpolation: 'bicubic'
71+
- name: CenterCrop
72+
size: 224
73+
- name: Transpose
74+
- name: Normalize
75+
mean: [123.675, 116.28, 103.53]
76+
std: [58.395, 57.12, 57.375]
77+
78+
lr_scheduler:
79+
name: LinearWarmup
80+
learning_rate:
81+
name: CosineAnnealingDecay
82+
learning_rate: 5e-4
83+
T_max: 400
84+
eta_min: 1e-5
85+
warmup_steps: 5
86+
start_lr: 1e-6
87+
end_lr: 5e-4
88+
89+
optimizer:
90+
name: AdamW
91+
beta1: 0.9
92+
beta2: 0.999
93+
weight_decay: 0.05
94+
exclude_from_weight_decay: ["temperature", "pos_embed", "cls_token", "dist_token"]
95+
96+
log_config:
97+
name: LogHook
98+
interval: 10
99+
100+
checkpoint:
101+
name: CheckpointHook
102+
by_epoch: True
103+
interval: 1
104+
max_keep_ckpts: 50
105+
106+
custom_config:
107+
- name: EvaluateHook
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
epochs: 400
2+
output_dir: output_dir
3+
seed: 0
4+
5+
model:
6+
name: SwinWrapper
7+
architecture:
8+
name: XCiT
9+
patch_size: 8
10+
embed_dim: 768
11+
depth: 24
12+
num_heads: 16
13+
eta: 1e-5
14+
tokens_norm: True
15+
head:
16+
name: SwinTransformerClsHead
17+
in_channels: 768 # equals to architecture.embed_dim
18+
num_classes: 1000
19+
20+
21+
dataloader:
22+
train:
23+
num_workers: 8
24+
sampler:
25+
batch_size: 16
26+
shuffle: True
27+
drop_last: True
28+
dataset:
29+
name: ImageNet
30+
dataroot: data/ILSVRC2012/train/
31+
return_label: True
32+
transforms:
33+
- name: RandomResizedCrop
34+
size: 224
35+
scale: [0.08, 1.]
36+
interpolation: 'bicubic'
37+
- name: RandomHorizontalFlip
38+
- name: AutoAugment
39+
config_str: 'rand-m9-mstd0.5-inc1'
40+
interpolation: 'bicubic'
41+
img_size: 224
42+
- name: Transpose
43+
- name: Normalize
44+
mean: [123.675, 116.28, 103.53]
45+
std: [58.395, 57.12, 57.375]
46+
- name: RandomErasing
47+
prob: 0.25
48+
mode: 'pixel'
49+
max_count: 1
50+
batch_transforms:
51+
- name: Mixup
52+
mixup_alpha: 0.8
53+
prob: 1.
54+
switch_prob: 0.5
55+
mode: 'batch'
56+
cutmix_alpha: 1.0
57+
val:
58+
num_workers: 8
59+
sampler:
60+
batch_size: 128
61+
shuffle: False
62+
drop_last: False
63+
dataset:
64+
name: ImageNet
65+
dataroot: data/ILSVRC2012/val
66+
return_label: True
67+
transforms:
68+
- name: Resize
69+
size: 224
70+
interpolation: 'bicubic'
71+
- name: CenterCrop
72+
size: 224
73+
- name: Transpose
74+
- name: Normalize
75+
mean: [123.675, 116.28, 103.53]
76+
std: [58.395, 57.12, 57.375]
77+
78+
lr_scheduler:
79+
name: LinearWarmup
80+
learning_rate:
81+
name: CosineAnnealingDecay
82+
learning_rate: 5e-4
83+
T_max: 400
84+
eta_min: 1e-5
85+
warmup_steps: 5
86+
start_lr: 1e-6
87+
end_lr: 5e-4
88+
89+
optimizer:
90+
name: AdamW
91+
beta1: 0.9
92+
beta2: 0.999
93+
weight_decay: 0.05
94+
exclude_from_weight_decay: ["temperature", "pos_embed", "cls_token", "dist_token"]
95+
96+
log_config:
97+
name: LogHook
98+
interval: 10
99+
100+
checkpoint:
101+
name: CheckpointHook
102+
by_epoch: True
103+
interval: 1
104+
max_keep_ckpts: 50
105+
106+
custom_config:
107+
- name: EvaluateHook

0 commit comments

Comments
 (0)