Skip to content

Commit 0b73057

Browse files
committed
model train run - no eval
1 parent 89ab479 commit 0b73057

File tree

5 files changed

+195
-88
lines changed

5 files changed

+195
-88
lines changed

.gitignore

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,111 @@
1+
# Model, ipynb_checkpoints
2+
model
3+
save
4+
log
5+
16
.vscode
2-
model
7+
8+
# Byte-compiled / optimized / DLL files
9+
__pycache__/
10+
*.py[cod]
11+
*$py.class
12+
13+
# C extensions
14+
*.so
15+
16+
# Distribution / packaging
17+
.Python
18+
build/
19+
develop-eggs/
20+
dist/
21+
downloads/
22+
eggs/
23+
.eggs/
24+
lib/
25+
lib64/
26+
parts/
27+
sdist/
28+
var/
29+
wheels/
30+
*.egg-info/
31+
.installed.cfg
32+
*.egg
33+
MANIFEST
34+
35+
# PyInstaller
36+
# Usually these files are written by a python script from a template
37+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
38+
*.manifest
39+
*.spec
40+
41+
# Installer logs
42+
pip-log.txt
43+
pip-delete-this-directory.txt
44+
45+
# Unit test / coverage reports
46+
htmlcov/
47+
.tox/
48+
.coverage
49+
.coverage.*
50+
.cache
51+
nosetests.xml
52+
coverage.xml
53+
*.cover
54+
.hypothesis/
55+
.pytest_cache/
56+
57+
# Translations
58+
*.mo
59+
*.pot
60+
61+
# Django stuff:
62+
*.log
63+
local_settings.py
64+
db.sqlite3
65+
66+
# Flask stuff:
67+
instance/
68+
.webassets-cache
69+
70+
# Scrapy stuff:
71+
.scrapy
72+
73+
# Sphinx documentation
74+
docs/_build/
75+
76+
# PyBuilder
77+
target/
78+
79+
# Jupyter Notebook
80+
.ipynb_checkpoints
81+
82+
# pyenv
83+
.python-version
84+
85+
# celery beat schedule file
86+
celerybeat-schedule
87+
88+
# SageMath parsed files
89+
*.sage.py
90+
91+
# Environments
92+
.env
93+
.venv
94+
env/
95+
venv/
96+
ENV/
97+
env.bak/
98+
venv.bak/
99+
100+
# Spyder project settings
101+
.spyderproject
102+
.spyproject
103+
104+
# Rope project settings
105+
.ropeproject
106+
107+
# mkdocs documentation
108+
/site
109+
110+
# mypy
111+
.mypy_cache/

datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _rocstories(path):
2727
y.append(int(line[-1])-1)
2828
return st, ct1, ct2, y
2929

30-
def rocstories(data_dir, n_train=1497, n_valid=374):
30+
def rocstories(data_dir, n_train=1497, n_valid=2): #374): # TODO: set this back
3131
storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv'))
3232
teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv'))
3333
tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed)

model_py.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import math
23
import json
34
import copy
@@ -36,22 +37,34 @@ def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'):
3637
n_transfer = 0
3738
else:
3839
n_transfer = 1+n_transfer*12
39-
assert model.embed.weight.shape == init_params[0].shape
40-
model.embed.weight = init_params[0]
40+
init_params = [arr.squeeze() for arr in init_params]
41+
try:
42+
assert model.embed.weight.shape == init_params[0].shape
43+
except AssertionError as e:
44+
e.args += (model.embed.weight.shape, init_params[0].shape)
45+
raise
46+
model.embed.weight.data = torch.from_numpy(init_params[0])
4147
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
4248
name = name[6:] # skip "model/"
4349
assert name[-2:] == ":0"
4450
name = name[:-2]
4551
name = name.split('/')
4652
pointer = model
4753
for m_name in name:
48-
l = re.split('(\d+)', m_name)
54+
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
55+
l = re.split(r'(\d+)', m_name)
56+
else:
57+
l = [m_name]
4958
pointer = getattr(pointer, l[0])
50-
if len(l) == 1:
59+
if len(l) >= 2:
5160
num = int(l[1])
5261
pointer = pointer[num]
53-
assert pointer.shape == ip.shape
54-
pointer = ip
62+
try:
63+
assert pointer.shape == ip.shape
64+
except AssertionError as e:
65+
e.args += (pointer.shape, ip.shape)
66+
raise
67+
pointer.data = torch.from_numpy(ip)
5568

5669

5770
class LayerNorm(nn.Module):
@@ -82,7 +95,7 @@ def __init__(self, nf, rf, nx):
8295

