Skip to content

Commit 6c376d6

Browse files
akolesnikofflucasb-eyerxiaohuazhai
committed
Add pretrained ViT-S/16 models (90, 150 and 300 epochs) + misc updates
Co-authored-by: Lucas Beyer <[email protected]> Co-authored-by: Xiaohua Zhai <[email protected]>
1 parent 8ca9d84 commit 6c376d6

File tree

4 files changed

+57
-18
lines changed

4 files changed

+57
-18
lines changed

Diff for: README.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,16 @@ gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "rm
181181
If you want to integrate other public or custom datasets, i.e. imagenet2012,
182182
please follow [the official guideline](https://www.tensorflow.org/datasets/catalog/overview).
183183

184+
## Pre-trained models
185+
186+
For the full list of pre-trained models check out the `load` function defined in
187+
the same module as the model code. And for example config on how to use these
188+
models, see `configs/transfer.py`.
189+
184190
## Run the transfer script on TPU VMs
185191

186192
The following command line fine-tunes a pre-trained `vit-i21k-augreg-b/32` model
187-
on `cifar10` dataset. Please check `transfer.py` directly for more supported
188-
datasets and models.
193+
on `cifar10` dataset.
189194

190195
```
191196
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03"

Diff for: big_vision/configs/load_and_eval.py

+25
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,31 @@ def bit_paper(config):
9595
)
9696

9797

98+
def vit_i1k(config):
99+
# We could omit init_{shapes,types} if we wanted, as they are the default.
100+
config.init_shapes = [(1, 224, 224, 3)]
101+
config.init_types = ['float32']
102+
config.num_classes = 1000
103+
104+
config.model_name = 'vit'
105+
config.model_init = '' # Will be set in sweep.
106+
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
107+
rep_size=True)
108+
109+
config.evals = [
110+
('fewshot', 'fewshot_lsr'),
111+
('val', 'classification'),
112+
]
113+
config.fewshot = get_fewshot_lsr()
114+
config.val = dict(
115+
dataset='imagenet2012',
116+
split='validation',
117+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
118+
loss_name='softmax_xent',
119+
cache_final=False, # Only run once, on low-mem machine.
120+
)
121+
122+
98123
def vit_i21k(config):
99124
# We could omit init_{shapes,types} if we wanted, as they are the default.
100125
config.init_shapes = [(1, 224, 224, 3)]

Diff for: big_vision/configs/transfer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,15 @@ def _set_model(config, model):
4242
config.model_load = dict(dont_load=['head/kernel', 'head/bias'])
4343

4444
if model == 'vit-i21k-augreg-b/32':
45-
# Load "recommented" upstream B/32 from https://arxiv.org/abs/2106.10270
45+
# Load "recommended" upstream B/32 from https://arxiv.org/abs/2106.10270
4646
config.model_name = 'vit'
4747
config.model_init = 'howto-i21k-B/32'
4848
config.model = dict(variant='B/32', pool_type='tok')
49+
elif model == 'vit-s16':
50+
config.model_name = 'vit'
51+
config.model_init = 'i1k-s16-300ep'
52+
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
53+
rep_size=True)
4954
else:
5055
raise ValueError(f'Unknown model: {model}, please define customized model.')
5156

Diff for: big_vision/models/vit.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,13 @@ def fix_old_checkpoints(params):
290290
# This means a B/32@224px would have 7x7+1 posembs. This is useless and clumsy
291291
# so we changed to add posemb then concat [cls]. We can recover the old
292292
# checkpoint by manually summing [cls] token and its posemb entry.
293-
pe = params["pos_embedding"]
294-
if int(np.sqrt(pe.shape[1])) ** 2 + 1 == int(pe.shape[1]):
295-
logging.info("ViT: Loading and fixing combined cls+posemb")
296-
pe_cls, params["pos_embedding"] = pe[:, :1], pe[:, 1:]
297-
if "cls" in params:
298-
params["cls"] += pe_cls
293+
if "pos_embedding" in params:
294+
pe = params["pos_embedding"]
295+
if int(np.sqrt(pe.shape[1])) ** 2 + 1 == int(pe.shape[1]):
296+
logging.info("ViT: Loading and fixing combined cls+posemb")
297+
pe_cls, params["pos_embedding"] = pe[:, :1], pe[:, 1:]
298+
if "cls" in params:
299+
params["cls"] += pe_cls
299300

300301
# MAP-head variants during ViT-G development had it inlined:
301302
if "probe" in params:
@@ -308,8 +309,10 @@ def fix_old_checkpoints(params):
308309
def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above.
309310
"""Load init from checkpoint, both old model and this one. +Hi-res posemb."""
310311

312+
del model_cfg
311313
# Shortcut names for some canonical paper checkpoints:
312314
init_file = {
315+
# pylint: disable=line-too-long
313316
# pylint: disable=line-too-long
314317
# Recommended models from https://arxiv.org/abs/2106.10270
315318
# Many more models at https://github.com/google-research/vision_transformer
@@ -320,24 +323,25 @@ def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=in
320323
"howto-i21k-B/16": "gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
321324
"howto-i21k-B/8": "gs://vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz",
322325
"howto-i21k-L/16": "gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
326+
327+
# Better plain vit-s16 baselines from https://arxiv.org/abs/2205.01580
328+
"i1k-s16-90ep": "gs://big_vision/vit_s16_i1k_90ep.npz",
329+
"i1k-s16-150ep": "gs://big_vision/vit_s16_i1k_150ep.npz",
330+
"i1k-s16-300ep": "gs://big_vision/vit_s16_i1k_300ep.npz",
331+
# pylint: disable=line-too-long
323332
# pylint: enable=line-too-long
324333
}.get(init_file, init_file)
325334
restored_params = utils.load_params(None, init_file)
326335

327-
# The following allows implementing both fine-tuning head variants from
328-
# (internal link)
329-
# depending on the value of `rep_size` in the fine-tuning job.
330-
if model_cfg.get("rep_size", False) in (None, False):
331-
restored_params.pop("pre_logits", None)
332-
333336
fix_old_checkpoints(restored_params)
334337

335338
# possibly use the random init for some of the params (such as, the head).
336339
restored_params = common.merge_params(restored_params, init_params, dont_load)
337340

338341
# resample posemb if needed.
339-
restored_params["pos_embedding"] = resample_posemb(
340-
old=restored_params["pos_embedding"],
341-
new=init_params["pos_embedding"])
342+
if "pos_embedding" in init_params:
343+
restored_params["pos_embedding"] = resample_posemb(
344+
old=restored_params["pos_embedding"],
345+
new=init_params["pos_embedding"])
342346

343347
return restored_params

0 commit comments

Comments
 (0)