Skip to content

Commit 1f209c4

Browse files
committed
first version of model + weights transfer
1 parent 5eb7937 commit 1f209c4

File tree

2 files changed

+95
-95
lines changed

2 files changed

+95
-95
lines changed

model.py renamed to model_py.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,18 @@
66
import torch.nn.functional as F
77
from torch.nn.parameter import Parameter
88

9-
vocab = n_vocab + n_special + n_ctx
10-
119
def gelu(x):
1210
return 0.5*x*(1+torch.tanh(math.sqrt(2/math.pi)*(x+0.044715*torch.pow(x, 3))))
1311

1412
def swish(x):
1513
return x*torch.sigmoid(x)
1614

1715
ACT_FNS = {
18-
'relu': nn.relu,
16+
'relu': nn.ReLU,
1917
'swish': swish,
2018
'gelu': gelu
2119
}
2220

23-
def clones(module, N):
24-
"Produce N identical layers."
25-
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
26-
2721

2822
class LayerNorm(nn.Module):
2923
"Construct a layernorm module (See citation for details)."
@@ -44,6 +38,7 @@ class Conv1D(nn.Module):
4438
def __init__(self, nf, rf, nx):
4539
super(Conv1D, self).__init__()
4640
self.rf = rf
41+
self.nf = nf
4742
if rf == 1: #faster 1x1 conv
4843
self.w = Parameter(torch.ones(nx, nf)) # TODO change to random normal
4944
self.b = Parameter(torch.zeros(nf))
@@ -52,7 +47,7 @@ def __init__(self, nf, rf, nx):
5247

5348
def forward(self, x):
5449
if self.rf == 1:
55-
size_out = x.size()[:-1] + [nf]
50+
size_out = x.size()[:-1] + [self.nf]
5651
x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
5752
x = x.view(*size_out)
5853
else:
@@ -61,14 +56,17 @@ def forward(self, x):
6156

6257

6358
class Attention(nn.Module):
64-
def __init__(self, nx, n_state, n_head, attn_pdrop, resid_pdrop, scale=False):
59+
def __init__(self, nx, cfg, scale=False):
6560
super(Attention, self).__init__()
66-
self.c_attn = Conv1D(n_state*3, 1, nx)
67-
self.c_proj = Conv1D(n_state, 1, nx)
61+
n_state = nx # in Attention: n_state=768 (nx=n_embed)
62+
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
63+
assert n_state % cfg.n_head==0
64+
self.n_head = cfg.n_head
6865
self.scale = scale
69-
self.n_head = n_head
70-
self.attn_dropout = nn.Dropout(attn_pdrop)
71-
self.resid_dropout = nn.Dropout(resid_pdrop)
66+
self.c_attn = Conv1D(n_state * 3, 1, nx)
67+
self.c_proj = Conv1D(n_state, 1, nx)
68+
self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
69+
self.resid_dropout = nn.Dropout(cfg.resid_pdrop)
7270

7371
@staticmethod
7472
def mask_attn_weights(w):
@@ -87,12 +85,12 @@ def _attn(self, q, k, v):
8785

8886
def merge_heads(self, x):
8987
new_x_shape = x.size()[:-2] + [np.prod(x.size()[-2:])]
90-
x = x.view(*new_x_shape) # in Tensorflow version: merge_states
88+
x = x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
9189
return x.permute(0, 2, 1, 3)
9290

9391
def split_heads(self, x, k=False):
9492
new_x_shape = x.size()[:-1] + [self.n_head, x.size(-1)//self.n_head]
95-
x = x.view(*new_x_shape) # in Tensorflow version: split_states
93+
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
9694
if k:
9795
return x.permute(0, 2, 3, 1)
9896
else:
@@ -112,53 +110,55 @@ def forward(self, x):
112110

113111

114112
class MLP(nn.Module):
115-
def __init__(self, nx, n_state, afn, resid_pdrop):
113+
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embed)
116114
super(MLP, self).__init__()
115+
nx = cfg.n_embed
117116
self.c_fc = Conv1D(n_state, 1, nx)
118117
self.c_proj = Conv1D(nx, 1, nx)
119-
self.act = ACT_FNS[afn]
120-
self.dropout = nn.Dropout(resid_pdrop)
118+
self.act = ACT_FNS[cfg.afn]
119+
self.dropout = nn.Dropout(cfg.resid_pdrop)
121120