8396
def forward(self, x):
8497
if self.rf == 1:
85-
size_out = x.size()[:-1] + [self.nf]
98+
size_out = x.size()[:-1] + (self.nf,)
8699
x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
87100
x = x.view(*size_out)
88101
else:
@@ -93,38 +106,35 @@ def forward(self, x):
93106
class Attention(nn.Module):
94107
def __init__(self, nx, cfg, scale=False):
95108
super(Attention, self).__init__()
96-
n_state = nx # in Attention: n_state=768 (nx=n_embed)
109+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
97110
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
98111
assert n_state % cfg.n_head==0
112+
mask_size = n_state // cfg.n_head
113+
self.register_buffer('b', torch.tril(torch.ones(mask_size, mask_size)).view(1, 1, mask_size, mask_size))
99114
self.n_head = cfg.n_head
115+
self.split_size = n_state
100116
self.scale = scale
101117
self.c_attn = Conv1D(n_state * 3, 1, nx)
102118
self.c_proj = Conv1D(n_state, 1, nx)
103119
self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
104120
self.resid_dropout = nn.Dropout(cfg.resid_pdrop)
105121

106-
@staticmethod
107-
def mask_attn_weights(w):
108-
n = w.size(-1)
109-
b = torch.tril(np.ones(n, n)).view(1, 1, n, n)
110-
return w * b + -1e9*(1-b)
111-
112122
def _attn(self, q, k, v):
113123
w = torch.matmul(q, k)
114124
if self.scale:
115125
w = w / math.sqrt(v.size(-1))
116-
w = self.mask_attn_weights(w)
126+
w = w * self.b + -1e9*(1-self.b) # TF implem method: mask_attn_weights
117127
w = nn.Softmax()(w)
118128
w = self.attn_dropout(w)
119129
return torch.matmul(w, v)
120130

121131
def merge_heads(self, x):
122-
new_x_shape = x.size()[:-2] + [np.prod(x.size()[-2:])]
123-
x = x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
124-
return x.permute(0, 2, 1, 3)
132+
x = x.permute(0, 2, 1, 3).contiguous()
133+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
134+
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
125135

126136
def split_heads(self, x, k=False):
127-
new_x_shape = x.size()[:-1] + [self.n_head, x.size(-1)//self.n_head]
137+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1)//self.n_head)
128138
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
129139
if k:
130140
return x.permute(0, 2, 3, 1)
@@ -133,7 +143,7 @@ def split_heads(self, x, k=False):
133143

134144
def forward(self, x):
135145
x = self.c_attn(x)
136-
query, key, value = x.split(3, dim=2)
146+
query, key, value = x.split(self.split_size, dim=2)
137147
query = self.split_heads(query)
138148
key = self.split_heads(key, k=True)
139149
value = self.split_heads(value)
@@ -145,11 +155,11 @@ def forward(self, x):
145155

146156

147157
class MLP(nn.Module):
148-
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embed)
158+
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd)
149159
super(MLP, self).__init__()
150-
nx = cfg.n_embed
160+
nx = cfg.n_embd
151161
self.c_fc = Conv1D(n_state, 1, nx)
152-
self.c_proj = Conv1D(nx, 1, nx)
162+
self.c_proj = Conv1D(nx, 1, n_state)
153163
self.act = ACT_FNS[cfg.afn]
154164
self.dropout = nn.Dropout(cfg.resid_pdrop)
155165

@@ -162,7 +172,7 @@ def forward(self, x):
162172
class Block(nn.Module):
163173
def __init__(self, cfg, scale=False):
164174
super(Block, self).__init__()
165-
nx = cfg.n_embed
175+
nx = cfg.n_embd
166176
self.attn = Attention(nx, cfg, scale)
167177
self.ln_1 = LayerNorm(nx)
168178
self.mlp = MLP(4*nx, cfg)
@@ -185,13 +195,12 @@ def __init__(self, vocab, cfg):
185195
self.drop = nn.Dropout(cfg.embd_pdrop)
186196
block = Block(cfg, scale=True)
187197
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
188-
self.decoder = nn.Linear(cfg.n_embed, vocab, bias=False)
198+
self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False)
189199
self.decoder.weight = self.embed.weight # Tied weights
190200
self.clf_dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
191201

192-
def forward(self, x, m):
202+
def forward(self, x):
193203
x = x.view(-1, x.size(2), x.size(3))
194-
m = m.view(-1, m.size(2))
195204
e = self.embed(x)
196205
h = e.sum(dim=2)
197206
for block in self.h:
@@ -200,36 +209,37 @@ def forward(self, x, m):
200209

201210

