|
| 1 | +import math |
| 2 | +import torch |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch import nn |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import Tuple |
| 7 | + |
| 8 | + |
| 9 | + |
| 10 | +@dataclass |
| 11 | +class ModelConfig: |
| 12 | + vocab_size: int |
| 13 | + seq_len: int |
| 14 | + dim: int |
| 15 | + expand: float |
| 16 | + n_layers: int |
| 17 | + n_heads: int |
| 18 | + rmsnorm_eps: float = 1e-6 |
| 19 | + tie_embeddings: bool = False |
| 20 | + |
| 21 | + |
| 22 | +class MLP(nn.Module): |
| 23 | + |
| 24 | + def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256): |
| 25 | + super().__init__() |
| 26 | + hidden_dim = multiple_of * ( |
| 27 | + (hidden_dim + multiple_of - 1) // multiple_of) |
| 28 | + self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False) |
| 29 | + self.fc2 = nn.Linear(hidden_dim, dim, bias=False) |
| 30 | + self.glu = nn.GLU(dim=2) |
| 31 | + |
| 32 | + # Initialize with Xavier uniform |
| 33 | + nn.init.xavier_uniform_(self.fc1.weight) |
| 34 | + nn.init.xavier_uniform_(self.fc2.weight) |
| 35 | + |
| 36 | + def forward(self, x): |
| 37 | + # x: (bsz, T, dim) |
| 38 | + return self.fc2(self.glu(self.fc1(x))) |
| 39 | + |
| 40 | + |
| 41 | +def precompute_freqs_cis(dim: int, |
| 42 | + end: int, |
| 43 | + theta: float = 10000.0, |
| 44 | + condense_ratio: int = 1): |
| 45 | + inv_freqs = 1.0 / (theta**(torch.arange( |
| 46 | + 0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim)) |
| 47 | + t = torch.arange(end, dtype=torch.float32, |
| 48 | + device=inv_freqs.device) / condense_ratio |
| 49 | + freqs = torch.outer(t, inv_freqs).float() |
| 50 | + return torch.stack([ |
| 51 | + torch.cos(freqs)[None, :, None, :], |
| 52 | + torch.sin(freqs)[None, :, None, :] |
| 53 | + ], |
| 54 | + dim=4) |
| 55 | + |
| 56 | + |
| 57 | +def apply_rotary_emb_complex_like( |
| 58 | + q: torch.Tensor, k: torch.Tensor, |
| 59 | + freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 60 | + # Rotate query and key vectors using RoPE |
| 61 | + qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() |
| 62 | + rotated_qk_r2 = torch.stack( |
| 63 | + [ |
| 64 | + qk_r2[..., 0] * freqs_cis[..., 0] - |
| 65 | + qk_r2[..., 1] * freqs_cis[..., 1], |
| 66 | + qk_r2[..., 1] * freqs_cis[..., 0] + |
| 67 | + qk_r2[..., 0] * freqs_cis[..., 1], |
| 68 | + ], |
| 69 | + -1, |
| 70 | + ).flatten(3) |
| 71 | + rotated_qk = rotated_qk_r2 |
| 72 | + return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) |
| 73 | + |
| 74 | + |
| 75 | +class Attention(nn.Module): |
| 76 | + |
| 77 | + def __init__(self, cfg: ModelConfig): |
| 78 | + super().__init__() |
| 79 | + assert cfg.dim % cfg.n_heads == 0 |
| 80 | + self.dim = cfg.dim |
| 81 | + self.n_heads = cfg.n_heads |
| 82 | + self.head_dim = cfg.dim // cfg.n_heads |
| 83 | + |
| 84 | + self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False) |
| 85 | + self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False) |
| 86 | + |
| 87 | + def forward(self, x, freqs_cis): |
| 88 | + bsz, seqlen, d = x.shape # (bsz, seqlen, d) |
| 89 | + |
| 90 | + q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d) |
| 91 | + q = q.view(bsz, seqlen, self.n_heads, |
| 92 | + self.head_dim) # (bsz, seqlen, nh, h_dim) |
| 93 | + k = k.view(bsz, seqlen, self.n_heads, |
| 94 | + self.head_dim) # (bsz, seqlen, nh, h_dim) |
| 95 | + v = v.view(bsz, seqlen, self.n_heads, |
| 96 | + self.head_dim) # (bsz, seqlen, nh, h_dim) |
| 97 | + |
| 98 | + q, k = apply_rotary_emb_complex_like( |
| 99 | + q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim) |
| 100 | + |
| 101 | + q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim) |
| 102 | + k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim) |
| 103 | + v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim) |
| 104 | + |
| 105 | + out = F.scaled_dot_product_attention( |
| 106 | + q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim) |
| 107 | + |
| 108 | + out = out.transpose(1, 2).contiguous().view(bsz, seqlen, |
| 109 | + d) # (bsz, seqlen, d) |
| 110 | + |
| 111 | + return self.w_out(out) |
| 112 | + |
| 113 | + |
| 114 | +class Block(nn.Module): |
| 115 | + |
| 116 | + def __init__(self, layer_id: int, cfg: ModelConfig): |
| 117 | + super().__init__() |
| 118 | + self.attn = Attention(cfg) |
| 119 | + self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) |
| 120 | + self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim)) |
| 121 | + self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) |
| 122 | + self.layer_id = layer_id |
| 123 | + |
| 124 | + def forward(self, x, freqs_cis): |
| 125 | + # x: (bsz, seqlen, dim) |
| 126 | + x = x + self.attn(self.attn_norm(x), freqs_cis) |
| 127 | + x = x + self.mlp(self.mlp_norm(x)) |
| 128 | + return x |
| 129 | + |
| 130 | + |
| 131 | +class Transformer(nn.Module): |
| 132 | + |
| 133 | + def __init__(self, cfg): |
| 134 | + super().__init__() |
| 135 | + self.n_layers = cfg.n_layers |
| 136 | + self.cfg = cfg |
| 137 | + head_dim = cfg.dim // cfg.n_heads |
| 138 | + assert cfg.dim % cfg.n_heads == 0 |
| 139 | + |
| 140 | + self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim) |
| 141 | + self.layers = nn.ModuleList( |
| 142 | + [Block(idx, cfg) for idx in range(cfg.n_layers)]) |
| 143 | + self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps) |
| 144 | + self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) |
| 145 | + |
| 146 | + # Initialize freqs_cis on CPU first (more memory efficient) |
| 147 | + self.register_buffer('freqs_cis', |
| 148 | + precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len], |
| 149 | + persistent=False) |
| 150 | + |
| 151 | + # init all weights, scale residual branches |
| 152 | + self.apply(self._init_weights) |
| 153 | + self._scale_residual_branches() |
| 154 | + |
| 155 | + # Move model to device (which will also move freqs_cis) |
| 156 | + if torch.cuda.is_available(): |
| 157 | + self.cuda() |
| 158 | + |
| 159 | + if cfg.tie_embeddings: |
| 160 | + self.tie_weights() |
| 161 | + |
| 162 | + def forward(self, x): |
| 163 | + # x: (bsz, seqlen) |
| 164 | + x = self.embed_tokens(x) # (bsz, seqlen, dim) |
| 165 | + L = x.shape[1] |
| 166 | + |
| 167 | + # Make sure we have enough precomputed frequencies |
| 168 | + if L > self.freqs_cis.shape[1]: |
| 169 | + # Need to recompute for longer sequence |
| 170 | + head_dim = self.cfg.dim // self.cfg.n_heads |
| 171 | + new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000) |
| 172 | + self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False) |
| 173 | + if torch.cuda.is_available(): |
| 174 | + self.freqs_cis = self.freqs_cis.cuda() |
| 175 | + |
| 176 | + # Select the frequencies for current sequence length and ensure correct device |
| 177 | + freqs_cis = self.freqs_cis[:, :L, :].to(x.device) |
| 178 | + |
| 179 | + for layer in self.layers: |
| 180 | + x = layer(x, freqs_cis) # (bsz, seqlen, dim) |
| 181 | + return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size) |
| 182 | + |
| 183 | + def predict(self, x, k=1): |
| 184 | + """Generate k tokens autoregressively. |
| 185 | +
|
| 186 | + Args: |
| 187 | + x: Input token sequence of shape (batch_size, seq_len) |
| 188 | + k: Number of tokens to predict |
| 189 | +
|
| 190 | + Returns: |
| 191 | + Tuple of (input_ids, predicted_ids) |
| 192 | + """ |
| 193 | + # For debugging |
| 194 | + predictions = [] |
| 195 | + |
| 196 | + batch_size = x.shape[0] |
| 197 | + seq_len = x.shape[1] |
| 198 | + |
| 199 | + # Store original input |
| 200 | + original_input = x.clone() |
| 201 | + generated_input = x.clone() |
| 202 | + |
| 203 | + # Generate k tokens autoregressively |
| 204 | + for i in range(k): |
| 205 | + # Get logits for the entire sequence |
| 206 | + logits = self(generated_input) |
| 207 | + |
| 208 | + # Get the logits for the last token in each sequence |
| 209 | + next_token_logits = logits[:, -1, :] |
| 210 | + |
| 211 | + # Zero out the last token ID to prevent repetition |
| 212 | + # This is a common issue - the model gets stuck repeating the last token |
| 213 | + last_token_id = generated_input[:, -1] |
| 214 | + next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf')) |
| 215 | + |
| 216 | + # Print top 5 tokens for debugging |
| 217 | + if i == 0: |
| 218 | + print("\nPyTorch detailed prediction:") |
| 219 | + top5_values, top5_indices = torch.topk(next_token_logits[0], 5) |
| 220 | + for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())): |
| 221 | + prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item() |
| 222 | + print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}") |
| 223 | + |
| 224 | + # Get the most likely token |
| 225 | + next_token = torch.argmax(next_token_logits, dim=-1) |
| 226 | + predictions.append(next_token.item()) |
| 227 | + |
| 228 | + # Append the predicted token to the sequence |
| 229 | + next_token = next_token.unsqueeze(1) # Add sequence dimension |
| 230 | + generated_input = torch.cat([generated_input, next_token], dim=1) |
| 231 | + |
| 232 | + print(f" Full predictions step by step: {predictions}") |
| 233 | + |
| 234 | + # Return all tokens, not just the last k |
| 235 | + return original_input, generated_input[:, -k:] |
| 236 | + |
| 237 | + def _init_weights(self, module): |
| 238 | + if isinstance(module, nn.Linear): |
| 239 | + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| 240 | + if module.bias is not None: |
| 241 | + torch.nn.init.zeros_(module.bias) |
| 242 | + elif isinstance(module, nn.Embedding): |
| 243 | + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| 244 | + |
| 245 | + def _scale_residual_branches(self): |
| 246 | + for n, p in self.named_parameters(): |
| 247 | + if n.endswith("fc2.weight"): # mlp/glu output layer |
| 248 | + torch.nn.init.normal_(p, |
| 249 | + mean=0.0, |
| 250 | + std=0.02 / math.sqrt(2 * self.n_layers)) |
| 251 | + if n.endswith("w_out.weight"): # attn output layer |
| 252 | + torch.nn.init.normal_(p, |
| 253 | + mean=0.0, |
| 254 | + std=0.02 / math.sqrt(2 * self.n_layers)) |
| 255 | + |
| 256 | + def tie_weights(self): |
| 257 | + self.lm_head.weight = self.embed_tokens.weight |
| 258 | + |
| 259 | + def count_params(self, non_embedding=True): |
| 260 | + n_params = sum(p.numel() for p in self.parameters()) |
| 261 | + if non_embedding: |
| 262 | + n_params -= self.embed_tokens.weight.numel() |
| 263 | + if (not self.lm_head.weight |
| 264 | + is self.embed_tokens.weight): # if no weight tying |
| 265 | + n_params -= self.lm_head.weight.numel() |
| 266 | + return n_params |
| 267 | + |
| 268 | + |
| 269 | +def main(): |
| 270 | + print("Initializing transformer model and running forward pass...") |
| 271 | + |
| 272 | + seq_length = 512 |
| 273 | + |
| 274 | + # Define model configuration |
| 275 | + config = ModelConfig( |
| 276 | + vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece |
| 277 | + seq_len=seq_length, # Maximum sequence length |
| 278 | + dim=768, # Embedding dimension |
| 279 | + expand=4.0, # MLP expansion factor |
| 280 | + n_layers=12, # Number of transformer layers |
| 281 | + n_heads=12, # Number of attention heads |
| 282 | + rmsnorm_eps=1e-6, # RMSNorm epsilon |
| 283 | + tie_embeddings=True # Tie embedding and output weights |
| 284 | + ) |
| 285 | + |
| 286 | + def tie_weights(self): |
| 287 | + self.lm_head.weight = self.embed_tokens.weight |
| 288 | + |
| 289 | + def count_params(self, non_embedding=True): |
| 290 | + n_params = sum(p.numel() for p in self.parameters()) |
| 291 | + if non_embedding: |
| 292 | + n_params -= self.embed_tokens.weight.numel() |
| 293 | + if (not self.lm_head.weight |
| 294 | + is self.embed_tokens.weight): # if no weight tying |
| 295 | + n_params -= self.lm_head.weight.numel() |
| 296 | + return n_params |
| 297 | + |
| 298 | + |
0 commit comments