Skip to content
This repository was archived by the owner on Jan 1, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tome/patch/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from tome.utils import parse_r

from .timm import ToMeBlock, ToMeAttention
from .timm import ToMeBlock, FlashAttnToMeAttention


def make_tome_class(transformer_class):
Expand Down Expand Up @@ -100,4 +100,4 @@ def apply_patch(
module.__class__ = ToMeBlock
module._tome_info = model._tome_info
elif isinstance(module, Attention):
module.__class__ = ToMeAttention
module.__class__ = FlashAttnToMeAttention
38 changes: 38 additions & 0 deletions tome/patch/timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from tome.merge import bipartite_soft_matching, merge_source, merge_wavg
from tome.utils import parse_r

from flash_attn import flash_attn_qkvpacked_func


class ToMeBlock(Block):
"""
Expand Down Expand Up @@ -96,6 +98,42 @@ def forward(
return x, k.mean(1)


class FlashAttnToMeAttention(Attention):
"""
Modifications:
- apply Flash-attn
- Do not Apply proportional attention for MAE models
- Return the mean of k over heads from attention
"""

def forward(
self, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Note: this is copied from timm.models.vision_transformer.Attention with modifications.
B, N, C = x.shape
try:
qkv_bias = torch.cat(
(self.q_bias,
torch.zeros_like(self.v_bias,
requires_grad=False),
self.v_bias))
except:
qkv_bias = None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1)
k = qkv.permute(2, 0, 3, 1, 4)[1]

x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0,
softmax_scale=self.scale,
causal=False)
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

# Return k as well here
return x, k.mean(1)


def make_tome_class(transformer_class):
class ToMeVisionTransformer(transformer_class):
"""
Expand Down