|
| 1 | +# Self-contained version of the DecoderOnly Transformer from NanoDO |
| 2 | + |
| 3 | +import dataclasses |
| 4 | +from functools import partial |
| 5 | + |
| 6 | +from flax import linen as nn |
| 7 | +import jax |
| 8 | +import jax.numpy as jnp |
| 9 | + |
| 10 | +# =========== Transformer Decoder-only Model ========== |
| 11 | + |
| 12 | + |
| 13 | + |
| 14 | +@dataclasses.dataclass |
| 15 | +class DoConfig: |
| 16 | + """Hyper-parameters for Transformer decoder-only.""" |
| 17 | + |
| 18 | + D: int # model/embed dim = qkv dim |
| 19 | + H: int # num attention heads |
| 20 | + L: int # max context/sequence length |
| 21 | + N: int # number of transformer block layers |
| 22 | + V: int # vocab size |
| 23 | + F: int # FF inner dimension |
| 24 | + kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() |
| 25 | + embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( |
| 26 | + 1.0, "fan_in", "normal", out_axis=0 |
| 27 | + ) |
| 28 | + dtype: jnp.dtype = jnp.float32 |
| 29 | + rmsnorm_epsilon: float = 1e-6 |
| 30 | + multiple_of: int = 256 |
| 31 | + tie_embeddings: bool = True # Whether to tie input and output embeddings |
| 32 | + |
| 33 | + |
| 34 | +class Mlp(nn.Module): |
| 35 | + """Multilayer perceptron with GLU activation.""" |
| 36 | + |
| 37 | + cfg: DoConfig |
| 38 | + |
| 39 | + @nn.compact |
| 40 | + def __call__(self, x_BxLxD: jax.Array): |
| 41 | + cfg = self.cfg |
| 42 | + # Use Xavier uniform initialization explicitly |
| 43 | + xavier_init = nn.initializers.xavier_uniform() |
| 44 | + linear = partial( |
| 45 | + nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype |
| 46 | + ) |
| 47 | + hidden_dim = cfg.multiple_of * ( |
| 48 | + (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of |
| 49 | + ) |
| 50 | + # Double the hidden dimension for GLU |
| 51 | + x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD) |
| 52 | + # Apply GLU activation |
| 53 | + x_BxLxF = nn.glu(x_BxLx2F, axis=-1) |
| 54 | + x_BxLxD = linear(cfg.D)(x_BxLxF) |
| 55 | + return x_BxLxD |
| 56 | + |
| 57 | +@partial(jax.jit, static_argnums=(0,1,2)) |
| 58 | +def init_rope(dim=256, seq_len=128, n_heads=4): |
| 59 | + """Initialize rotary embeddings.""" |
| 60 | + def precompute_freqs_cis_jax(dim, end, theta=10000.0): |
| 61 | + inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim)) |
| 62 | + t = jnp.arange(end) / 1.0 |
| 63 | + freqs = jnp.outer(t, inv_freqs).astype(jnp.float32) |
| 64 | + return jnp.stack([ |
| 65 | + jnp.cos(freqs)[None, :, None, :], |
| 66 | + jnp.sin(freqs)[None, :, None, :] |
| 67 | + ], axis=3) |
| 68 | + |
| 69 | + freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000) |
| 70 | + return freqs_cis.transpose(0, 1, 2, 4, 3) |
| 71 | + |
| 72 | +@jax.jit |
| 73 | +def apply_rope(q, k, freqs_cis): |
| 74 | + """Apply rotary embeddings to Q and K.""" |
| 75 | + def rotate_tensor(x): |
| 76 | + # Split into real and imaginary parts |
| 77 | + x_r2 = x.reshape(*x.shape[:-1], -1, 2) |
| 78 | + L = x.shape[1] |
| 79 | + freqs = freqs_cis[:, :L, :, :, :] |
| 80 | + |
| 81 | + # Apply rotation |
| 82 | + rotated_x_r2 = jnp.stack([ |
| 83 | + x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1], |
| 84 | + x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1] |
| 85 | + ], axis=-1) |
| 86 | + |
| 87 | + return rotated_x_r2.reshape(*x.shape) |
| 88 | + |
| 89 | + # Apply rotation to Q and K separately |
| 90 | + rotated_q = rotate_tensor(q) |
| 91 | + rotated_k = rotate_tensor(k) |
| 92 | + |
| 93 | + return rotated_q, rotated_k |
| 94 | + |
| 95 | + |
| 96 | +class CausalAttn(nn.Module): |
| 97 | + """Causal attention layer with rotary embeddings.""" |
| 98 | + |
| 99 | + cfg: DoConfig |
| 100 | + |
| 101 | + def setup(self): |
| 102 | + cfg = self.cfg |
| 103 | + assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}" |
| 104 | + self.Dh = cfg.D // cfg.H |
| 105 | + |
| 106 | + # Initialize rotary embeddings |
| 107 | + self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H) |
| 108 | + |
| 109 | + # Maps D -> (H, Dh) |
| 110 | + self.multilinear = partial( |
| 111 | + nn.DenseGeneral, |
| 112 | + axis=-1, |
| 113 | + features=(cfg.H, self.Dh), |
| 114 | + kernel_init=cfg.kernel_init, |
| 115 | + use_bias=False, |
| 116 | + dtype=cfg.dtype, |
| 117 | + ) |
| 118 | + |
| 119 | + self.multilinear_query = self.multilinear(name="query") |
| 120 | + self.multilinear_key = self.multilinear(name="key") |
| 121 | + self.multilinear_value = self.multilinear(name="value") |
| 122 | + self.output_projection = nn.DenseGeneral( |
| 123 | + features=cfg.D, |
| 124 | + name="attn_out_proj", |
| 125 | + # axis=(-2, -1), # |
| 126 | + kernel_init=cfg.kernel_init, |
| 127 | + use_bias=False, |
| 128 | + dtype=cfg.dtype, |
| 129 | + ) |
| 130 | + |
| 131 | + def __call__(self, x_BxLxD: jax.Array): |
| 132 | + cfg = self.cfg |
| 133 | + |
| 134 | + # Project inputs to Q, K, V |
| 135 | + q_BxLxHxDh = self.multilinear_query(x_BxLxD) |
| 136 | + k_BxLxHxDh = self.multilinear_key(x_BxLxD) |
| 137 | + v_BxLxHxDh = self.multilinear_value(x_BxLxD) |
| 138 | + |
| 139 | + # Apply rotary embeddings to Q and K |
| 140 | + q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) |
| 141 | + |
| 142 | + # Scale queries |
| 143 | + q_BxLxHxDh /= self.Dh**0.5 |
| 144 | + |
| 145 | + # Compute attention scores |
| 146 | + att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh) |
| 147 | + |
| 148 | + # Causal attention mask |
| 149 | + L = x_BxLxD.shape[1] |
| 150 | + mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_)) |
| 151 | + |
| 152 | + # Apply mask and softmax |
| 153 | + _NEG_INF = jnp.finfo(cfg.dtype).min |
| 154 | + att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF) |
| 155 | + att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1) |
| 156 | + att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype) |
| 157 | + |
| 158 | + # Compute attention output |
| 159 | + out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh) |
| 160 | + |
| 161 | + # Reshape and project output |
| 162 | + out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape) |
| 163 | + |
| 164 | + # Output projection |
| 165 | + out_BxLxD = self.output_projection(out_BxLxD) |
| 166 | + |
| 167 | + return out_BxLxD |
| 168 | + |
| 169 | + |
| 170 | +class TBlock(nn.Module): |
| 171 | + """Transformer Block.""" |
| 172 | + |
| 173 | + docfg: DoConfig |
| 174 | + |
| 175 | + @nn.compact |
| 176 | + def __call__(self, in_BxLxD: jax.Array): |
| 177 | + cfg = self.docfg |
| 178 | + |
| 179 | + # x = x + attn( attn_norm(x) ) |
| 180 | + x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( |
| 181 | + in_BxLxD |
| 182 | + ) |
| 183 | + x_BxLxD = CausalAttn(cfg)(x_BxLxD) |
| 184 | + x_BxLxD += in_BxLxD |
| 185 | + |
| 186 | + # x = x + mlp( mlp_norm(x) ) |
| 187 | + z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)( |
| 188 | + x_BxLxD |
| 189 | + ) |
| 190 | + z_BxLxD = Mlp(cfg)(z_BxLxD) |
| 191 | + |
| 192 | + return x_BxLxD + z_BxLxD |
| 193 | + |
| 194 | + |
| 195 | +class TransformerDo(nn.Module): |
| 196 | + """Transformer decoder-only.""" |
| 197 | + |
| 198 | + docfg: DoConfig |
| 199 | + |
| 200 | + def setup(self): |
| 201 | + cfg = self.docfg |
| 202 | + self.embed = nn.Embed( |
| 203 | + num_embeddings=cfg.V, |
| 204 | + features=cfg.D, |
| 205 | + embedding_init=cfg.embed_init, |
| 206 | + ) |
| 207 | + |
| 208 | + self.blocks = [TBlock(cfg) for _ in range(cfg.N)] |
| 209 | + self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon) |
| 210 | + |
| 211 | + # Output projection - tied to input embeddings if configured |
| 212 | + if cfg.tie_embeddings: |
| 213 | + self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32)) |
| 214 | + else: |
| 215 | + self.output_proj = nn.Dense( |
| 216 | + cfg.V, |
| 217 | + kernel_init=cfg.embed_init, |
| 218 | + dtype=cfg.dtype, |
| 219 | + name="output_proj" |
| 220 | + ) |
| 221 | + |
| 222 | + def __call__(self, y_BxL: jax.Array): |
| 223 | + # For training on concatenated examples. |
| 224 | + y_BxLxD = self.embed(y_BxL) |
| 225 | + for block in self.blocks: |
| 226 | + y_BxLxD = block(y_BxLxD) |
| 227 | + y_BxLxD = self.out_ln(y_BxLxD) |
| 228 | + logits_BxLxV = self.output_proj(y_BxLxD) |
| 229 | + return logits_BxLxV |
| 230 | + |
| 231 | + def predict(self, y_BxL: jax.Array, k: int = 1): |
| 232 | + """Generate k tokens autoregressively. |
| 233 | +
|
| 234 | + Args: |
| 235 | + y_BxL: Input token sequence of shape (batch_size, seq_len) |
| 236 | + k: Number of tokens to predict |
| 237 | +
|
| 238 | + Returns: |
| 239 | + Tuple of (input_ids, predicted_ids) |
| 240 | + """ |
| 241 | + cfg = self.docfg |
| 242 | + batch_size = y_BxL.shape[0] |
| 243 | + seq_len = y_BxL.shape[1] |
| 244 | + |
| 245 | + # Store original input |
| 246 | + original_input = y_BxL |
| 247 | + |
| 248 | + # Make sure we don't exceed the model's context length |
| 249 | + if seq_len + k > cfg.L: |
| 250 | + raise ValueError( |
| 251 | + f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})" |
| 252 | + ) |
| 253 | + |
| 254 | + # Generate k tokens autoregressively |
| 255 | + for _ in range(k): |
| 256 | + # Get logits for the entire sequence |
| 257 | + logits = self(y_BxL) |
| 258 | + |
| 259 | + # Get the logits for the last token in each sequence |
| 260 | + next_token_logits = logits[:, -1, :] |
| 261 | + |
| 262 | + # Get the most likely token |
| 263 | + next_token = jnp.argmax(next_token_logits, axis=-1) |
| 264 | + |
| 265 | + # Append the predicted token to the sequence |
| 266 | + y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1) |
| 267 | + |
| 268 | + # Return original input and the k predicted tokens |
| 269 | + return original_input, y_BxL[:, -k:] |
| 270 | + |
| 271 | + |
| 272 | +# =========== Demo Code ========== |
| 273 | + |
| 274 | + |
| 275 | +def main(): |
| 276 | + """Create and run the DecoderOnly Transformer model.""" |
| 277 | + # Initialize model configuration with smaller parameters for demo |
| 278 | + B, L = (2, 128) # Batch size, sequence length |
| 279 | + cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128) |
| 280 | + model = TransformerDo(cfg) |
| 281 | + |
| 282 | + # Print model info |
| 283 | + print(f"\nModel Configuration:") |
| 284 | + print(f" - Model dimension (D): {cfg.D}") |
| 285 | + print(f" - Number of heads (H): {cfg.H}") |
| 286 | + print(f" - Max sequence length (L): {cfg.L}") |
| 287 | + print(f" - Number of layers (N): {cfg.N}") |
| 288 | + print(f" - Vocabulary size (V): {cfg.V}") |
| 289 | + print(f" - Feed forward dimension (F): {cfg.F}") |
| 290 | + |
| 291 | + # Create random input tokens (simulated token IDs) |
| 292 | + rng_key = jax.random.PRNGKey(42) |
| 293 | + input_rng, init_rng = jax.random.split(rng_key) |
| 294 | + |
| 295 | + # Generate random token IDs (integers between 0 and vocab_size-1) |
| 296 | + x_BxL = jax.random.randint( |
| 297 | + input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32 |
| 298 | + ) |
| 299 | + |
| 300 | + # Initialize model parameters |
| 301 | + print("\nInitializing model parameters...") |
| 302 | + params = model.init(init_rng, x_BxL) |
| 303 | + |
| 304 | + # Print parameter count |
| 305 | + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) |
| 306 | + print(f"Total parameters: {param_count:,}") |
| 307 | + |
| 308 | + # Make a prediction (forward pass) |
| 309 | + print("\nRunning forward pass...") |
| 310 | + logits = model.apply(params, x_BxL) |
| 311 | + |
| 312 | + # Print output shape and sample values |
| 313 | + print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)") |
| 314 | + print(f"Output data type: {logits.dtype}") |
| 315 | + |
| 316 | + # Print sample logits (first 5 positions of the first sequence) |
| 317 | + print("\nSample logits (first sequence, first 5 positions, first 5 values):") |
| 318 | + for position in range(min(5, L)): |
| 319 | + print(f" Position {position}: {logits[0, position, :5]}") |
| 320 | + |
| 321 | + # Get predictions (token with highest logit at each position) |
| 322 | + predictions = jnp.argmax(logits, axis=-1) |
| 323 | + print("\nPredicted token IDs (first sequence, first 10 positions):") |
| 324 | + print(predictions[0, :10]) |
| 325 | + |
| 326 | + # Test the predict function |
| 327 | + print("\nTesting predict function...") |
| 328 | + # Use a shorter |
| 329 | + short_seq = x_BxL[:, :10] |
| 330 | + print(f"Input sequence shape: {short_seq.shape}") |
| 331 | + |
| 332 | + # Predict 5 tokens |
| 333 | + k = 5 |
| 334 | + original, predicted = model.apply(params, short_seq, k, method=model.predict) |
| 335 | + |
| 336 | + # Get predictions (token with highest logit at each position) |
| 337 | + predictions = jnp.argmax(logits, axis=-1) |
| 338 | + print("\nPredicted token IDs (first sequence, first 10 positions):") |
| 339 | + print(predictions[0, :10]) |
| 340 | + |
| 341 | + print("\nDone!") |
| 342 | + |
| 343 | + |
| 344 | +if __name__ == "__main__": |
| 345 | + main() |
0 commit comments