Skip to content

Commit 49d628f

Browse files
committed
fixed modified Adam + added evaluation code
1 parent 0704c84 commit 49d628f

File tree

4 files changed

+155
-133
lines changed

4 files changed

+155
-133
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
model
33
save
44
log
5+
submission
56

67
.vscode
78

model_py.py

+54-47
Original file line numberDiff line numberDiff line change
@@ -20,51 +20,6 @@ def swish(x):
2020
'gelu': gelu
2121
}
2222

23-
def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'):
24-
# Load weights from TF model
25-
n_transfer = cfg.n_transfer
26-
shapes = json.load(open(path + '/params_shapes.json'))
27-
names = json.load(open(path + '/parameters_names.json'))
28-
offsets = np.cumsum([np.prod(shape) for shape in shapes])
29-
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)]
30-
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
31-
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
32-
init_params[0] = init_params[0][:n_ctx]
33-
init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, cfg.n_embd)*0.02).astype(np.float32), init_params[0]], 0)
34-
del init_params[1]
35-
if n_transfer == -1:
36-
n_transfer = 0
37-
else:
38-
n_transfer = 1+n_transfer*12
39-
init_params = [arr.squeeze() for arr in init_params]
40-
try:
41-
assert model.embed.weight.shape == init_params[0].shape
42-
except AssertionError as e:
43-
e.args += (model.embed.weight.shape, init_params[0].shape)
44-
raise
45-
model.embed.weight.data = torch.from_numpy(init_params[0])
46-
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
47-
name = name[6:] # skip "model/"
48-
assert name[-2:] == ":0"
49-
name = name[:-2]
50-
name = name.split('/')
51-
pointer = model
52-
for m_name in name:
53-
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
54-
l = re.split(r'(\d+)', m_name)
55-
else:
56-
l = [m_name]
57-
pointer = getattr(pointer, l[0])
58-
if len(l) >= 2:
59-
num = int(l[1])
60-
pointer = pointer[num]
61-
try:
62-
assert pointer.shape == ip.shape
63-
except AssertionError as e:
64-
e.args += (pointer.shape, ip.shape)
65-
raise
66-
pointer.data = torch.from_numpy(ip)
67-
6823

6924
class LayerNorm(nn.Module):
7025
"Construct a layernorm module (See citation for details)."
@@ -87,7 +42,9 @@ def __init__(self, nf, rf, nx):
8742
self.rf = rf
8843
self.nf = nf
8944
if rf == 1: #faster 1x1 conv
90-
self.w = Parameter(torch.ones(nx, nf)) # TODO change to random normal
45+
w = torch.empty(nx, nf)
46+
nn.init.normal_(w, std=0.02)
47+
self.w = Parameter(w)
9148
self.b = Parameter(torch.zeros(nf))
9249
else: #was used to train LM
9350
raise NotImplementedError
@@ -123,7 +80,7 @@ def _attn(self, q, k, v):
12380
if self.scale:
12481
w = w / math.sqrt(v.size(-1))
12582
w = w * self.b + -1e9*(1-self.b) # TF implem method: mask_attn_weights
126-
w = nn.Softmax()(w)
83+
w = nn.Softmax(dim=-1)(w)
12784
w = self.attn_dropout(w)
12885
return torch.matmul(w, v)
12986

@@ -198,6 +155,8 @@ def __init__(self, vocab, cfg):
198155
self.decoder.weight = self.embed.weight # Tied weights
199156
self.clf_dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
200157

158+
nn.init.normal_(self.embed.weight, std=0.02)
159+
201160
def forward(self, x):
202161
x = x.view(-1, x.size(2), x.size(3))
203162
e = self.embed(x)
@@ -230,6 +189,8 @@ def __init__(self, clf_token, cfg):
230189
self.clf_token = clf_token
231190
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
232191
self.linear = nn.Linear(cfg.n_embd, 1)
192+
nn.init.normal_(self.linear.weight, std=0.02)
193+
nn.init.normal_(self.linear.bias, 0)
233194

234195
def forward(self, h, x):
235196
# Classification logits
@@ -242,3 +203,49 @@ def forward(self, h, x):
242203
clf_h = clf_h.view(-1, self.n_embd)
243204
clf_logits = self.linear(clf_h)
244205
return clf_logits.view(-1, 2)
206+
207+
208+
def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'):
209+
# Load weights from TF model
210+
n_transfer = cfg.n_transfer
211+
shapes = json.load(open(path + '/params_shapes.json'))
212+
names = json.load(open(path + '/parameters_names.json'))
213+
offsets = np.cumsum([np.prod(shape) for shape in shapes])
214+
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)]
215+
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
216+
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
217+
init_params[0] = init_params[0][:n_ctx]
218+
init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, cfg.n_embd)*0.02).astype(np.float32), init_params[0]], 0)
219+
del init_params[1]
220+
if n_transfer == -1:
221+
n_transfer = 0
222+
else:
223+
n_transfer = 1+n_transfer*12
224+
init_params = [arr.squeeze() for arr in init_params]
225+
try:
226+
assert model.embed.weight.shape == init_params[0].shape
227+
except AssertionError as e:
228+
e.args += (model.embed.weight.shape, init_params[0].shape)
229+
raise
230+
model.embed.weight.data = torch.from_numpy(init_params[0])
231+
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
232+
name = name[6:] # skip "model/"
233+
assert name[-2:] == ":0"
234+
name = name[:-2]
235+
name = name.split('/')
236+
pointer = model
237+
for m_name in name:
238+
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
239+
l = re.split(r'(\d+)', m_name)
240+
else:
241+
l = [m_name]
242+
pointer = getattr(pointer, l[0])
243+
if len(l) >= 2:
244+
num = int(l[1])
245+
pointer = pointer[num]
246+
try:
247+
assert pointer.shape == ip.shape
248+
except AssertionError as e:
249+
e.args += (pointer.shape, ip.shape)
250+
raise
251+
pointer.data = torch.from_numpy(ip)

