Skip to content

Commit 67b0b3d

Browse files
authored
Merge pull request #2121 from huggingface/cleanup_vit_convert
Improve vit conversions. OpenAI convert pass through main convert
2 parents 492947d + c559c39 commit 67b0b3d

File tree

1 file changed

+18
-38
lines changed

1 file changed

+18
-38
lines changed

timm/models/vision_transformer.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -771,28 +771,20 @@ def resize_pos_embed(
771771
antialias: bool = False,
772772
) -> torch.Tensor:
773773
""" Rescale the grid of position embeddings when loading from state_dict.
774-
775-
*DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
776-
777-
Adapted from:
778-
https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
774+
*DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
779775
"""
780-
ntok_new = posemb_new.shape[1]
781-
if num_prefix_tokens:
782-
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
783-
ntok_new -= num_prefix_tokens
784-
else:
785-
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
786-
gs_old = int(math.sqrt(len(posemb_grid)))
776+
ntok_new = posemb_new.shape[1] - num_prefix_tokens
777+
ntok_old = posemb.shape[1] - num_prefix_tokens
778+
gs_old = [int(math.sqrt(ntok_old))] * 2
787779
if not len(gs_new): # backwards compatibility
788780
gs_new = [int(math.sqrt(ntok_new))] * 2
789-
assert len(gs_new) >= 2
790-
_logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).')
791-
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
792-
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False)
793-
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
794-
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
795-
return posemb
781+
return resample_abs_pos_embed(
782+
posemb, gs_new, gs_old,
783+
num_prefix_tokens=num_prefix_tokens,
784+
interpolation=interpolation,
785+
antialias=antialias,
786+
verbose=True,
787+
)
796788

797789

798790
@torch.no_grad()
@@ -962,16 +954,6 @@ def _convert_openai_clip(
962954
v = v.unsqueeze(0).unsqueeze(1)
963955
elif k == 'pos_embed':
964956
v = v.unsqueeze(0)
965-
if v.shape[1] != model.pos_embed.shape[1]:
966-
# To resize pos embedding when using model at different size from pretrained weights
967-
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \
968-
else getattr(model, 'num_prefix_tokens', 1)
969-
v = resample_abs_pos_embed(
970-
v,
971-
new_size=model.patch_embed.grid_size,
972-
num_prefix_tokens=num_prefix_tokens,
973-
verbose=True,
974-
)
975957
out_dict[k] = v
976958
return out_dict
977959

@@ -1014,19 +996,17 @@ def checkpoint_filter_fn(
1014996
prefix = ''
1015997

1016998
if 'visual.class_embedding' in state_dict:
1017-
return _convert_openai_clip(state_dict, model)
999+
state_dict = _convert_openai_clip(state_dict, model)
10181000
elif 'module.visual.class_embedding' in state_dict:
1019-
return _convert_openai_clip(state_dict, model, prefix='module.visual.')
1020-
1021-
if "mask_token" in state_dict:
1001+
state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
1002+
elif "mask_token" in state_dict:
10221003
state_dict = _convert_dinov2(state_dict, model)
1023-
1024-
if "encoder" in state_dict:
1004+
elif "encoder" in state_dict:
1005+
# IJEPA, vit in an 'encoder' submodule
10251006
state_dict = state_dict['encoder']
10261007
prefix = 'module.'
1027-
1028-
if 'visual.trunk.pos_embed' in state_dict:
1029-
# convert an OpenCLIP model with timm vision encoder
1008+
elif 'visual.trunk.pos_embed' in state_dict:
1009+
# OpenCLIP model with timm vision encoder
10301010
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
10311011
prefix = 'visual.trunk.'
10321012

0 commit comments

Comments
 (0)