Skip to content

Commit 89ab479

Browse files
committed
refactoring
1 parent 1f209c4 commit 89ab479

File tree

3 files changed

+231
-233
lines changed

3 files changed

+231
-233
lines changed

model_py.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import numpy as np
21
import math
2+
import json
33
import copy
4+
import numpy as np
5+
46
import torch
57
import torch.nn as nn
68
import torch.nn.functional as F
@@ -18,20 +20,53 @@ def swish(x):
1820
'gelu': gelu
1921
}
2022

23+
def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'):
24+
# Load weights from TF model
25+
n_transfer = cfg.n_transfer
26+
shapes = json.load(open(path + '/params_shapes.json'))
27+
names = json.load(open(path + '/parameters_names.json'))
28+
offsets = np.cumsum([np.prod(shape) for shape in shapes])
29+
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)]
30+
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
31+
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
32+
init_params[0] = init_params[0][:n_ctx]
33+
init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, cfg.n_embd)*0.02).astype(np.float32), init_params[0]], 0)
34+
del init_params[1]
35+
if n_transfer == -1:
36+
n_transfer = 0
37+
else:
38+
n_transfer = 1+n_transfer*12
39+
assert model.embed.weight.shape == init_params[0].shape
40+
model.embed.weight = init_params[0]
41+
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
42+
name = name[6:] # skip "model/"
43+
assert name[-2:] == ":0"
44+
name = name[:-2]
45+
name = name.split('/')
46+
pointer = model
47+
for m_name in name:
48+
l = re.split('(\d+)', m_name)
49+
pointer = getattr(pointer, l[0])
50+
if len(l) == 1:
51+
num = int(l[1])
52+
pointer = pointer[num]
53+
assert pointer.shape == ip.shape
54+
pointer = ip
55+
2156

2257
class LayerNorm(nn.Module):
2358
"Construct a layernorm module (See citation for details)."
24-
def __init__(self, n_state, eps=1e-6):
59+
def __init__(self, n_state, e=1e-5):
2560
super(LayerNorm, self).__init__()
2661
self.g = nn.Parameter(torch.ones(n_state))
2762
self.b = nn.Parameter(torch.zeros(n_state))
28-
self.eps = eps
63+
self.e = e
2964

3065
def forward(self, x):
31-
mean = x.mean(-1, keepdim=True)
32-
std = x.std(-1, keepdim=True)
33-
# One difference with the TF version here: we add epsilon outside of sqrt
34-
return self.g * (x - mean) / (std + self.eps) + self.b
66+
u = x.mean(-1, keepdim=True)
67+
s = (x - u).pow(2).mean(-1, keepdim=True)
68+
x = (x - u) / torch.sqrt(s + self.e)
69+
return self.g * x + self.b
3570

3671

3772
class Conv1D(nn.Module):
@@ -145,6 +180,7 @@ class Model(nn.Module):
145180
""" Transformer model """
146181
def __init__(self, vocab, cfg):
147182
super(Model, self).__init__()
183+
self.vocab = vocab
148184
self.embed = nn.Embedding(vocab, cfg.n_embd)
149185
self.drop = nn.Dropout(cfg.embd_pdrop)
150186
block = Block(cfg, scale=True)
@@ -160,18 +196,40 @@ def forward(self, x, m):
160196
h = e.sum(dim=2)
161197
for block in self.h:
162198
h = block(h)
199+
return h
163200

164-
# Language modeling logits
201+
202+
class LMHead(nn.Module):
203+
""" Language Model Head """
204+
def __init__(self, model, cfg):
205+
super(LMHead, self).__init__()
206+
self.n_embed = cfg.n_embed
207+
self.decoder = nn.Linear(cfg.n_embed, model.vocab, bias=False)
208+
self.decoder.weight = model.embed.weight # Tied weights
209+
210+
def forward(self, h):
211+
# Truncated Language modeling logits
165212
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embed) # Shape: 252, 768
166213
lm_logits = self.decoder(h_trunc)
214+
return lm_logits
215+
167216

