@@ -290,12 +290,13 @@ def fix_old_checkpoints(params):
290
290
# This means a B/32@224px would have 7x7+1 posembs. This is useless and clumsy
291
291
# so we changed to add posemb then concat [cls]. We can recover the old
292
292
# checkpoint by manually summing [cls] token and its posemb entry.
293
- pe = params ["pos_embedding" ]
294
- if int (np .sqrt (pe .shape [1 ])) ** 2 + 1 == int (pe .shape [1 ]):
295
- logging .info ("ViT: Loading and fixing combined cls+posemb" )
296
- pe_cls , params ["pos_embedding" ] = pe [:, :1 ], pe [:, 1 :]
297
- if "cls" in params :
298
- params ["cls" ] += pe_cls
293
+ if "pos_embedding" in params :
294
+ pe = params ["pos_embedding" ]
295
+ if int (np .sqrt (pe .shape [1 ])) ** 2 + 1 == int (pe .shape [1 ]):
296
+ logging .info ("ViT: Loading and fixing combined cls+posemb" )
297
+ pe_cls , params ["pos_embedding" ] = pe [:, :1 ], pe [:, 1 :]
298
+ if "cls" in params :
299
+ params ["cls" ] += pe_cls
299
300
300
301
# MAP-head variants during ViT-G development had it inlined:
301
302
if "probe" in params :
@@ -308,8 +309,10 @@ def fix_old_checkpoints(params):
308
309
def load (init_params , init_file , model_cfg , dont_load = ()): # pylint: disable=invalid-name because we had to CamelCase above.
309
310
"""Load init from checkpoint, both old model and this one. +Hi-res posemb."""
310
311
312
+ del model_cfg
311
313
# Shortcut names for some canonical paper checkpoints:
312
314
init_file = {
315
+ # pylint: disable=line-too-long
313
316
# pylint: disable=line-too-long
314
317
# Recommended models from https://arxiv.org/abs/2106.10270
315
318
# Many more models at https://github.com/google-research/vision_transformer
@@ -320,24 +323,25 @@ def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=in
320
323
"howto-i21k-B/16" : "gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz" ,
321
324
"howto-i21k-B/8" : "gs://vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz" ,
322
325
"howto-i21k-L/16" : "gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz" ,
326
+
327
+ # Better plain vit-s16 baselines from https://arxiv.org/abs/2205.01580
328
+ "i1k-s16-90ep" : "gs://big_vision/vit_s16_i1k_90ep.npz" ,
329
+ "i1k-s16-150ep" : "gs://big_vision/vit_s16_i1k_150ep.npz" ,
330
+ "i1k-s16-300ep" : "gs://big_vision/vit_s16_i1k_300ep.npz" ,
331
+ # pylint: disable=line-too-long
323
332
# pylint: enable=line-too-long
324
333
}.get (init_file , init_file )
325
334
restored_params = utils .load_params (None , init_file )
326
335
327
- # The following allows implementing both fine-tuning head variants from
328
- # (internal link)
329
- # depending on the value of `rep_size` in the fine-tuning job.
330
- if model_cfg .get ("rep_size" , False ) in (None , False ):
331
- restored_params .pop ("pre_logits" , None )
332
-
333
336
fix_old_checkpoints (restored_params )
334
337
335
338
# possibly use the random init for some of the params (such as, the head).
336
339
restored_params = common .merge_params (restored_params , init_params , dont_load )
337
340
338
341
# resample posemb if needed.
339
- restored_params ["pos_embedding" ] = resample_posemb (
340
- old = restored_params ["pos_embedding" ],
341
- new = init_params ["pos_embedding" ])
342
+ if "pos_embedding" in init_params :
343
+ restored_params ["pos_embedding" ] = resample_posemb (
344
+ old = restored_params ["pos_embedding" ],
345
+ new = init_params ["pos_embedding" ])
342
346
343
347
return restored_params
0 commit comments