Skip to content

Commit 706d9f7

Browse files
committed
torch model
1 parent 99c7b9b commit 706d9f7

File tree

4 files changed

+341
-18
lines changed

4 files changed

+341
-18
lines changed

algoperf/param_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def pytorch_param_types(
4343
param_types[name] = spec.ParameterType.ATTENTION_BIAS
4444
elif 'in_proj' in name:
4545
param_types[name] = spec.ParameterType.ATTENTION_QKV
46+
elif 'qkv' in name:
47+
param_types[name] = spec.ParameterType.ATTENTION_QKV
4648
elif 'kv_proj' in name:
4749
param_types[name] = spec.ParameterType.ATTENTION_KV
4850
elif 'k_proj' in name or 'key' in name:
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import math
2+
import torch
3+
import torch.nn.functional as F
4+
from torch import nn
5+
from dataclasses import dataclass
6+
from typing import Tuple
7+
8+
9+
10+
@dataclass
11+
class ModelConfig:
12+
vocab_size: int
13+
seq_len: int
14+
dim: int
15+
expand: float
16+
n_layers: int
17+
n_heads: int
18+
rmsnorm_eps: float = 1e-6
19+
tie_embeddings: bool = False
20+
21+
22+
class MLP(nn.Module):
23+
24+
def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256):
25+
super().__init__()
26+
hidden_dim = multiple_of * (
27+
(hidden_dim + multiple_of - 1) // multiple_of)
28+
self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False)
29+
self.fc2 = nn.Linear(hidden_dim, dim, bias=False)
30+
self.glu = nn.GLU(dim=2)
31+
32+
# Initialize with Xavier uniform
33+
nn.init.xavier_uniform_(self.fc1.weight)
34+
nn.init.xavier_uniform_(self.fc2.weight)
35+
36+
def forward(self, x):
37+
# x: (bsz, T, dim)
38+
return self.fc2(self.glu(self.fc1(x)))
39+
40+
41+
def precompute_freqs_cis(dim: int,
42+
end: int,
43+
theta: float = 10000.0,
44+
condense_ratio: int = 1):
45+
inv_freqs = 1.0 / (theta**(torch.arange(
46+
0, dim, 2, dtype=torch.float32, device=torch.device("cpu")) / dim))
47+
t = torch.arange(end, dtype=torch.float32,
48+
device=inv_freqs.device) / condense_ratio
49+
freqs = torch.outer(t, inv_freqs).float()
50+
return torch.stack([
51+
torch.cos(freqs)[None, :, None, :],
52+
torch.sin(freqs)[None, :, None, :]
53+
],
54+
dim=4)
55+
56+
57+
def apply_rotary_emb_complex_like(
58+
q: torch.Tensor, k: torch.Tensor,
59+
freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
60+
# Rotate query and key vectors using RoPE
61+
qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float()
62+
rotated_qk_r2 = torch.stack(
63+
[
64+
qk_r2[..., 0] * freqs_cis[..., 0] -
65+
qk_r2[..., 1] * freqs_cis[..., 1],
66+
qk_r2[..., 1] * freqs_cis[..., 0] +
67+
qk_r2[..., 0] * freqs_cis[..., 1],
68+
],
69+
-1,
70+
).flatten(3)
71+
rotated_qk = rotated_qk_r2
72+
return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2)
73+
74+
75+
class Attention(nn.Module):
76+
77+
def __init__(self, cfg: ModelConfig):
78+
super().__init__()
79+
assert cfg.dim % cfg.n_heads == 0
80+
self.dim = cfg.dim
81+
self.n_heads = cfg.n_heads
82+
self.head_dim = cfg.dim // cfg.n_heads
83+
84+
self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False)
85+
self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False)
86+
87+
def forward(self, x, freqs_cis):
88+
bsz, seqlen, d = x.shape # (bsz, seqlen, d)
89+
90+
q, k, v = self.w_qkv(x).split(d, dim=2) # (bsz, seqlen, d)
91+
q = q.view(bsz, seqlen, self.n_heads,
92+
self.head_dim) # (bsz, seqlen, nh, h_dim)
93+
k = k.view(bsz, seqlen, self.n_heads,
94+
self.head_dim) # (bsz, seqlen, nh, h_dim)
95+
v = v.view(bsz, seqlen, self.n_heads,
96+
self.head_dim) # (bsz, seqlen, nh, h_dim)
97+
98+
q, k = apply_rotary_emb_complex_like(
99+
q, k, freqs_cis=freqs_cis) # (bsz, seqlen, nh, h_dim)
100+
101+
q = q.transpose(1, 2) # (bsz, nh, seqlen, h_dim)
102+
k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim)
103+
v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim)
104+
105+
out = F.scaled_dot_product_attention(
106+
q, k, v, is_causal=True) # (bsz, nh, seqlen, h_dim)
107+
108+
out = out.transpose(1, 2).contiguous().view(bsz, seqlen,
109+
d) # (bsz, seqlen, d)
110+
111+
return self.w_out(out)
112+
113+
114+
class Block(nn.Module):
115+
116+
def __init__(self, layer_id: int, cfg: ModelConfig):
117+
super().__init__()
118+
self.attn = Attention(cfg)
119+
self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps)
120+
self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim))
121+
self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps)
122+
self.layer_id = layer_id
123+
124+
def forward(self, x, freqs_cis):
125+
# x: (bsz, seqlen, dim)
126+
x = x + self.attn(self.attn_norm(x), freqs_cis)
127+
x = x + self.mlp(self.mlp_norm(x))
128+
return x
129+
130+
131+
class Transformer(nn.Module):
132+
133+
def __init__(self, cfg):
134+
super().__init__()
135+
self.n_layers = cfg.n_layers
136+
self.cfg = cfg
137+
head_dim = cfg.dim // cfg.n_heads
138+
assert cfg.dim % cfg.n_heads == 0
139+
140+
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim)
141+
self.layers = nn.ModuleList(
142+
[Block(idx, cfg) for idx in range(cfg.n_layers)])
143+
self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps)
144+
self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
145+
146+
# Initialize freqs_cis on CPU first (more memory efficient)
147+
self.register_buffer('freqs_cis',
148+
precompute_freqs_cis(head_dim, cfg.seq_len, 500000)[0:cfg.seq_len],
149+
persistent=False)
150+
151+
# init all weights, scale residual branches
152+
self.apply(self._init_weights)
153+
self._scale_residual_branches()
154+
155+
# Move model to device (which will also move freqs_cis)
156+
if torch.cuda.is_available():
157+
self.cuda()
158+
159+
if cfg.tie_embeddings:
160+
self.tie_weights()
161+
162+
def forward(self, x):
163+
# x: (bsz, seqlen)
164+
x = self.embed_tokens(x) # (bsz, seqlen, dim)
165+
L = x.shape[1]
166+
167+
# Make sure we have enough precomputed frequencies
168+
if L > self.freqs_cis.shape[1]:
169+
# Need to recompute for longer sequence
170+
head_dim = self.cfg.dim // self.cfg.n_heads
171+
new_freqs = precompute_freqs_cis(head_dim, max(L, self.cfg.seq_len), 500000)
172+
self.register_buffer('freqs_cis', new_freqs[0:max(L, self.cfg.seq_len)], persistent=False)
173+
if torch.cuda.is_available():
174+
self.freqs_cis = self.freqs_cis.cuda()
175+
176+
# Select the frequencies for current sequence length and ensure correct device
177+
freqs_cis = self.freqs_cis[:, :L, :].to(x.device)
178+
179+
for layer in self.layers:
180+
x = layer(x, freqs_cis) # (bsz, seqlen, dim)
181+
return self.lm_head(self.out_norm(x)) # (bsz, seqlen, vocab_size)
182+
183+
def predict(self, x, k=1):
184+
"""Generate k tokens autoregressively.
185+
186+
Args:
187+
x: Input token sequence of shape (batch_size, seq_len)
188+
k: Number of tokens to predict
189+
190+
Returns:
191+
Tuple of (input_ids, predicted_ids)
192+
"""
193+
# For debugging
194+
predictions = []
195+
196+
batch_size = x.shape[0]
197+
seq_len = x.shape[1]
198+
199+
# Store original input
200+
original_input = x.clone()
201+
generated_input = x.clone()
202+
203+
# Generate k tokens autoregressively
204+
for i in range(k):
205+
# Get logits for the entire sequence
206+
logits = self(generated_input)
207+
208+
# Get the logits for the last token in each sequence
209+
next_token_logits = logits[:, -1, :]
210+
211+
# Zero out the last token ID to prevent repetition
212+
# This is a common issue - the model gets stuck repeating the last token
213+
last_token_id = generated_input[:, -1]
214+
next_token_logits.scatter_(1, last_token_id.unsqueeze(1), float('-inf'))
215+
216+
# Print top 5 tokens for debugging
217+
if i == 0:
218+
print("\nPyTorch detailed prediction:")
219+
top5_values, top5_indices = torch.topk(next_token_logits[0], 5)
220+
for j, (idx, val) in enumerate(zip(top5_indices.tolist(), top5_values.tolist())):
221+
prob = torch.softmax(next_token_logits[0], dim=-1)[idx].item()
222+
print(f" Top {j+1}: Token {idx}, logit={val:.2f}, prob={prob:.6f}")
223+
224+
# Get the most likely token
225+
next_token = torch.argmax(next_token_logits, dim=-1)
226+
predictions.append(next_token.item())
227+
228+
# Append the predicted token to the sequence
229+
next_token = next_token.unsqueeze(1) # Add sequence dimension
230+
generated_input = torch.cat([generated_input, next_token], dim=1)
231+
232+
print(f" Full predictions step by step: {predictions}")
233+
234+
# Return all tokens, not just the last k
235+
return original_input, generated_input[:, -k:]
236+
237+
def _init_weights(self, module):
238+
if isinstance(module, nn.Linear):
239+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
240+
if module.bias is not None:
241+
torch.nn.init.zeros_(module.bias)
242+
elif isinstance(module, nn.Embedding):
243+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
244+
245+
def _scale_residual_branches(self):
246+
for n, p in self.named_parameters():
247+
if n.endswith("fc2.weight"): # mlp/glu output layer
248+
torch.nn.init.normal_(p,
249+
mean=0.0,
250+
std=0.02 / math.sqrt(2 * self.n_layers))
251+
if n.endswith("w_out.weight"): # attn output layer
252+
torch.nn.init.normal_(p,
253+
mean=0.0,
254+
std=0.02 / math.sqrt(2 * self.n_layers))
255+
256+
def tie_weights(self):
257+
self.lm_head.weight = self.embed_tokens.weight
258+
259+
def count_params(self, non_embedding=True):
260+
n_params = sum(p.numel() for p in self.parameters())
261+
if non_embedding:
262+
n_params -= self.embed_tokens.weight.numel()
263+
if (not self.lm_head.weight
264+
is self.embed_tokens.weight): # if no weight tying
265+
n_params -= self.lm_head.weight.numel()
266+
return n_params
267+
268+
269+
def main():
270+
print("Initializing transformer model and running forward pass...")
271+
272+
seq_length = 512
273+
274+
# Define model configuration
275+
config = ModelConfig(
276+
vocab_size=32000, # Common vocab size for tokenizers like BPE or SentencePiece
277+
seq_len=seq_length, # Maximum sequence length
278+
dim=768, # Embedding dimension
279+
expand=4.0, # MLP expansion factor
280+
n_layers=12, # Number of transformer layers
281+
n_heads=12, # Number of attention heads
282+
rmsnorm_eps=1e-6, # RMSNorm epsilon
283+
tie_embeddings=True # Tie embedding and output weights
284+
)
285+
286+
def tie_weights(self):
287+
self.lm_head.weight = self.embed_tokens.weight
288+
289+
def count_params(self, non_embedding=True):
290+
n_params = sum(p.numel() for p in self.parameters())
291+
if non_embedding:
292+
n_params -= self.embed_tokens.weight.numel()
293+
if (not self.lm_head.weight
294+
is self.embed_tokens.weight): # if no weight tying
295+
n_params -= self.lm_head.weight.numel()
296+
return n_params
297+
298+

0 commit comments

Comments
 (0)