122121
def forward(self, x):
123122
h = self.act(self.c_fc(x))
124-
h = self.c_proj(h)
125-
return self.dropout(h)
123+
h2 = self.c_proj(h)
124+
return self.dropout(h2)
126125

127126

128127
class Block(nn.Module):
129-
def __init__(self, nx, n_head, attn_pdrop, resid_pdrop, afn, scale=False):
128+
def __init__(self, cfg, scale=False):
130129
super(Block, self).__init__()
131-
self.attn = Attention(nx, nx, n_head, attn_pdrop, resid_pdrop, scale)
130+
nx = cfg.n_embed
131+
self.attn = Attention(nx, cfg, scale)
132132
self.ln_1 = LayerNorm(nx)
133-
self.mlp = MLP(nx, nx*4, afn, resid_pdrop)
133+
self.mlp = MLP(4*nx, cfg)
134134
self.ln_2 = LayerNorm(nx)
135135

136136
def forward(self, x):
137-
h = self.attn(x)
138-
h = self.ln_1(x)
139-
h = self.mlp(x)
140-
h = self.ln_2(x)
137+
a = self.attn(x)
138+
n = self.ln_1(x+a)
139+
m = self.mlp(n)
140+
h = self.ln_2(n+m)
141141
return h
142142

143143

144144
class Model(nn.Module):
145145
""" Transformer model """
146-
def __init__(self, vocab, n_embd, pdrop, n_layers,
147-
nx, n_head, attn_pdrop, resid_pdrop, afn):
146+
def __init__(self, vocab, cfg):
148147
super(Model, self).__init__()
149-
self.embed = nn.Embedding(vocab, n_embd)
150-
self.drop = nn.Dropout(pdrop)
151-
self.blocks = clones(Block(nx, n_head, attn_pdrop,
152-
resid_pdrop, afn, scale=True), n_layers)
153-
self.decoder = nn.Linear(nhid, vocab, bias=False)
154-
self.decoder.weight = self.embed.weight
148+
self.embed = nn.Embedding(vocab, cfg.n_embd)
149+
self.drop = nn.Dropout(cfg.embd_pdrop)
150+
block = Block(cfg, scale=True)
151+
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
152+
self.decoder = nn.Linear(cfg.n_embed, vocab, bias=False)
153+
self.decoder.weight = self.embed.weight # Tied weights
154+
self.clf_dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
155155

156156
def forward(self, x, m):
157157
x = x.view(-1, x.size(2), x.size(3))
158158
m = m.view(-1, m.size(2))
159159
e = self.embed(x)
160160
h = e.sum(dim=2)
161-
for block in self.blocks:
161+
for block in self.h:
162162
h = block(h)
163163

164164
# Language modeling logits
@@ -167,12 +167,11 @@ def forward(self, x, m):
167167

168168
# Classification logits
169169
clf_h = h.view(-1, self.n_embed)
170-
pool_idx = torch.eq(X[:, :, 0].contiguous().view(-1), self.clf_token)
170+
pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
171171
clf_h = clf_h[pool_idx, :]
172172
clf_h = clf_h.view(-1, 2, self.n_embed, 1)
173-
m = nn.Dropout2d(clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
174-
clf_h = m(clf_h)
173+
clf_h = self.clf_dropout(clf_h)
175174
clf_h = clf_h.view(-1, self.n_embed)
176175
clf_logits = self.linear(clf_h)
177176

178-
return lm_logits, clf_logits
177+
return lm_logits, clf_logits

train.py

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import os
23
import time
34
import math
@@ -16,6 +17,7 @@
1617
from sklearn.utils import shuffle
1718
from sklearn.metrics import accuracy_score
1819

20+
from model_py import Model
1921
from opt import adam, warmup_cosine, warmup_linear, warmup_constant
2022
from datasets import rocstories
2123
from analysis import rocstories as rocstories_analysis
@@ -51,45 +53,44 @@ def __call__(self, X, Y, M, lm_logits, clf_logits, norm):
5153

5254
# Classification loss
5355
clf_losses = self.clf_criterion(clf_logits, Y)
54-
5556
if lm_coef > 0:
5657
train_loss = clf_losses.sum() + lm_coef * lm_losses.sum())
5758
else:
5859
train_loss = clf_losses.sum()
5960
return train_loss
6061

