@@ -771,28 +771,20 @@ def resize_pos_embed(
771
771
antialias : bool = False ,
772
772
) -> torch .Tensor :
773
773
""" Rescale the grid of position embeddings when loading from state_dict.
774
-
775
- *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
776
-
777
- Adapted from:
778
- https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
774
+ *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
779
775
"""
780
- ntok_new = posemb_new .shape [1 ]
781
- if num_prefix_tokens :
782
- posemb_prefix , posemb_grid = posemb [:, :num_prefix_tokens ], posemb [0 , num_prefix_tokens :]
783
- ntok_new -= num_prefix_tokens
784
- else :
785
- posemb_prefix , posemb_grid = posemb [:, :0 ], posemb [0 ]
786
- gs_old = int (math .sqrt (len (posemb_grid )))
776
+ ntok_new = posemb_new .shape [1 ] - num_prefix_tokens
777
+ ntok_old = posemb .shape [1 ] - num_prefix_tokens
778
+ gs_old = [int (math .sqrt (ntok_old ))] * 2
787
779
if not len (gs_new ): # backwards compatibility
788
780
gs_new = [int (math .sqrt (ntok_new ))] * 2
789
- assert len ( gs_new ) >= 2
790
- _logger . info ( f'Resized position embedding: { posemb . shape } ( { [ gs_old , gs_old ] } ) to { posemb_new . shape } ( { gs_new } ).' )
791
- posemb_grid = posemb_grid . reshape ( 1 , gs_old , gs_old , - 1 ). permute ( 0 , 3 , 1 , 2 )
792
- posemb_grid = F . interpolate ( posemb_grid , size = gs_new , mode = interpolation , antialias = antialias , align_corners = False )
793
- posemb_grid = posemb_grid . permute ( 0 , 2 , 3 , 1 ). reshape ( 1 , gs_new [ 0 ] * gs_new [ 1 ], - 1 )
794
- posemb = torch . cat ([ posemb_prefix , posemb_grid ], dim = 1 )
795
- return posemb
781
+ return resample_abs_pos_embed (
782
+ posemb , gs_new , gs_old ,
783
+ num_prefix_tokens = num_prefix_tokens ,
784
+ interpolation = interpolation ,
785
+ antialias = antialias ,
786
+ verbose = True ,
787
+ )
796
788
797
789
798
790
@torch .no_grad ()
@@ -962,16 +954,6 @@ def _convert_openai_clip(
962
954
v = v .unsqueeze (0 ).unsqueeze (1 )
963
955
elif k == 'pos_embed' :
964
956
v = v .unsqueeze (0 )
965
- if v .shape [1 ] != model .pos_embed .shape [1 ]:
966
- # To resize pos embedding when using model at different size from pretrained weights
967
- num_prefix_tokens = 0 if getattr (model , 'no_embed_class' , False ) \
968
- else getattr (model , 'num_prefix_tokens' , 1 )
969
- v = resample_abs_pos_embed (
970
- v ,
971
- new_size = model .patch_embed .grid_size ,
972
- num_prefix_tokens = num_prefix_tokens ,
973
- verbose = True ,
974
- )
975
957
out_dict [k ] = v
976
958
return out_dict
977
959
@@ -1014,19 +996,17 @@ def checkpoint_filter_fn(
1014
996
prefix = ''
1015
997
1016
998
if 'visual.class_embedding' in state_dict :
1017
- return _convert_openai_clip (state_dict , model )
999
+ state_dict = _convert_openai_clip (state_dict , model )
1018
1000
elif 'module.visual.class_embedding' in state_dict :
1019
- return _convert_openai_clip (state_dict , model , prefix = 'module.visual.' )
1020
-
1021
- if "mask_token" in state_dict :
1001
+ state_dict = _convert_openai_clip (state_dict , model , prefix = 'module.visual.' )
1002
+ elif "mask_token" in state_dict :
1022
1003
state_dict = _convert_dinov2 (state_dict , model )
1023
-
1024
- if "encoder" in state_dict :
1004
+ elif "encoder" in state_dict :
1005
+ # IJEPA, vit in an 'encoder' submodule
1025
1006
state_dict = state_dict ['encoder' ]
1026
1007
prefix = 'module.'
1027
-
1028
- if 'visual.trunk.pos_embed' in state_dict :
1029
- # convert an OpenCLIP model with timm vision encoder
1008
+ elif 'visual.trunk.pos_embed' in state_dict :
1009
+ # OpenCLIP model with timm vision encoder
1030
1010
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
1031
1011
prefix = 'visual.trunk.'
1032
1012
0 commit comments