opt.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import torch
33
from torch.optim import Optimizer
4-
from torch.nn.utils import clip_grad_norm
4+
from torch.nn.utils import clip_grad_norm_
55

66
def warmup_cosine(x, warmup=0.002):
77
s = 1 if x <= warmup else 0
@@ -81,12 +81,12 @@ def step(self, closure=None):
8181

8282
# Add grad clipping
8383
if group['max_grad_norm'] > 0:
84-
clip_grad_norm(p, group['max_grad_norm'])
84+
clip_grad_norm_(p, group['max_grad_norm'])
8585

8686
# Decay the first and second moment running average coefficient
8787
exp_avg.mul_(beta1).add_(1 - beta1, grad)
8888
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
89-
denom = exp_avg_sq.sqrt().add_(group['eps'])
89+
denom = exp_avg_sq.sqrt().add_(group['e'])
9090

9191
bias_correction1 = 1 - beta1 ** state['step']
9292
bias_correction2 = 1 - beta2 ** state['step']

train.py

+97-83
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,28 @@ def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
3333
self.lm_coef = lm_coef
3434
self.opt = opt
3535

36-
def __call__(self, X, Y, M, lm_logits, clf_logits):
36+
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
3737
# Language modeling loss
38-
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
39-
M = M.view(-1, M.size(2))
40-
lm_losses = self.lm_criterion(lm_logits, x_shifted)
41-
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1)
42-
lm_losses = lm_losses * M[:, 1:]
43-
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
44-
38+
if lm_logits is not None:
39+
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
40+
M = M.view(-1, M.size(2))
41+
lm_losses = self.lm_criterion(lm_logits, x_shifted)
42+
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1)
43+
lm_losses = lm_losses * M[:, 1:]
44+
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
4545
# Classification loss
4646
clf_losses = self.clf_criterion(clf_logits, Y)
47+
if only_return_losses:
48+
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
4749

48-
if self.lm_coef > 0:
50+
if self.lm_coef > 0 and lm_logits is not None:
4951
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
5052
else:
5153
train_loss = clf_losses.sum()
52-
5354
train_loss.backward()
5455
if self.opt is not None:
5556
self.opt.step()
56-
self.opt.optimizer.zero_grad()
57+
self.opt.zero_grad()
5758
return train_loss.item()
5859

5960

@@ -75,60 +76,84 @@ def transform_roc(X1, X2, X3):
7576
xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx)
7677
return xmb, mmb
7778

78-
# def iter_apply(Xs, Ms, Ys):
79-
# fns = [lambda x:np.concatenate(x, 0), lambda x:float(np.sum(x))]
80-
# results = []
81-
# for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
82-
# n = len(xmb)
83-
# if n == n_batch_train:
84-
# res = sess.run([eval_mgpu_logits, eval_mgpu_clf_loss], {X_train:xmb, M_train:mmb, Y_train:ymb})
85-
# else:
86-
# res = sess.run([eval_logits, eval_clf_loss], {X:xmb, M:mmb, Y:ymb})
87-
# res = [r*n for r in res]
88-
# results.append(res)
89-
# results = zip(*results)
90-
# return [fn(res) for res, fn in zip(results, fns)]
79+
def iter_apply(Xs, Ms, Ys):
80+
fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
81+
results = []
82+
with torch.no_grad():
83+
model.eval()
84+
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
85+
n = len(xmb)
86+
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
87+
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
88+
MMB = torch.tensor(mmb).to(device)
89+
h = model(XMB)
90+
clf_logits = clf_head(h, XMB)
91+
clf_losses = compute_loss(XMB, YMB, MMB, clf_logits, only_return_losses=True)
92+
res = (clf_logits.numpy()*n, clf_losses.numpy()*n)
93+
results.append(res)
94+
results = zip(*results)
95+
return [fn(res) for res, fn in zip(results, fns)]
9196