217+
class ClfHead(nn.Module):
218+
""" Classifier Head for the model"""
219+
def __init__(self, model, clf_token, cfg):
220+
super(ClfHead, self).__init__()
221+
self.n_embed = cfg.n_embed
222+
self.clf_token = clf_token
223+
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
224+
self.linear = nn.Linear(cfg.n_embed, 1)
225+
226+
def forward(self, h, x):
168227
# Classification logits
169228
clf_h = h.view(-1, self.n_embed)
170229
pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
171230
clf_h = clf_h[pool_idx, :]
172231
clf_h = clf_h.view(-1, 2, self.n_embed, 1)
173-
clf_h = self.clf_dropout(clf_h)
232+
clf_h = self.dropout(clf_h)
174233
clf_h = clf_h.view(-1, self.n_embed)
175234
clf_logits = self.linear(clf_h)
176-
177-
return lm_logits, clf_logits
235+
return clf_logits.view(-1, 2)

train.py

Lines changed: 80 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
from sklearn.utils import shuffle
1818
from sklearn.metrics import accuracy_score
1919

20-
from model_py import Model
20+
from model_py import Model, LMHead, ClfHead, load_openai_pretrained_model
2121
from opt import adam, warmup_cosine, warmup_linear, warmup_constant
2222
from datasets import rocstories
2323
from analysis import rocstories as rocstories_analysis
2424
from text_utils import TextEncoder
25-
from utils import encode_dataset, flatten, iter_data, find_trainable_variables, get_ema_vars, convert_gradient_to_tensor, shape_list, ResultLogger, assign_to_gpu, average_grads, make_path
25+
from utils import (encode_dataset, flatten, iter_data,
26+
ResultLogger, make_path)
2627

2728
OPT_FNS = {
2829
'adam':adam,
@@ -36,14 +37,11 @@
3637

3738
class LossCompute:
3839
"A Loss compute and train function."
39-
def __init__(self, generator, lm_criterion, n_embed, clf_token, opt=None):
40-
self.generator = generator
40+
def __init__(self, lm_criterion, clf_criterion):
4141
self.lm_criterion = lm_criterion
42-
self.opt = opt
43-
self.n_embed = n_embed
44-
self.clf_token = clf_token
42+
self.clf_criterion = clf_criterion
4543

46-
def __call__(self, X, Y, M, lm_logits, clf_logits, norm):
44+
def __call__(self, X, Y, M, lm_logits, clf_logits):
4745
# Language modeling loss
4846
x_shifted = X[:, 1:, 0].contiguous().view(-1) # Shape: 252
4947
lm_losses = self.lm_criterion(lm_logits, x_shifted)
@@ -59,39 +57,6 @@ def __call__(self, X, Y, M, lm_logits, clf_logits, norm):
5957
train_loss = clf_losses.sum()
6058
return train_loss
6159

62-
# def mgpu_train(*xs):
63-
# gpu_ops = []
64-
# gpu_grads = []
65-
# xs = (tf.split(x, n_gpu, 0) for x in xs)
66-
# for i, xs in enumerate(zip(*xs)):
67-
# do_reuse = True if i > 0 else None
68-
# with tf.device(assign_to_gpu(i, "/gpu:0")), tf.variable_scope(tf.get_variable_scope(), reuse=do_reuse):
69-
# clf_logits, clf_losses, lm_losses = model(*xs, train=True, reuse=do_reuse)
70-
# if lm_coef > 0:
71-
# train_loss = tf.reduce_mean(clf_losses) + lm_coef*tf.reduce_mean(lm_losses)
72-
# else:
73-
# train_loss = tf.reduce_mean(clf_losses)
74-
# params = find_trainable_variables("model")
75-
# grads = tf.gradients(train_loss, params)
76-
# grads = list(zip(grads, params))
77-
# gpu_grads.append(grads)
78-
# gpu_ops.append([clf_logits, clf_losses, lm_losses])
79-
# ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
80-
# grads = average_grads(gpu_grads)
81-
# grads = [g for g, p in grads]
82-
# train = opt_fns[opt](params, grads, lr, partial(lr_schedules[lr_schedule], warmup=lr_warmup), n_updates_total, l2=l2, max_grad_norm=max_grad_norm, vector_l2=vector_l2, b1=b1, b2=b2, e=e)
83-
# return [train]+ops
84-
85-
# def mgpu_predict(*xs):
86-
# gpu_ops = []
87-
# xs = (tf.split(x, n_gpu, 0) for x in xs)
88-
# for i, xs in enumerate(zip(*xs)):
89-
# with tf.device(assign_to_gpu(i, "/gpu:0")), tf.variable_scope(tf.get_variable_scope(), reuse=True):
90-
# clf_logits, clf_losses, lm_losses = model(*xs, train=False, reuse=True)
91-
# gpu_ops.append([clf_logits, clf_losses, lm_losses])
92-
# ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
93-
# return ops
94-
9560
def transform_roc(X1, X2, X3):
9661
n_batch = len(X1)
9762
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
@@ -110,50 +75,60 @@ def transform_roc(X1, X2, X3):
11075
xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx)
11176
return xmb, mmb
11277