61-
def mgpu_train(*xs):
62-
gpu_ops = []
63-
gpu_grads = []
64-
xs = (tf.split(x, n_gpu, 0) for x in xs)
65-
for i, xs in enumerate(zip(*xs)):
66-
do_reuse = True if i > 0 else None
67-
with tf.device(assign_to_gpu(i, "/gpu:0")), tf.variable_scope(tf.get_variable_scope(), reuse=do_reuse):
68-
clf_logits, clf_losses, lm_losses = model(*xs, train=True, reuse=do_reuse)
69-
if lm_coef > 0:
70-
train_loss = tf.reduce_mean(clf_losses) + lm_coef*tf.reduce_mean(lm_losses)
71-
else:
72-
train_loss = tf.reduce_mean(clf_losses)
73-
params = find_trainable_variables("model")
74-
grads = tf.gradients(train_loss, params)
75-
grads = list(zip(grads, params))
76-
gpu_grads.append(grads)
77-
gpu_ops.append([clf_logits, clf_losses, lm_losses])
78-
ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
79-
grads = average_grads(gpu_grads)
80-
grads = [g for g, p in grads]
81-
train = opt_fns[opt](params, grads, lr, partial(lr_schedules[lr_schedule], warmup=lr_warmup), n_updates_total, l2=l2, max_grad_norm=max_grad_norm, vector_l2=vector_l2, b1=b1, b2=b2, e=e)
82-
return [train]+ops
83-
84-
def mgpu_predict(*xs):
85-
gpu_ops = []
86-
xs = (tf.split(x, n_gpu, 0) for x in xs)
87-
for i, xs in enumerate(zip(*xs)):
88-
with tf.device(assign_to_gpu(i, "/gpu:0")), tf.variable_scope(tf.get_variable_scope(), reuse=True):
89-
clf_logits, clf_losses, lm_losses = model(*xs, train=False, reuse=True)
90-
gpu_ops.append([clf_logits, clf_losses, lm_losses])
91-
ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
92-
return ops
62+
# def mgpu_train(*xs):
63+
# gpu_ops = []
64+
# gpu_grads = []
65+
# xs = (tf.split(x, n_gpu, 0) for x in xs)
66+
# for i, xs in enumerate(zip(*xs)):
67+
# do_reuse = True if i > 0 else None
68+
# with tf.device(assign_to_gpu(i, "/gpu:0")), tf.variable_scope(tf.get_variable_scope(), reuse=do_reuse):
69+
# clf_logits, clf_losses, lm_losses = model(*xs, train=True, reuse=do_reuse)
70+
# if lm_coef > 0:
71+
# train_loss = tf.reduce_mean(clf_losses) + lm_coef*tf.reduce_mean(lm_losses)
72+
# else:
73+
# train_loss = tf.reduce_mean(clf_losses)
74+
# params = find_trainable_variables("model")
75+
# grads = tf.gradients(train_loss, params)
76+
# grads = list(zip(grads, params))
77+
# gpu_grads.append(grads)
78+
# gpu_ops.append([clf_logits, clf_losses, lm_losses])
79+
# ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
80+
# grads = average_grads(gpu_grads)
81+
# grads = [g for g, p in grads]
82+
# train = opt_fns[opt](params, grads, lr, partial(lr_schedules[lr_schedule], warmup=lr_warmup), n_updates_total, l2=l2, max_grad_norm=max_grad_norm, vector_l2=vector_l2, b1=b1, b2=b2, e=e)
83+
# return [train]+ops
84+
85+
# def mgpu_predict(*xs):
86+
# gpu_ops = []
87+
# xs = (tf.split(x, n_gpu, 0) for x in xs)
88+
# for i, xs in enumerate(zip(*xs)):
89+
# with tf.device(assign_to_gpu(i, "/gpu:0")), tf.variable_scope(tf.get_variable_scope(), reuse=True):
90+
# clf_logits, clf_losses, lm_losses = model(*xs, train=False, reuse=True)
91+
# gpu_ops.append([clf_logits, clf_losses, lm_losses])
92+
# ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
93+
# return ops
9394

