Skip to content

Forgot to compact attention pool branches after verifying #2507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@

## What's New

## June 5, 2025
* Initial NaFlexVit model code. NaFlexVit is a Vision Transformer with:
1. Encapsulated embedding and position encoding in a single module
2. Support for nn.Linear patch embedding on pre-patchified (dictionary) inputs
3. Support for NaFlex variable aspect, variable resolution (SigLip-2: https://arxiv.org/abs/2502.14786)
4. Support for FlexiViT variable patch size (https://arxiv.org/abs/2212.08013)
5. Support for NaViT fractional/factorized position embedding (https://arxiv.org/abs/2307.06304)
* Existing vit models in `vision_transformer.py` can be loaded into the NaFlexVit model by adding the `use_naflex=True` flag to `create_model`
* Some native weights coming soon
* A full NaFlex data pipeline is available that allows training / fine-tuning / evaluating with variable aspect / size images
* To enable in `train.py` and `validate.py` add the `--naflex-loader` arg, must be used with a NaFlexVit
* To evaluate an existing (classic) ViT loaded in NaFlexVit model w/ NaFlex data pipe:
* `python validate.py /imagenet --amp -j 8 --model vit_base_patch16_224 --model-kwargs use_naflex=True --naflex-loader --naflex-max-seq-len 256`
* The training has some extra args features worth noting
* The `--naflex-train-seq-lens'` argument specifies which sequence lengths to randomly pick from per batch during training
* The `--naflex-max-seq-len` argument sets the target sequence length for validation
* Adding `--model-kwargs enable_patch_interpolator=True --naflex-patch-sizes 12 16 24` will enable random patch size selection per-batch w/ interpolation
* The `--naflex-loss-scale` arg changes loss scaling mode per batch relative to the batch size, `timm` NaFlex loading changes the batch size for each seq len

## May 28, 2025
* Add a number of small/fast models thanks to https://github.com/brianhou0208
* SwiftFormer - [(ICCV2023) SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications](https://github.com/Amshaker/SwiftFormer)
Expand Down
31 changes: 10 additions & 21 deletions timm/models/naflexvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,27 +1192,16 @@ def _pool(
patch_valid: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.attn_pool is not None:
# For attention pooling, we need to pass the mask for NaFlex models
if self.pool_include_prefix:
# Include all tokens in attention pooling - create mask for all tokens including prefix
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,
symmetric=False,
q_len=1,
dtype=x.dtype,
)
x = self.attn_pool(x, attn_mask=attn_mask)
else:
# Exclude prefix tokens from attention pooling (default behavior)
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=0, # No prefix tokens when we slice them off
symmetric=False,
q_len=1,
dtype=x.dtype,
)
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens if self.pool_include_prefix else 0,
symmetric=False,
q_len=1,
dtype=x.dtype,
)
if not self.pool_include_prefix:
x = x[:, self.num_prefix_tokens:]
x = self.attn_pool(x, attn_mask=attn_mask)
return x

pool_type = self.global_pool if pool_type is None else pool_type
Expand Down