113-
def iter_apply(Xs, Ms, Ys):
114-
fns = [lambda x:np.concatenate(x, 0), lambda x:float(np.sum(x))]
115-
results = []
116-
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
117-
n = len(xmb)
118-
if n == n_batch_train:
119-
res = sess.run([eval_mgpu_logits, eval_mgpu_clf_loss], {X_train:xmb, M_train:mmb, Y_train:ymb})
120-
else:
121-
res = sess.run([eval_logits, eval_clf_loss], {X:xmb, M:mmb, Y:ymb})
122-
res = [r*n for r in res]
123-
results.append(res)
124-
results = zip(*results)
125-
return [fn(res) for res, fn in zip(results, fns)]
78+
# def iter_apply(Xs, Ms, Ys):
79+
# fns = [lambda x:np.concatenate(x, 0), lambda x:float(np.sum(x))]
80+
# results = []
81+
# for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
82+
# n = len(xmb)
83+
# if n == n_batch_train:
84+
# res = sess.run([eval_mgpu_logits, eval_mgpu_clf_loss], {X_train:xmb, M_train:mmb, Y_train:ymb})
85+
# else:
86+
# res = sess.run([eval_logits, eval_clf_loss], {X:xmb, M:mmb, Y:ymb})
87+
# res = [r*n for r in res]
88+
# results.append(res)
89+
# results = zip(*results)
90+
# return [fn(res) for res, fn in zip(results, fns)]
12691

127-
def iter_predict(Xs, Ms):
128-
logits = []
129-
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
130-
n = len(xmb)
131-
if n == n_batch_train:
132-
logits.append(sess.run(eval_mgpu_logits, {X_train:xmb, M_train:mmb}))
133-
else:
134-
logits.append(sess.run(eval_logits, {X:xmb, M:mmb}))
135-
logits = np.concatenate(logits, 0)
136-
return logits
92+
# def iter_predict(Xs, Ms):
93+
# logits = []
94+
# for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
95+
# n = len(xmb)
96+
# if n == n_batch_train:
97+
# logits.append(sess.run(eval_mgpu_logits, {X_train:xmb, M_train:mmb}))
98+
# else:
99+
# logits.append(sess.run(eval_logits, {X:xmb, M:mmb}))
100+
# logits = np.concatenate(logits, 0)
101+
# return logits
137102

138-
def save(path):
139-
ps = sess.run(params)
140-
joblib.dump(ps, make_path(path))
103+
# def log():
104+
# global best_score
105+
# tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
106+
# va_logits, va_cost = iter_apply(vaX, vaM, vaY)
107+
# tr_cost = tr_cost/len(trY[:n_valid])
108+
# va_cost = va_cost/n_valid
109+
# tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
110+
# va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
111+
# logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
112+
# print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
113+
# if submit:
114+
# score = va_acc
115+
# if score > best_score:
116+
# best_score = score
117+
# save(os.path.join(save_dir, desc, 'best_params.jl'))
141118

142-
def log():
143-
global best_score
144-
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
145-
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
146-
tr_cost = tr_cost/len(trY[:n_valid])
147-
va_cost = va_cost/n_valid
148-
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
149-
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
150-
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
151-
print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
152-
if submit:
153-
score = va_acc
154-
if score > best_score:
155-
best_score = score
156-
save(os.path.join(save_dir, desc, 'best_params.jl'))
119+
# def predict():
120+
# filename = filenames[dataset]
121+
# pred_fn = pred_fns[dataset]
122+
# label_decoder = label_decoders[dataset]
123+
# predictions = pred_fn(iter_predict(teX, teM))
124+
# if label_decoder is not None:
125+
# predictions = [label_decoder[prediction] for prediction in predictions]
126+
# path = os.path.join(submission_dir, filename)
127+
# os.makedirs(os.path.dirname(path), exist_ok=True)
128+
# with open(path, 'w') as f:
129+
# f.write('{}\t{}\n'.format('index', 'prediction'))
130+
# for i, prediction in enumerate(predictions):
131+
# f.write('{}\t{}\n'.format(i, prediction))
157132

