Skip to content

Commit 99c7b9b

Browse files
committed
add nanodo model
1 parent 0c22f3d commit 99c7b9b

File tree

3 files changed

+386
-19
lines changed

3 files changed

+386
-19
lines changed
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Self-contained version of the DecoderOnly Transformer from NanoDO
2+
3+
import dataclasses
4+
from functools import partial
5+
6+
from flax import linen as nn
7+
import jax
8+
import jax.numpy as jnp
9+
10+
# =========== Transformer Decoder-only Model ==========
11+
12+
13+
14+
@dataclasses.dataclass
15+
class DoConfig:
16+
"""Hyper-parameters for Transformer decoder-only."""
17+
18+
D: int # model/embed dim = qkv dim
19+
H: int # num attention heads
20+
L: int # max context/sequence length
21+
N: int # number of transformer block layers
22+
V: int # vocab size
23+
F: int # FF inner dimension
24+
kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform()
25+
embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling(
26+
1.0, "fan_in", "normal", out_axis=0
27+
)
28+
dtype: jnp.dtype = jnp.float32
29+
rmsnorm_epsilon: float = 1e-6
30+
multiple_of: int = 256
31+
tie_embeddings: bool = True # Whether to tie input and output embeddings
32+
33+
34+
class Mlp(nn.Module):
35+
"""Multilayer perceptron with GLU activation."""
36+
37+
cfg: DoConfig
38+
39+
@nn.compact
40+
def __call__(self, x_BxLxD: jax.Array):
41+
cfg = self.cfg
42+
# Use Xavier uniform initialization explicitly
43+
xavier_init = nn.initializers.xavier_uniform()
44+
linear = partial(
45+
nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype
46+
)
47+
hidden_dim = cfg.multiple_of * (
48+
(cfg.F + cfg.multiple_of - 1) // cfg.multiple_of
49+
)
50+
# Double the hidden dimension for GLU
51+
x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD)
52+
# Apply GLU activation
53+
x_BxLxF = nn.glu(x_BxLx2F, axis=-1)
54+
x_BxLxD = linear(cfg.D)(x_BxLxF)
55+
return x_BxLxD
56+
57+
@partial(jax.jit, static_argnums=(0,1,2))
58+
def init_rope(dim=256, seq_len=128, n_heads=4):
59+
"""Initialize rotary embeddings."""
60+
def precompute_freqs_cis_jax(dim, end, theta=10000.0):
61+
inv_freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim))
62+
t = jnp.arange(end) / 1.0
63+
freqs = jnp.outer(t, inv_freqs).astype(jnp.float32)
64+
return jnp.stack([
65+
jnp.cos(freqs)[None, :, None, :],
66+
jnp.sin(freqs)[None, :, None, :]
67+
], axis=3)
68+
69+
freqs_cis = precompute_freqs_cis_jax(dim // n_heads, seq_len, theta=500000)
70+
return freqs_cis.transpose(0, 1, 2, 4, 3)
71+
72+
@jax.jit
73+
def apply_rope(q, k, freqs_cis):
74+
"""Apply rotary embeddings to Q and K."""
75+
def rotate_tensor(x):
76+
# Split into real and imaginary parts
77+
x_r2 = x.reshape(*x.shape[:-1], -1, 2)
78+
L = x.shape[1]
79+
freqs = freqs_cis[:, :L, :, :, :]
80+
81+
# Apply rotation
82+
rotated_x_r2 = jnp.stack([
83+
x_r2[..., 0] * freqs[..., 0] - x_r2[..., 1] * freqs[..., 1],
84+
x_r2[..., 1] * freqs[..., 0] + x_r2[..., 0] * freqs[..., 1]
85+
], axis=-1)
86+
87+
return rotated_x_r2.reshape(*x.shape)
88+
89+
# Apply rotation to Q and K separately
90+
rotated_q = rotate_tensor(q)
91+
rotated_k = rotate_tensor(k)
92+
93+
return rotated_q, rotated_k
94+
95+
96+
class CausalAttn(nn.Module):
97+
"""Causal attention layer with rotary embeddings."""
98+
99+
cfg: DoConfig
100+
101+
def setup(self):
102+
cfg = self.cfg
103+
assert cfg.D % cfg.H == 0, f"D {cfg.D} not divisible by H {cfg.H}"
104+
self.Dh = cfg.D // cfg.H
105+
106+
# Initialize rotary embeddings
107+
self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H)
108+
109+
# Maps D -> (H, Dh)
110+
self.multilinear = partial(
111+
nn.DenseGeneral,
112+
axis=-1,
113+
features=(cfg.H, self.Dh),
114+
kernel_init=cfg.kernel_init,
115+
use_bias=False,
116+
dtype=cfg.dtype,
117+
)
118+
119+
self.multilinear_query = self.multilinear(name="query")
120+
self.multilinear_key = self.multilinear(name="key")
121+
self.multilinear_value = self.multilinear(name="value")
122+
self.output_projection = nn.DenseGeneral(
123+
features=cfg.D,
124+
name="attn_out_proj",
125+
# axis=(-2, -1), #
126+
kernel_init=cfg.kernel_init,
127+
use_bias=False,
128+
dtype=cfg.dtype,
129+
)
130+
131+
def __call__(self, x_BxLxD: jax.Array):
132+
cfg = self.cfg
133+
134+
# Project inputs to Q, K, V
135+
q_BxLxHxDh = self.multilinear_query(x_BxLxD)
136+
k_BxLxHxDh = self.multilinear_key(x_BxLxD)
137+
v_BxLxHxDh = self.multilinear_value(x_BxLxD)
138+
139+
# Apply rotary embeddings to Q and K
140+
q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis)
141+
142+
# Scale queries
143+
q_BxLxHxDh /= self.Dh**0.5
144+
145+
# Compute attention scores
146+
att_BxHxLxL = jnp.einsum("...qhd,...khd->...hqk", q_BxLxHxDh, k_BxLxHxDh)
147+
148+
# Causal attention mask
149+
L = x_BxLxD.shape[1]
150+
mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_))
151+
152+
# Apply mask and softmax
153+
_NEG_INF = jnp.finfo(cfg.dtype).min
154+
att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF)
155+
att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1)
156+
att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype)
157+
158+
# Compute attention output
159+
out_BxLxHxDh = jnp.einsum("...hqk,...khd->...qhd", att_BxHxLxL, v_BxLxHxDh)
160+
161+
# Reshape and project output
162+
out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape)
163+
164+
# Output projection
165+
out_BxLxD = self.output_projection(out_BxLxD)
166+
167+
return out_BxLxD
168+
169+
170+
class TBlock(nn.Module):
171+
"""Transformer Block."""
172+
173+
docfg: DoConfig
174+
175+
@nn.compact
176+
def __call__(self, in_BxLxD: jax.Array):
177+
cfg = self.docfg
178+
179+
# x = x + attn( attn_norm(x) )
180+
x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)(
181+
in_BxLxD
182+
)
183+
x_BxLxD = CausalAttn(cfg)(x_BxLxD)
184+
x_BxLxD += in_BxLxD
185+
186+
# x = x + mlp( mlp_norm(x) )
187+
z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)(
188+
x_BxLxD
189+
)
190+
z_BxLxD = Mlp(cfg)(z_BxLxD)
191+
192+
return x_BxLxD + z_BxLxD
193+
194+
195+
class TransformerDo(nn.Module):
196+
"""Transformer decoder-only."""
197+
198+
docfg: DoConfig
199+
200+
def setup(self):
201+
cfg = self.docfg
202+
self.embed = nn.Embed(
203+
num_embeddings=cfg.V,
204+
features=cfg.D,
205+
embedding_init=cfg.embed_init,
206+
)
207+
208+
self.blocks = [TBlock(cfg) for _ in range(cfg.N)]
209+
self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)
210+
211+
# Output projection - tied to input embeddings if configured
212+
if cfg.tie_embeddings:
213+
self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32))
214+
else:
215+
self.output_proj = nn.Dense(
216+
cfg.V,
217+
kernel_init=cfg.embed_init,
218+
dtype=cfg.dtype,
219+
name="output_proj"
220+
)
221+
222+
def __call__(self, y_BxL: jax.Array):
223+
# For training on concatenated examples.
224+
y_BxLxD = self.embed(y_BxL)
225+
for block in self.blocks:
226+
y_BxLxD = block(y_BxLxD)
227+
y_BxLxD = self.out_ln(y_BxLxD)
228+
logits_BxLxV = self.output_proj(y_BxLxD)
229+
return logits_BxLxV
230+
231+
def predict(self, y_BxL: jax.Array, k: int = 1):
232+
"""Generate k tokens autoregressively.
233+
234+
Args:
235+
y_BxL: Input token sequence of shape (batch_size, seq_len)
236+
k: Number of tokens to predict
237+
238+
Returns:
239+
Tuple of (input_ids, predicted_ids)
240+
"""
241+
cfg = self.docfg
242+
batch_size = y_BxL.shape[0]
243+
seq_len = y_BxL.shape[1]
244+
245+
# Store original input
246+
original_input = y_BxL
247+
248+
# Make sure we don't exceed the model's context length
249+
if seq_len + k > cfg.L:
250+
raise ValueError(
251+
f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})"
252+
)
253+
254+
# Generate k tokens autoregressively
255+
for _ in range(k):
256+
# Get logits for the entire sequence
257+
logits = self(y_BxL)
258+
259+
# Get the logits for the last token in each sequence
260+
next_token_logits = logits[:, -1, :]
261+
262+
# Get the most likely token
263+
next_token = jnp.argmax(next_token_logits, axis=-1)
264+
265+
# Append the predicted token to the sequence
266+
y_BxL = jnp.concatenate([y_BxL, next_token[:, None]], axis=1)
267+
268+
# Return original input and the k predicted tokens
269+
return original_input, y_BxL[:, -k:]
270+
271+
272+
# =========== Demo Code ==========
273+
274+
275+
def main():
276+
"""Create and run the DecoderOnly Transformer model."""
277+
# Initialize model configuration with smaller parameters for demo
278+
B, L = (2, 128) # Batch size, sequence length
279+
cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128)
280+
model = TransformerDo(cfg)
281+
282+
# Print model info
283+
print(f"\nModel Configuration:")
284+
print(f" - Model dimension (D): {cfg.D}")
285+
print(f" - Number of heads (H): {cfg.H}")
286+
print(f" - Max sequence length (L): {cfg.L}")
287+
print(f" - Number of layers (N): {cfg.N}")
288+
print(f" - Vocabulary size (V): {cfg.V}")
289+
print(f" - Feed forward dimension (F): {cfg.F}")
290+
291+
# Create random input tokens (simulated token IDs)
292+
rng_key = jax.random.PRNGKey(42)
293+
input_rng, init_rng = jax.random.split(rng_key)
294+
295+
# Generate random token IDs (integers between 0 and vocab_size-1)
296+
x_BxL = jax.random.randint(
297+
input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32
298+
)
299+
300+
# Initialize model parameters
301+
print("\nInitializing model parameters...")
302+
params = model.init(init_rng, x_BxL)
303+
304+
# Print parameter count
305+
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
306+
print(f"Total parameters: {param_count:,}")
307+
308+
# Make a prediction (forward pass)
309+
print("\nRunning forward pass...")
310+
logits = model.apply(params, x_BxL)
311+
312+
# Print output shape and sample values
313+
print(f"\nOutput shape: {logits.shape} (batch_size, sequence_length, vocab_size)")
314+
print(f"Output data type: {logits.dtype}")
315+
316+
# Print sample logits (first 5 positions of the first sequence)
317+
print("\nSample logits (first sequence, first 5 positions, first 5 values):")
318+
for position in range(min(5, L)):
319+
print(f" Position {position}: {logits[0, position, :5]}")
320+
321+
# Get predictions (token with highest logit at each position)
322+
predictions = jnp.argmax(logits, axis=-1)
323+
print("\nPredicted token IDs (first sequence, first 10 positions):")
324+
print(predictions[0, :10])
325+
326+
# Test the predict function
327+
print("\nTesting predict function...")
328+
# Use a shorter
329+
short_seq = x_BxL[:, :10]
330+
print(f"Input sequence shape: {short_seq.shape}")
331+
332+
# Predict 5 tokens
333+
k = 5
334+
original, predicted = model.apply(params, short_seq, k, method=model.predict)
335+
336+
# Get predictions (token with highest logit at each position)
337+
predictions = jnp.argmax(logits, axis=-1)
338+
print("\nPredicted token IDs (first sequence, first 10 positions):")
339+
print(predictions[0, :10])
340+
341+
print("\nDone!")
342+
343+
344+
if __name__ == "__main__":
345+
main()

0 commit comments

Comments
 (0)