9495
def transform_roc(X1, X2, X3):
9596
n_batch = len(X1)
@@ -247,6 +248,7 @@ def predict():
247248
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
248249
)+3, n_ctx
249250
)
251+
vocab = n_vocab + n_special + n_ctx
250252
trX, trM = transform_roc(trX1, trX2, trX3)
251253
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
252254
if submit:
@@ -257,40 +259,39 @@ def predict():
257259
n_batch_train = n_batch*n_gpu
258260
n_updates_total = (n_train//n_batch_train)*n_iter
259261

260-
X_train = tf.placeholder(tf.int32, [n_batch_train, 2, n_ctx, 2])
261-
M_train = tf.placeholder(tf.float32, [n_batch_train, 2, n_ctx])
262-
X = tf.placeholder(tf.int32, [None, 2, n_ctx, 2])
263-
M = tf.placeholder(tf.float32, [None, 2, n_ctx])
264-
265-
Y_train = tf.placeholder(tf.int32, [n_batch_train])
266-
Y = tf.placeholder(tf.int32, [None])
267-
268-
train, logits, clf_losses, lm_losses = mgpu_train(X_train, M_train, Y_train)
269-
clf_loss = tf.reduce_mean(clf_losses)
270-
271-
params = find_trainable_variables('model')
272-
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
273-
sess.run(tf.global_variables_initializer())
262+
model = Model(vocab, cfg)
263+
# TODO Initialize model
274264

265+
# Load weights from TF model
275266
shapes = json.load(open('model/params_shapes.json'))
267+
names = json.load(open('model/parameters_names.json'))
276268
offsets = np.cumsum([np.prod(shape) for shape in shapes])
277269
init_params = [np.load('model/params_{}.npy'.format(n)) for n in range(10)]
278270
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
279271
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
280272
init_params[0] = init_params[0][:n_ctx]
281273
init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, n_embd)*0.02).astype(np.float32), init_params[0]], 0)
282274
del init_params[1]
283-
284275
if n_transfer == -1:
285276
n_transfer = 0
286277
else:
287278
n_transfer = 1+n_transfer*12
288-
sess.run([p.assign(ip) for p, ip in zip(params[:n_transfer], init_params[:n_transfer])])
289-
290-
eval_mgpu_logits, eval_mgpu_clf_losses, eval_mgpu_lm_losses = mgpu_predict(X_train, M_train, Y_train)
291-
eval_logits, eval_clf_losses, eval_lm_losses = model(X, M, Y, train=False, reuse=True)
292-
eval_clf_loss = tf.reduce_mean(eval_clf_losses)
293-
eval_mgpu_clf_loss = tf.reduce_mean(eval_mgpu_clf_losses)
279+
assert model.embed.weight.shape == init_params[0].shape
280+
model.embed.weight = init_params[0]
281+
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
282+
name = name[6:] # skip "model/"
283+
assert name[-2:] == ":0"
284+
name = name[:-2]
285+
name = name.split('/')
286+
pointer = model
287+
for m_name in name:
288+
l = re.split('(\d+)', m_name)
289+
pointer = getattr(pointer, l[0])
290+
if len(l) == 1:
291+
num = int(l[1])
292+
pointer = pointer[num]
293+
assert pointer.shape == ip.shape
294+
pointer = ip
294295

295296
n_updates = 0
296297
n_epochs = 0

0 commit comments

Comments
 (0)