158133
argmax = lambda x:np.argmax(x, 1)
159134

@@ -169,20 +144,6 @@ def log():
169144
'rocstories':None,
170145
}
171146

172-
def predict():
173-
filename = filenames[dataset]
174-
pred_fn = pred_fns[dataset]
175-
label_decoder = label_decoders[dataset]
176-
predictions = pred_fn(iter_predict(teX, teM))
177-
if label_decoder is not None:
178-
predictions = [label_decoder[prediction] for prediction in predictions]
179-
path = os.path.join(submission_dir, filename)
180-
os.makedirs(os.path.dirname(path), exist_ok=True)
181-
with open(path, 'w') as f:
182-
f.write('{}\t{}\n'.format('index', 'prediction'))
183-
for i, prediction in enumerate(predictions):
184-
f.write('{}\t{}\n'.format(i, prediction))
185-
186147
if __name__ == '__main__':
187148
parser = argparse.ArgumentParser()
188149
parser.add_argument('--desc', type=str)
@@ -260,56 +221,38 @@ def predict():
260221
n_updates_total = (n_train//n_batch_train)*n_iter
261222

262223
model = Model(vocab, cfg)
263-
# TODO Initialize model
224+
lm_head = LMHead(model, cfg)
225+
clf_head = ClfHead(model, clf_token, cfg)
226+
compute_loss = LossCompute(nn.CrossEntropyLoss, nn.CrossEntropyLoss)
227+
# TODO Initialize model (?)
264228

265-
# Load weights from TF model
266-
shapes = json.load(open('model/params_shapes.json'))
267-
names = json.load(open('model/parameters_names.json'))
268-
offsets = np.cumsum([np.prod(shape) for shape in shapes])
269-
init_params = [np.load('model/params_{}.npy'.format(n)) for n in range(10)]
270-
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
271-
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
272-
init_params[0] = init_params[0][:n_ctx]
273-
init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, n_embd)*0.02).astype(np.float32), init_params[0]], 0)
274-
del init_params[1]
275-
if n_transfer == -1:
276-
n_transfer = 0
277-
else:
278-
n_transfer = 1+n_transfer*12
279-
assert model.embed.weight.shape == init_params[0].shape
280-
model.embed.weight = init_params[0]
281-
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
282-
name = name[6:] # skip "model/"
283-
assert name[-2:] == ":0"
284-
name = name[:-2]
285-
name = name.split('/')
286-
pointer = model
287-
for m_name in name:
288-
l = re.split('(\d+)', m_name)
289-
pointer = getattr(pointer, l[0])
290-
if len(l) == 1:
291-
num = int(l[1])
292-
pointer = pointer[num]
293-
assert pointer.shape == ip.shape
294-
pointer = ip
229+
load_openai_pretrained_model(model, n_ctx, n_special, cfg)
295230

296231
n_updates = 0
297232
n_epochs = 0
298233
if dataset != 'stsb':
299234
trYt = trY
300235
if submit:
301-
save(os.path.join(save_dir, desc, 'best_params.jl'))
236+
path = os.path.join(save_dir, desc, 'best_params')
237+
torch.save(model.state_dict(), make_path(path))
238+
302239
best_score = 0
303240
for i in range(n_iter):
304-
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random), n_batch=n_batch_train, truncate=True, verbose=True):
305-
cost, _ = sess.run([clf_loss, train], {X_train:xmb, M_train:mmb, Y_train:ymb})
241+
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
242+
n_batch=n_batch_train, truncate=True, verbose=True):
243+
h = model(xmb, mmb)
244+
lm_logits = lm_head(h)
245+
clf_logits = clf_head(h, xmb)
246+
loss = compute_loss(xmb, ymb, mmb, lm_logits, clf_logits)
247+
loss.backward()
248+
306249
n_updates += 1
307250
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
308-
log()
251+
# log()
309252
n_epochs += 1
310-
log()
311-
if submit:
312-
sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))])
313-
predict()
314-
if analysis:
315-
rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), os.path.join(log_dir, 'rocstories.jsonl'))
253+
# log()
254+
# if submit:
255+
# sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))])
256+
# predict()
257+
# if analysis:
258+
# rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), os.path.join(log_dir, 'rocstories.jsonl'))

0 commit comments

Comments
 (0)