92-
# def iter_predict(Xs, Ms):
93-
# logits = []
94-
# for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
95-
# n = len(xmb)
96-
# if n == n_batch_train:
97-
# logits.append(sess.run(eval_mgpu_logits, {X_train:xmb, M_train:mmb}))
98-
# else:
99-
# logits.append(sess.run(eval_logits, {X:xmb, M:mmb}))
100-
# logits = np.concatenate(logits, 0)
101-
# return logits
97+
def iter_predict(Xs, Ms):
98+
logits = []
99+
with torch.no_grad():
100+
model.eval()
101+
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
102+
n = len(xmb)
103+
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
104+
MMB = torch.tensor(mmb).to(device)
105+
h = model(XMB)
106+
clf_logits = clf_head(h, XMB)
107+
logits.append(clf_logits.numpy())
108+
logits = np.concatenate(logits, 0)
109+
return logits
102110

103-
# def log():
104-
# global best_score
105-
# tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
106-
# va_logits, va_cost = iter_apply(vaX, vaM, vaY)
107-
# tr_cost = tr_cost/len(trY[:n_valid])
108-
# va_cost = va_cost/n_valid
109-
# tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
110-
# va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
111-
# logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
112-
# print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
113-
# if submit:
114-
# score = va_acc
115-
# if score > best_score:
116-
# best_score = score
117-
# save(os.path.join(save_dir, desc, 'best_params.jl'))
111+
def log():
112+
global best_score
113+
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
114+
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
115+
tr_cost = tr_cost/len(trY[:n_valid])
116+
va_cost = va_cost/n_valid
117+
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
118+
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
119+
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
120+
print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
121+
if submit:
122+
score = va_acc
123+
if score > best_score:
124+
best_score = score
125+
path = os.path.join(save_dir, desc, 'best_params')
126+
torch.save(model.state_dict(), make_path(path))
127+
128+
def predict():
129+
filename = filenames[dataset]
130+
pred_fn = pred_fns[dataset]
131+
label_decoder = label_decoders[dataset]
132+
predictions = pred_fn(iter_predict(teX, teM))
133+
if label_decoder is not None:
134+
predictions = [label_decoder[prediction] for prediction in predictions]
135+
path = os.path.join(submission_dir, filename)
136+
os.makedirs(os.path.dirname(path), exist_ok=True)
137+
with open(path, 'w') as f:
138+
f.write('{}\t{}\n'.format('index', 'prediction'))
139+
for i, prediction in enumerate(predictions):
140+
f.write('{}\t{}\n'.format(i, prediction))
118141

119-
# def predict():
120-
# filename = filenames[dataset]
121-
# pred_fn = pred_fns[dataset]
122-
# label_decoder = label_decoders[dataset]
123-
# predictions = pred_fn(iter_predict(teX, teM))
124-
# if label_decoder is not None:
125-
# predictions = [label_decoder[prediction] for prediction in predictions]
126-
# path = os.path.join(submission_dir, filename)
127-
# os.makedirs(os.path.dirname(path), exist_ok=True)
128-
# with open(path, 'w') as f:
129-
# f.write('{}\t{}\n'.format('index', 'prediction'))
130-
# for i, prediction in enumerate(predictions):
131-
# f.write('{}\t{}\n'.format(i, prediction))
142+
def run_epoch():
143+
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
144+
n_batch=n_batch_train, truncate=True, verbose=True):
145+
global n_updates
146+
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
147+
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
148+
MMB = torch.tensor(mmb).to(device)
149+
model.train()
150+
h = model(XMB)
151+
lm_logits = lm_head(h)
152+
clf_logits = clf_head(h, XMB)
153+
compute_loss(XMB, YMB, MMB, clf_logits, lm_logits)
154+
n_updates += 1
155+
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
156+
log()
132157

133158
argmax = lambda x:np.argmax(x, 1)
134159

@@ -235,7 +260,6 @@ def transform_roc(X1, X2, X3):
235260
max_grad_norm=max_grad_norm)
236261

237262
compute_loss = LossCompute(criterion, criterion, lm_coef, model_opt)
238-
# TODO Initialize model (?)
239263
# TODO add train() and eval()
240264
load_openai_pretrained_model(model, n_ctx, n_special, args)
241265

@@ -250,26 +274,16 @@ def transform_roc(X1, X2, X3):
250274
if submit:
251275
path = os.path.join(save_dir, desc, 'best_params')
252276
torch.save(model.state_dict(), make_path(path))
253-
254277
best_score = 0
278+
log()
255279
for i in range(n_iter):
256-
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
257-
n_batch=n_batch_train, truncate=True, verbose=True):
258-
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
259-
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
260-
MMB = torch.tensor(mmb).to(device)
261-
model.train()
262-
h = model(XMB)
263-
lm_logits = lm_head(h)
264-
clf_logits = clf_head(h, XMB)
265-
loss = compute_loss(XMB, YMB, MMB, lm_logits, clf_logits)
266-
n_updates += 1
267-
#if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
268-
# log()
280+
run_epoch()
269281
n_epochs += 1
270-
# log()
271-
# if submit:
272-
# sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))])
273-
# predict()
274-
# if analysis:
275-
# rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), os.path.join(log_dir, 'rocstories.jsonl'))
282+
log()
283+
if submit:
284+
path = os.path.join(save_dir, desc, 'best_params')
285+
model.load_state_dict(torch.load(path))
286+
predict()
287+
if analysis:
288+
rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'),
289+
os.path.join(log_dir, 'rocstories.jsonl'))

0 commit comments

Comments
 (0)