Skip to content

Bounded attention #892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions ldm/modules/bounded_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
Bounded Attention patch for Stable-Diffusion v1 UNet
(implements Dahary et al., 2024)
"""
from __future__ import annotations
import contextlib
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn.functional as F

@dataclass
class SubjectMask:
start: int
end: int # half-open [start, end)

def slice(self, L: int) -> slice:
return slice(max(0, self.start), min(L, self.end))


# ---------- helpers ---------------------------------------------------------

def _build_key_mask(L: int, subjects: List[SubjectMask], device) -> torch.Tensor:
"""
returns [Nsubj+1, 1, 1, L] bool
- first N slices = subjects
- last slice = background (everything not in any subject span)
"""
full = torch.zeros(L, dtype=torch.bool, device=device)
subj_masks = []
covered = torch.zeros_like(full)

for sm in subjects:
m = full.clone()
m[sm.slice(L)] = True
covered |= m
subj_masks.append(m)

# background “bucket”
background = ~covered
return torch.stack([*subj_masks, background])[:, None, None, :] # [N+1,1,1,L]


def _safe_softmax(attn: torch.Tensor, mask: torch.Tensor, dim=-1) -> torch.Tensor:
"""
softmax with -inf masking that **guarantees** each row has ≥1 valid key.
if a row would be all -inf we instead fall back to an un-masked softmax
for that row only (uniform attention ≈ no harm, avoids nans).
"""
max_neg = -torch.finfo(attn.dtype).max
attn = attn.masked_fill(~mask, max_neg)

# rows where everything is masked
all_masked = (mask.sum(dim=dim, keepdim=True) == 0)
if all_masked.any():
attn = attn.masked_fill(all_masked, 0.0)

return F.softmax(attn, dim=dim, dtype=torch.float32)


# ---------- monkey-patch machinery -----------------------------------------

_patch: Optional[tuple] = None

def enable_bounded_attention(model, subjects: List[SubjectMask]):
"""
Enable bounded attention on **all** CrossAttention layers in `model`.
"""
global _patch
if _patch is not None:
raise RuntimeError("already enabled")

from ldm.modules.attention import CrossAttention # import locally
orig_forward = CrossAttention.forward

def forward_ba(self, x, context=None, mask=None):
h = self.heads
context = x if context is None else context # self-attention

B, Lq, _ = x.shape
Lk = context.shape[1]
device = context.device

# build / cache masks
if (not hasattr(self, "_ba_kmask")
or self._ba_kmask.shape[-1] != Lk):
self._ba_kmask = _build_key_mask(Lk, subjects, device) # [N+1,1,1,Lk]

# decide, **per query token**, which bucket to use
# rule of thumb: token ∈ subject_i → bucket i
# else → background bucket (−1 index)
bucket_ids = torch.full((Lk,), len(subjects), device=device)
for i, sm in enumerate(subjects):
bucket_ids[sm.slice(Lk)] = i # assign ids

# projections
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)

dim_head = q.shape[-1] // h
q, k, v = map(lambda t: t.view(B, -1, h, dim_head).transpose(1, 2),
(q, k, v)) # (B,h,Len,dh)

attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B,h,Lq,Lk)

# broadcast key-mask to (B,h,Lq,Lk) by picking the right bucket for each query
# 1. Union-of-subject + background bucket → shape (1,1,1,Lk)
kmask = self._ba_kmask.any(0, keepdim=True) # (1,1,1,Lk)

# 2. Bring in SD’s own mask (if it exists)
if mask is not None: # mask: (B,1,1,Lk)
kmask = kmask & mask

# 3. Broadcast to (B,h,Lq,Lk) automatically
probs = _safe_softmax(attn, kmask, dim=-1)

out = torch.matmul(probs, v) # (B,h,Lq,dh)
out = out.transpose(1, 2).reshape(B, Lq, h * dim_head)
return self.to_out(out)

CrossAttention.forward = forward_ba
_patch = (CrossAttention, orig_forward)


def disable_bounded_attention():
global _patch
if _patch is None:
return
cls, orig_fwd = _patch
cls.forward = orig_fwd
_patch = None


@contextlib.contextmanager
def bounded_attention(model, subjects: List[SubjectMask]):
enable_bounded_attention(model, subjects)
try:
yield
finally:
disable_bounded_attention()
46 changes: 45 additions & 1 deletion ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import torch.nn as nn
from functools import partial
import clip
import open_clip as clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.modules.rope_utils import build_rope_cache, apply_rope


from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test

Expand Down Expand Up @@ -140,10 +142,17 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_l
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
# === Inject RoPE into attention layers ===
for name, module in self.transformer.named_modules():
if isinstance(module, torch.nn.MultiheadAttention):
setattr(self.transformer, name, RoPEAttentionWrapper(module))
print(f"[RoPE] Wrapped attention module: {name}")

self.device = device
self.max_length = max_length
self.freeze()


def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
Expand Down Expand Up @@ -227,6 +236,41 @@ def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x))

class RoPEAttentionWrapper(nn.Module):
def __init__(self, attn_layer):
super().__init__()
self.attn = attn_layer
self.rope_cache = None

def forward(self, x, *args, **kwargs):
B, S, C = x.shape # batch, seq_len, channels
device = x.device
num_heads = self.attn.num_heads
head_dim = C // num_heads

# Linear projection to get QKV
qkv = F.linear(x, self.attn.in_proj_weight, self.attn.in_proj_bias)
qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

# Build rope cache if not existing
if self.rope_cache is None or self.rope_cache[0].shape[2] != S:
self.rope_cache = build_rope_cache(S, head_dim, device)

# Apply RoPE
q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)

# Attention calculation
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** -0.5)
attn_weights = attn_weights.softmax(dim=-1)
attn_output = torch.matmul(attn_weights, v)

attn_output = attn_output.transpose(1, 2).reshape(B, S, C)
output = self.attn.out_proj(attn_output)

return output


if __name__ == "__main__":
from ldm.util import count_params
Expand Down
20 changes: 20 additions & 0 deletions ldm/modules/rope_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ldm/modules/rope_utils.py

import torch

def build_rope_cache(seq_len, head_dim, device):
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(seq_len, device=device).type_as(inv_freq)
freqs = torch.einsum('i,j->ij', t, inv_freq) # (seq_len, head_dim/2)
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim)
sin_emb = emb.sin()[None, None, :, :] # (1, 1, seq_len, head_dim)
cos_emb = emb.cos()[None, None, :, :]
return sin_emb, cos_emb

def apply_rope(x, rope_cache):
sin_emb, cos_emb = rope_cache
x1 = x[..., ::2]
x2 = x[..., 1::2]
x_out = torch.cat([x1 * cos_emb - x2 * sin_emb,
x1 * sin_emb + x2 * cos_emb], dim=-1)
return x_out
Loading