202211
class LMHead(nn.Module):
203-
""" Language Model Head """
212+
""" Language Model Head for the transformer """
204213
def __init__(self, model, cfg):
205214
super(LMHead, self).__init__()
206-
self.n_embed = cfg.n_embed
207-
self.decoder = nn.Linear(cfg.n_embed, model.vocab, bias=False)
215+
self.n_embd = cfg.n_embd
216+
self.decoder = nn.Linear(cfg.n_embd, model.vocab, bias=False)
208217
self.decoder.weight = model.embed.weight # Tied weights
209218

210219
def forward(self, h):
211220
# Truncated Language modeling logits
212-
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embed) # Shape: 252, 768
221+
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) # Shape: 252, 768
213222
lm_logits = self.decoder(h_trunc)
214223
return lm_logits
215224

216225

217226
class ClfHead(nn.Module):
218-
""" Classifier Head for the model"""
219-
def __init__(self, model, clf_token, cfg):
227+
""" Classifier Head for the transformer """
228+
def __init__(self, clf_token, cfg):
220229
super(ClfHead, self).__init__()
221-
self.n_embed = cfg.n_embed
230+
self.n_embd = cfg.n_embd
222231
self.clf_token = clf_token
223232
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
224-
self.linear = nn.Linear(cfg.n_embed, 1)
233+
self.linear = nn.Linear(cfg.n_embd, 1)
225234

226235
def forward(self, h, x):
227236
# Classification logits
228-
clf_h = h.view(-1, self.n_embed)
229-
pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
230-
clf_h = clf_h[pool_idx, :]
231-
clf_h = clf_h.view(-1, 2, self.n_embed, 1)
237+
clf_h = h.view(-1, self.n_embd)
238+
flat = x[:, :, :, 0].contiguous().view(-1)
239+
#pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
240+
clf_h = clf_h[flat == self.clf_token, :] #.index_select(0, pool_idx)
241+
clf_h = clf_h.view(-1, 2, self.n_embd, 1)
232242
clf_h = self.dropout(clf_h)
233-
clf_h = clf_h.view(-1, self.n_embed)
243+
clf_h = clf_h.view(-1, self.n_embd)
234244
clf_logits = self.linear(clf_h)
235245
return clf_logits.view(-1, 2)

opt.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
import math
22
import numpy as np
3-
import tensorflow as tf
43

54
def warmup_cosine(x, warmup=0.002):
6-
s = tf.cast(x <= warmup, tf.float32)
7-
return s*(x/warmup) + (1-s)*(0.5 * (1 + tf.cos(math.pi * x)))
5+
pass
86

97
def warmup_constant(x, warmup=0.002):
10-
s = tf.cast(x <= warmup, tf.float32)
11-
return s*(x/warmup) + (1-s)*1
8+
pass
129

1310
def warmup_linear(x, warmup=0.002):
14-
s = tf.cast(x <= warmup, tf.float32)
15-
return (s*(x/warmup) + (1-s))*(1-x)
11+
pass
1612

1713
schedules = {
1814
'warmup_cosine':warmup_cosine,
@@ -24,26 +20,4 @@ def adam(params, grads, lr, schedule, t_total, b1=0.9, b2=0.999, e=1e-8, l2=0, v
2420
"""
2521
adam with weight decay fix
2622
"""
27-
t = tf.Variable(0, dtype=tf.float32, trainable=False)
28-
tt = t+1
29-
updates = [t.assign(tt)]
30-
if max_grad_norm > 0:
31-
grads, _ = tf.clip_by_global_norm(grads, max_grad_norm)
32-
for p, g in zip(params, grads):
33-
if p is None or g is None:
34-
print("can't train", p.name, g)
35-
else:
36-
if isinstance(g, tf.IndexedSlices):
37-
g = tf.convert_to_tensor(g)
38-
m = tf.Variable(p*0, dtype=tf.float32, trainable=False)
39-
v = tf.Variable(p*0, dtype=tf.float32, trainable=False)
40-
lrt = lr*tf.sqrt(1-b2**tt)/(1-b1**tt)
41-
lrt *= schedule(t/t_total)
42-
mt = b1*m + (1-b1)*g
43-
vt = b2*v + (1-b2)*g*g
44-
if (len(p.get_shape()) > 1 or vector_l2) and l2 > 0:
45-
pt = p - lrt * (mt / (tf.sqrt(vt) + e) + l2*p)
46-
else:
47-
pt = p - lrt * (mt / (tf.sqrt(vt) + e))
48-
updates.extend([m.assign(mt), v.assign(vt), p.assign(pt)])
49-
return tf.group(*updates)
23+
pass

0 commit comments

Comments
 (0)