Skip to content

Commit d2864f6

Browse files
committed
Examples of training bias
1 parent 36f8bd5 commit d2864f6

File tree

1 file changed

+231
-0
lines changed

1 file changed

+231
-0
lines changed

examples/learnable_bias.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import functools
2+
import logging
3+
import torch
4+
import torch.nn.functional as F
5+
import json
6+
import argparse
7+
from torch.nn.attention.flex_attention import flex_attention
8+
from typing import Callable, Dict, List, Tuple, Optional
9+
from enum import Enum, auto
10+
from torch.optim import Adam
11+
from torch.utils.data import DataLoader, TensorDataset
12+
from tqdm import tqdm
13+
14+
logging.basicConfig(
15+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
16+
)
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class BiasType(Enum):
21+
RELATIVE_1D = "relative_1d"
22+
ABSOLUTE_2D = "absolute_2d"
23+
HEAD_SPECIFIC = "head_specific"
24+
BATCH_HEAD = "batch_head"
25+
MULTIPLICATIVE = "multiplicative"
26+
LOCAL_WINDOW = "local_window"
27+
GLOBAL_TOKENS = "global_tokens"
28+
WEIRD = "weird"
29+
OFFSET = "offset"
30+
31+
32+
class AttentionTrainer:
33+
def __init__(
34+
self,
35+
batch_size: int = 8,
36+
num_heads: int = 4,
37+
seq_length: int = 256,
38+
head_dim: int = 64,
39+
device: str = "cuda",
40+
dtype: torch.dtype = torch.float32,
41+
window_size: int = 16,
42+
learning_rate: float = 1e-1,
43+
):
44+
self.B = batch_size
45+
self.H = num_heads
46+
self.S = seq_length
47+
self.D = head_dim
48+
self.W = window_size
49+
self.device = device
50+
self.dtype = dtype
51+
self.lr = learning_rate
52+
self.which_bias = torch.tensor(0, device=device)
53+
self.offset = None
54+
55+
# Initialize bias generators and functions like in the original
56+
self.bias_generators = {
57+
BiasType.RELATIVE_1D: self._generate_relative_1d_bias,
58+
BiasType.ABSOLUTE_2D: self._generate_absolute_2d_bias,
59+
BiasType.HEAD_SPECIFIC: self._generate_head_specific_bias,
60+
BiasType.BATCH_HEAD: self._generate_batch_head_bias,
61+
BiasType.MULTIPLICATIVE: self._generate_multiplicative_bias,
62+
BiasType.LOCAL_WINDOW: self._generate_local_window_bias,
63+
BiasType.GLOBAL_TOKENS: self._generate_global_tokens_bias,
64+
BiasType.WEIRD: self._generate_weird_bias,
65+
BiasType.OFFSET: self._generate_offset_bias,
66+
}
67+
68+
# Copy the bias application functions from the original
69+
self.bias_functions = {
70+
BiasType.RELATIVE_1D: self._apply_relative_1d_bias,
71+
BiasType.ABSOLUTE_2D: self._apply_absolute_2d_bias,
72+
BiasType.HEAD_SPECIFIC: self._apply_head_specific_bias,
73+
BiasType.BATCH_HEAD: self._apply_batch_head_bias,
74+
BiasType.MULTIPLICATIVE: self._apply_multiplicative_bias,
75+
BiasType.LOCAL_WINDOW: self._apply_local_window_bias,
76+
BiasType.GLOBAL_TOKENS: self._apply_global_tokens_bias,
77+
BiasType.WEIRD: self._apply_weird_bias,
78+
BiasType.OFFSET: self._apply_offset_bias,
79+
}
80+
81+
def _generate_tensor(self, *size):
82+
return torch.randn(
83+
*size, device=self.device, dtype=self.dtype, requires_grad=True
84+
)
85+
86+
# Bias Generators
87+
88+
def _generate_relative_1d_bias(self):
89+
return self._generate_tensor(2 * self.S)
90+
91+
def _generate_absolute_2d_bias(self):
92+
return self._generate_tensor(self.S, self.S)
93+
94+
def _generate_head_specific_bias(self):
95+
return self._generate_tensor(self.H, self.S, self.S)
96+
97+
def _generate_batch_head_bias(self):
98+
return self._generate_tensor(self.B, self.H, self.S, self.S)
99+
100+
def _generate_multiplicative_bias(self):
101+
return self._generate_tensor(self.S)
102+
103+
def _generate_local_window_bias(self):
104+
return self._generate_tensor(2 * self.W + 1)
105+
106+
def _generate_learned_pattern_bias(self):
107+
return self._generate_tensor(self.H, self.D)
108+
109+
def _generate_global_tokens_bias(self):
110+
return self._generate_tensor(self.S)
111+
112+
def _generate_weird_bias(self):
113+
return self._generate_tensor(self.B, self.H, 4, self.S)
114+
115+
def _generate_offset_bias(self):
116+
# Generate both the bias and offset tensors
117+
bias = self._generate_tensor(self.S)
118+
self.offset = torch.randint(0, self.S, (self.S,), device=self.device)
119+
return bias
120+
121+
# Bias Application Functions
122+
def _apply_relative_1d_bias(self, score, b, h, q_idx, kv_idx, bias):
123+
return score + bias[torch.abs(q_idx - kv_idx)]
124+
125+
def _apply_absolute_2d_bias(self, score, b, h, q_idx, kv_idx, bias):
126+
return score + bias[q_idx, kv_idx]
127+
128+
def _apply_head_specific_bias(self, score, b, h, q_idx, kv_idx, bias):
129+
return score + bias[h, q_idx, kv_idx]
130+
131+
def _apply_batch_head_bias(self, score, b, h, q_idx, kv_idx, bias):
132+
return score + bias[b, h, q_idx, kv_idx]
133+
134+
def _apply_multiplicative_bias(self, score, b, h, q_idx, kv_idx, bias):
135+
return score * bias[q_idx]
136+
137+
def _apply_local_window_bias(self, score, b, h, q_idx, kv_idx, bias):
138+
window_idx = torch.clamp(q_idx - kv_idx + self.W, 0, 2 * self.W)
139+
return score + bias[window_idx]
140+
141+
def _apply_global_tokens_bias(self, score, b, h, q_idx, kv_idx, bias):
142+
return score + bias[kv_idx]
143+
144+
def _apply_weird_bias(self, score, b, h, q_idx, kv_idx, bias):
145+
return score + bias[b, h, self.which_bias, q_idx]
146+
147+
def _apply_offset_bias(self, score, b, h, q_idx, kv_idx, bias):
148+
return score + bias[self.offset[q_idx]]
149+
150+
def generate_dummy_data(self) -> TensorDataset:
151+
"""Generate dummy training data."""
152+
queries = torch.randn(1, self.B, self.H, self.S, self.D, device=self.device)
153+
keys = torch.randn(1, self.B, self.H, self.S, self.D, device=self.device)
154+
values = torch.randn(1, self.B, self.H, self.S, self.D, device=self.device)
155+
156+
targets = torch.randn(1, self.B, self.H, self.S, self.D, device=self.device)
157+
158+
return TensorDataset(queries, keys, values, targets)
159+
160+
def train(
161+
self,
162+
bias_type: BiasType = BiasType.RELATIVE_1D,
163+
num_epochs: int = 10,
164+
batch_size: int = 4,
165+
):
166+
"""Train the attention model with the specified bias type."""
167+
# Generate bias parameters
168+
bias = self.bias_generators[bias_type]()
169+
optimizer = Adam([bias], lr=self.lr)
170+
171+
# Generate dummy dataset
172+
dataset = self.generate_dummy_data()
173+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
174+
175+
# Create bias function closure
176+
def bias_func(score, b, h, q_idx, kv_idx):
177+
return self.bias_functions[bias_type](score, b, h, q_idx, kv_idx, bias)
178+
179+
# Compile the attention function
180+
flex_compiled = torch.compile(
181+
flex_attention, backend="eager", fullgraph=True, dynamic=False
182+
)
183+
184+
# Training loop
185+
for epoch in range(num_epochs):
186+
total_loss = 0.0
187+
with tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
188+
for batch_idx, (q_batch, k_batch, v_batch, targets) in enumerate(pbar):
189+
q_batch.requires_grad_()
190+
optimizer.zero_grad()
191+
192+
# Forward pass
193+
outputs = flex_compiled(
194+
q_batch[0], k_batch[0], v_batch[0], score_mod=bias_func
195+
)
196+
197+
# Compute loss (MSE for this example)
198+
loss = F.mse_loss(outputs, targets[0])
199+
200+
# Backward pass
201+
loss.backward()
202+
optimizer.step()
203+
204+
total_loss += loss.item()
205+
pbar.set_postfix({"loss": f"{loss.item():.6f}"})
206+
207+
avg_loss = total_loss / len(dataloader)
208+
logger.info(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.6f}")
209+
210+
return bias, avg_loss
211+
212+
213+
def main(
214+
bias_type: BiasType = BiasType.RELATIVE_1D,
215+
num_epochs: int = 100,
216+
batch_size: int = 4,
217+
):
218+
trainer = AttentionTrainer()
219+
trained_bias, final_loss = trainer.train(
220+
bias_type=bias_type,
221+
num_epochs=num_epochs,
222+
batch_size=batch_size,
223+
)
224+
225+
logger.info(f"Final loss: {final_loss:.6f}")
226+
227+
228+
if __name__ == "__main__":
229+
from jsonargparse import CLI
230+
231+
CLI(main)

0 commit comments

Comments
 (0)