|
| 1 | +import functools |
| 2 | +import logging |
| 3 | +import torch |
| 4 | +import torch.nn.functional as F |
| 5 | +import json |
| 6 | +import argparse |
| 7 | +from torch.nn.attention.flex_attention import flex_attention |
| 8 | +from typing import Callable, Dict, List, Tuple, Optional |
| 9 | +from enum import Enum, auto |
| 10 | +from torch.optim import Adam |
| 11 | +from torch.utils.data import DataLoader, TensorDataset |
| 12 | +from tqdm import tqdm |
| 13 | + |
| 14 | +logging.basicConfig( |
| 15 | + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| 16 | +) |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class BiasType(Enum): |
| 21 | + RELATIVE_1D = "relative_1d" |
| 22 | + ABSOLUTE_2D = "absolute_2d" |
| 23 | + HEAD_SPECIFIC = "head_specific" |
| 24 | + BATCH_HEAD = "batch_head" |
| 25 | + MULTIPLICATIVE = "multiplicative" |
| 26 | + LOCAL_WINDOW = "local_window" |
| 27 | + GLOBAL_TOKENS = "global_tokens" |
| 28 | + WEIRD = "weird" |
| 29 | + OFFSET = "offset" |
| 30 | + |
| 31 | + |
| 32 | +class AttentionTrainer: |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + batch_size: int = 8, |
| 36 | + num_heads: int = 4, |
| 37 | + seq_length: int = 256, |
| 38 | + head_dim: int = 64, |
| 39 | + device: str = "cuda", |
| 40 | + dtype: torch.dtype = torch.float32, |
| 41 | + window_size: int = 16, |
| 42 | + learning_rate: float = 1e-1, |
| 43 | + ): |
| 44 | + self.B = batch_size |
| 45 | + self.H = num_heads |
| 46 | + self.S = seq_length |
| 47 | + self.D = head_dim |
| 48 | + self.W = window_size |
| 49 | + self.device = device |
| 50 | + self.dtype = dtype |
| 51 | + self.lr = learning_rate |
| 52 | + self.which_bias = torch.tensor(0, device=device) |
| 53 | + self.offset = None |
| 54 | + |
| 55 | + # Initialize bias generators and functions like in the original |
| 56 | + self.bias_generators = { |
| 57 | + BiasType.RELATIVE_1D: self._generate_relative_1d_bias, |
| 58 | + BiasType.ABSOLUTE_2D: self._generate_absolute_2d_bias, |
| 59 | + BiasType.HEAD_SPECIFIC: self._generate_head_specific_bias, |
| 60 | + BiasType.BATCH_HEAD: self._generate_batch_head_bias, |
| 61 | + BiasType.MULTIPLICATIVE: self._generate_multiplicative_bias, |
| 62 | + BiasType.LOCAL_WINDOW: self._generate_local_window_bias, |
| 63 | + BiasType.GLOBAL_TOKENS: self._generate_global_tokens_bias, |
| 64 | + BiasType.WEIRD: self._generate_weird_bias, |
| 65 | + BiasType.OFFSET: self._generate_offset_bias, |
| 66 | + } |
| 67 | + |
| 68 | + # Copy the bias application functions from the original |
| 69 | + self.bias_functions = { |
| 70 | + BiasType.RELATIVE_1D: self._apply_relative_1d_bias, |
| 71 | + BiasType.ABSOLUTE_2D: self._apply_absolute_2d_bias, |
| 72 | + BiasType.HEAD_SPECIFIC: self._apply_head_specific_bias, |
| 73 | + BiasType.BATCH_HEAD: self._apply_batch_head_bias, |
| 74 | + BiasType.MULTIPLICATIVE: self._apply_multiplicative_bias, |
| 75 | + BiasType.LOCAL_WINDOW: self._apply_local_window_bias, |
| 76 | + BiasType.GLOBAL_TOKENS: self._apply_global_tokens_bias, |
| 77 | + BiasType.WEIRD: self._apply_weird_bias, |
| 78 | + BiasType.OFFSET: self._apply_offset_bias, |
| 79 | + } |
| 80 | + |
| 81 | + def _generate_tensor(self, *size): |
| 82 | + return torch.randn( |
| 83 | + *size, device=self.device, dtype=self.dtype, requires_grad=True |
| 84 | + ) |
| 85 | + |
| 86 | + # Bias Generators |
| 87 | + |
| 88 | + def _generate_relative_1d_bias(self): |
| 89 | + return self._generate_tensor(2 * self.S) |
| 90 | + |
| 91 | + def _generate_absolute_2d_bias(self): |
| 92 | + return self._generate_tensor(self.S, self.S) |
| 93 | + |
| 94 | + def _generate_head_specific_bias(self): |
| 95 | + return self._generate_tensor(self.H, self.S, self.S) |
| 96 | + |
| 97 | + def _generate_batch_head_bias(self): |
| 98 | + return self._generate_tensor(self.B, self.H, self.S, self.S) |
| 99 | + |
| 100 | + def _generate_multiplicative_bias(self): |
| 101 | + return self._generate_tensor(self.S) |
| 102 | + |
| 103 | + def _generate_local_window_bias(self): |
| 104 | + return self._generate_tensor(2 * self.W + 1) |
| 105 | + |
| 106 | + def _generate_learned_pattern_bias(self): |
| 107 | + return self._generate_tensor(self.H, self.D) |
| 108 | + |
| 109 | + def _generate_global_tokens_bias(self): |
| 110 | + return self._generate_tensor(self.S) |
| 111 | + |
| 112 | + def _generate_weird_bias(self): |
| 113 | + return self._generate_tensor(self.B, self.H, 4, self.S) |
| 114 | + |
| 115 | + def _generate_offset_bias(self): |
| 116 | + # Generate both the bias and offset tensors |
| 117 | + bias = self._generate_tensor(self.S) |
| 118 | + self.offset = torch.randint(0, self.S, (self.S,), device=self.device) |
| 119 | + return bias |
| 120 | + |
| 121 | + # Bias Application Functions |
| 122 | + def _apply_relative_1d_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 123 | + return score + bias[torch.abs(q_idx - kv_idx)] |
| 124 | + |
| 125 | + def _apply_absolute_2d_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 126 | + return score + bias[q_idx, kv_idx] |
| 127 | + |
| 128 | + def _apply_head_specific_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 129 | + return score + bias[h, q_idx, kv_idx] |
| 130 | + |
| 131 | + def _apply_batch_head_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 132 | + return score + bias[b, h, q_idx, kv_idx] |
| 133 | + |
| 134 | + def _apply_multiplicative_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 135 | + return score * bias[q_idx] |
| 136 | + |
| 137 | + def _apply_local_window_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 138 | + window_idx = torch.clamp(q_idx - kv_idx + self.W, 0, 2 * self.W) |
| 139 | + return score + bias[window_idx] |
| 140 | + |
| 141 | + def _apply_global_tokens_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 142 | + return score + bias[kv_idx] |
| 143 | + |
| 144 | + def _apply_weird_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 145 | + return score + bias[b, h, self.which_bias, q_idx] |
| 146 | + |
| 147 | + def _apply_offset_bias(self, score, b, h, q_idx, kv_idx, bias): |
| 148 | + return score + bias[self.offset[q_idx]] |
| 149 | + |
| 150 | + def generate_dummy_data(self) -> TensorDataset: |
| 151 | + """Generate dummy training data.""" |
| 152 | + queries = torch.randn(1, self.B, self.H, self.S, self.D, device=self.device) |
| 153 | + keys = torch.randn(1, self.B, self.H, self.S, self.D, device=self.device) |
| 154 | + values = torch.randn(1, self.B, self.H, self.S, self.D, device=self.device) |
| 155 | + |
| 156 | + targets = torch.randn(1, self.B, self.H, self.S, self.D, device=self.device) |
| 157 | + |
| 158 | + return TensorDataset(queries, keys, values, targets) |
| 159 | + |
| 160 | + def train( |
| 161 | + self, |
| 162 | + bias_type: BiasType = BiasType.RELATIVE_1D, |
| 163 | + num_epochs: int = 10, |
| 164 | + batch_size: int = 4, |
| 165 | + ): |
| 166 | + """Train the attention model with the specified bias type.""" |
| 167 | + # Generate bias parameters |
| 168 | + bias = self.bias_generators[bias_type]() |
| 169 | + optimizer = Adam([bias], lr=self.lr) |
| 170 | + |
| 171 | + # Generate dummy dataset |
| 172 | + dataset = self.generate_dummy_data() |
| 173 | + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| 174 | + |
| 175 | + # Create bias function closure |
| 176 | + def bias_func(score, b, h, q_idx, kv_idx): |
| 177 | + return self.bias_functions[bias_type](score, b, h, q_idx, kv_idx, bias) |
| 178 | + |
| 179 | + # Compile the attention function |
| 180 | + flex_compiled = torch.compile( |
| 181 | + flex_attention, backend="eager", fullgraph=True, dynamic=False |
| 182 | + ) |
| 183 | + |
| 184 | + # Training loop |
| 185 | + for epoch in range(num_epochs): |
| 186 | + total_loss = 0.0 |
| 187 | + with tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar: |
| 188 | + for batch_idx, (q_batch, k_batch, v_batch, targets) in enumerate(pbar): |
| 189 | + q_batch.requires_grad_() |
| 190 | + optimizer.zero_grad() |
| 191 | + |
| 192 | + # Forward pass |
| 193 | + outputs = flex_compiled( |
| 194 | + q_batch[0], k_batch[0], v_batch[0], score_mod=bias_func |
| 195 | + ) |
| 196 | + |
| 197 | + # Compute loss (MSE for this example) |
| 198 | + loss = F.mse_loss(outputs, targets[0]) |
| 199 | + |
| 200 | + # Backward pass |
| 201 | + loss.backward() |
| 202 | + optimizer.step() |
| 203 | + |
| 204 | + total_loss += loss.item() |
| 205 | + pbar.set_postfix({"loss": f"{loss.item():.6f}"}) |
| 206 | + |
| 207 | + avg_loss = total_loss / len(dataloader) |
| 208 | + logger.info(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.6f}") |
| 209 | + |
| 210 | + return bias, avg_loss |
| 211 | + |
| 212 | + |
| 213 | +def main( |
| 214 | + bias_type: BiasType = BiasType.RELATIVE_1D, |
| 215 | + num_epochs: int = 100, |
| 216 | + batch_size: int = 4, |
| 217 | +): |
| 218 | + trainer = AttentionTrainer() |
| 219 | + trained_bias, final_loss = trainer.train( |
| 220 | + bias_type=bias_type, |
| 221 | + num_epochs=num_epochs, |
| 222 | + batch_size=batch_size, |
| 223 | + ) |
| 224 | + |
| 225 | + logger.info(f"Final loss: {final_loss:.6f}") |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + from jsonargparse import CLI |
| 230 | + |
| 231 | + CLI(main) |
0 commit comments