17
17
from sklearn .utils import shuffle
18
18
from sklearn .metrics import accuracy_score
19
19
20
- from model_py import Model
20
+ from model_py import Model , LMHead , ClfHead , load_openai_pretrained_model
21
21
from opt import adam , warmup_cosine , warmup_linear , warmup_constant
22
22
from datasets import rocstories
23
23
from analysis import rocstories as rocstories_analysis
24
24
from text_utils import TextEncoder
25
- from utils import encode_dataset , flatten , iter_data , find_trainable_variables , get_ema_vars , convert_gradient_to_tensor , shape_list , ResultLogger , assign_to_gpu , average_grads , make_path
25
+ from utils import (encode_dataset , flatten , iter_data ,
26
+ ResultLogger , make_path )
26
27
27
28
OPT_FNS = {
28
29
'adam' :adam ,
36
37
37
38
class LossCompute :
38
39
"A Loss compute and train function."
39
- def __init__ (self , generator , lm_criterion , n_embed , clf_token , opt = None ):
40
- self .generator = generator
40
+ def __init__ (self , lm_criterion , clf_criterion ):
41
41
self .lm_criterion = lm_criterion
42
- self .opt = opt
43
- self .n_embed = n_embed
44
- self .clf_token = clf_token
42
+ self .clf_criterion = clf_criterion
45
43
46
- def __call__ (self , X , Y , M , lm_logits , clf_logits , norm ):
44
+ def __call__ (self , X , Y , M , lm_logits , clf_logits ):
47
45
# Language modeling loss
48
46
x_shifted = X [:, 1 :, 0 ].contiguous ().view (- 1 ) # Shape: 252
49
47
lm_losses = self .lm_criterion (lm_logits , x_shifted )
@@ -59,39 +57,6 @@ def __call__(self, X, Y, M, lm_logits, clf_logits, norm):
59
57
train_loss = clf_losses .sum ()
60
58
return train_loss
61
59
62
- # def mgpu_train(*xs):
63
- # gpu_ops = []
64
- # gpu_grads = []
65
- # xs = (tf.split(x, n_gpu, 0) for x in xs)
66
- # for i, xs in enumerate(zip(*xs)):
67
- # do_reuse = True if i > 0 else None
68
- # with tf.device(assign_to_gpu(i, "/gpu:0")), tf.variable_scope(tf.get_variable_scope(), reuse=do_reuse):
69
- # clf_logits, clf_losses, lm_losses = model(*xs, train=True, reuse=do_reuse)
70
- # if lm_coef > 0:
71
- # train_loss = tf.reduce_mean(clf_losses) + lm_coef*tf.reduce_mean(lm_losses)
72
- # else:
73
- # train_loss = tf.reduce_mean(clf_losses)
74
- # params = find_trainable_variables("model")
75
- # grads = tf.gradients(train_loss, params)
76
- # grads = list(zip(grads, params))
77
- # gpu_grads.append(grads)
78
- # gpu_ops.append([clf_logits, clf_losses, lm_losses])
79
- # ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
80
- # grads = average_grads(gpu_grads)
81
- # grads = [g for g, p in grads]
82
- # train = opt_fns[opt](params, grads, lr, partial(lr_schedules[lr_schedule], warmup=lr_warmup), n_updates_total, l2=l2, max_grad_norm=max_grad_norm, vector_l2=vector_l2, b1=b1, b2=b2, e=e)
83
- # return [train]+ops
84
-
85
- # def mgpu_predict(*xs):
86
- # gpu_ops = []
87
- # xs = (tf.split(x, n_gpu, 0) for x in xs)
88
- # for i, xs in enumerate(zip(*xs)):
89
- # with tf.device(assign_to_gpu(i, "/gpu:0")), tf.variable_scope(tf.get_variable_scope(), reuse=True):
90
- # clf_logits, clf_losses, lm_losses = model(*xs, train=False, reuse=True)
91
- # gpu_ops.append([clf_logits, clf_losses, lm_losses])
92
- # ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
93
- # return ops
94
-
95
60
def transform_roc (X1 , X2 , X3 ):
96
61
n_batch = len (X1 )
97
62
xmb = np .zeros ((n_batch , 2 , n_ctx , 2 ), dtype = np .int32 )
@@ -110,50 +75,60 @@ def transform_roc(X1, X2, X3):
110
75
xmb [:, :, :, 1 ] = np .arange (n_vocab + n_special , n_vocab + n_special + n_ctx )
111
76
return xmb , mmb
112
77
113
- def iter_apply (Xs , Ms , Ys ):
114
- fns = [lambda x :np .concatenate (x , 0 ), lambda x :float (np .sum (x ))]
115
- results = []
116
- for xmb , mmb , ymb in iter_data (Xs , Ms , Ys , n_batch = n_batch_train , truncate = False , verbose = True ):
117
- n = len (xmb )
118
- if n == n_batch_train :
119
- res = sess .run ([eval_mgpu_logits , eval_mgpu_clf_loss ], {X_train :xmb , M_train :mmb , Y_train :ymb })
120
- else :
121
- res = sess .run ([eval_logits , eval_clf_loss ], {X :xmb , M :mmb , Y :ymb })
122
- res = [r * n for r in res ]
123
- results .append (res )
124
- results = zip (* results )
125
- return [fn (res ) for res , fn in zip (results , fns )]
78
+ # def iter_apply(Xs, Ms, Ys):
79
+ # fns = [lambda x:np.concatenate(x, 0), lambda x:float(np.sum(x))]
80
+ # results = []
81
+ # for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
82
+ # n = len(xmb)
83
+ # if n == n_batch_train:
84
+ # res = sess.run([eval_mgpu_logits, eval_mgpu_clf_loss], {X_train:xmb, M_train:mmb, Y_train:ymb})
85
+ # else:
86
+ # res = sess.run([eval_logits, eval_clf_loss], {X:xmb, M:mmb, Y:ymb})
87
+ # res = [r*n for r in res]
88
+ # results.append(res)
89
+ # results = zip(*results)
90
+ # return [fn(res) for res, fn in zip(results, fns)]
126
91
127
- def iter_predict (Xs , Ms ):
128
- logits = []
129
- for xmb , mmb in iter_data (Xs , Ms , n_batch = n_batch_train , truncate = False , verbose = True ):
130
- n = len (xmb )
131
- if n == n_batch_train :
132
- logits .append (sess .run (eval_mgpu_logits , {X_train :xmb , M_train :mmb }))
133
- else :
134
- logits .append (sess .run (eval_logits , {X :xmb , M :mmb }))
135
- logits = np .concatenate (logits , 0 )
136
- return logits
92
+ # def iter_predict(Xs, Ms):
93
+ # logits = []
94
+ # for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
95
+ # n = len(xmb)
96
+ # if n == n_batch_train:
97
+ # logits.append(sess.run(eval_mgpu_logits, {X_train:xmb, M_train:mmb}))
98
+ # else:
99
+ # logits.append(sess.run(eval_logits, {X:xmb, M:mmb}))
100
+ # logits = np.concatenate(logits, 0)
101
+ # return logits
137
102
138
- def save (path ):
139
- ps = sess .run (params )
140
- joblib .dump (ps , make_path (path ))
103
+ # def log():
104
+ # global best_score
105
+ # tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
106
+ # va_logits, va_cost = iter_apply(vaX, vaM, vaY)
107
+ # tr_cost = tr_cost/len(trY[:n_valid])
108
+ # va_cost = va_cost/n_valid
109
+ # tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
110
+ # va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
111
+ # logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
112
+ # print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
113
+ # if submit:
114
+ # score = va_acc
115
+ # if score > best_score:
116
+ # best_score = score
117
+ # save(os.path.join(save_dir, desc, 'best_params.jl'))
141
118
142
- def log ():
143
- global best_score
144
- tr_logits , tr_cost = iter_apply (trX [:n_valid ], trM [:n_valid ], trY [:n_valid ])
145
- va_logits , va_cost = iter_apply (vaX , vaM , vaY )
146
- tr_cost = tr_cost / len (trY [:n_valid ])
147
- va_cost = va_cost / n_valid
148
- tr_acc = accuracy_score (trY [:n_valid ], np .argmax (tr_logits , 1 ))* 100.
149
- va_acc = accuracy_score (vaY , np .argmax (va_logits , 1 ))* 100.
150
- logger .log (n_epochs = n_epochs , n_updates = n_updates , tr_cost = tr_cost , va_cost = va_cost , tr_acc = tr_acc , va_acc = va_acc )
151
- print ('%d %d %.3f %.3f %.2f %.2f' % (n_epochs , n_updates , tr_cost , va_cost , tr_acc , va_acc ))
152
- if submit :
153
- score = va_acc
154
- if score > best_score :
155
- best_score = score
156
- save (os .path .join (save_dir , desc , 'best_params.jl' ))
119
+ # def predict():
120
+ # filename = filenames[dataset]
121
+ # pred_fn = pred_fns[dataset]
122
+ # label_decoder = label_decoders[dataset]
123
+ # predictions = pred_fn(iter_predict(teX, teM))
124
+ # if label_decoder is not None:
125
+ # predictions = [label_decoder[prediction] for prediction in predictions]
126
+ # path = os.path.join(submission_dir, filename)
127
+ # os.makedirs(os.path.dirname(path), exist_ok=True)
128
+ # with open(path, 'w') as f:
129
+ # f.write('{}\t{}\n'.format('index', 'prediction'))
130
+ # for i, prediction in enumerate(predictions):
131
+ # f.write('{}\t{}\n'.format(i, prediction))
157
132
158
133
argmax = lambda x :np .argmax (x , 1 )
159
134
@@ -169,20 +144,6 @@ def log():
169
144
'rocstories' :None ,
170
145
}
171
146
172
- def predict ():
173
- filename = filenames [dataset ]
174
- pred_fn = pred_fns [dataset ]
175
- label_decoder = label_decoders [dataset ]
176
- predictions = pred_fn (iter_predict (teX , teM ))
177
- if label_decoder is not None :
178
- predictions = [label_decoder [prediction ] for prediction in predictions ]
179
- path = os .path .join (submission_dir , filename )
180
- os .makedirs (os .path .dirname (path ), exist_ok = True )
181
- with open (path , 'w' ) as f :
182
- f .write ('{}\t {}\n ' .format ('index' , 'prediction' ))
183
- for i , prediction in enumerate (predictions ):
184
- f .write ('{}\t {}\n ' .format (i , prediction ))
185
-
186
147
if __name__ == '__main__' :
187
148
parser = argparse .ArgumentParser ()
188
149
parser .add_argument ('--desc' , type = str )
@@ -260,56 +221,38 @@ def predict():
260
221
n_updates_total = (n_train // n_batch_train )* n_iter
261
222
262
223
model = Model (vocab , cfg )
263
- # TODO Initialize model
224
+ lm_head = LMHead (model , cfg )
225
+ clf_head = ClfHead (model , clf_token , cfg )
226
+ compute_loss = LossCompute (nn .CrossEntropyLoss , nn .CrossEntropyLoss )
227
+ # TODO Initialize model (?)
264
228
265
- # Load weights from TF model
266
- shapes = json .load (open ('model/params_shapes.json' ))
267
- names = json .load (open ('model/parameters_names.json' ))
268
- offsets = np .cumsum ([np .prod (shape ) for shape in shapes ])
269
- init_params = [np .load ('model/params_{}.npy' .format (n )) for n in range (10 )]
270
- init_params = np .split (np .concatenate (init_params , 0 ), offsets )[:- 1 ]
271
- init_params = [param .reshape (shape ) for param , shape in zip (init_params , shapes )]
272
- init_params [0 ] = init_params [0 ][:n_ctx ]
273
- init_params [0 ] = np .concatenate ([init_params [1 ], (np .random .randn (n_special , n_embd )* 0.02 ).astype (np .float32 ), init_params [0 ]], 0 )
274
- del init_params [1 ]
275
- if n_transfer == - 1 :
276
- n_transfer = 0
277
- else :
278
- n_transfer = 1 + n_transfer * 12
279
- assert model .embed .weight .shape == init_params [0 ].shape
280
- model .embed .weight = init_params [0 ]
281
- for name , ip in zip (names [1 :n_transfer ], init_params [1 :n_transfer ]):
282
- name = name [6 :] # skip "model/"
283
- assert name [- 2 :] == ":0"
284
- name = name [:- 2 ]
285
- name = name .split ('/' )
286
- pointer = model
287
- for m_name in name :
288
- l = re .split ('(\d+)' , m_name )
289
- pointer = getattr (pointer , l [0 ])
290
- if len (l ) == 1 :
291
- num = int (l [1 ])
292
- pointer = pointer [num ]
293
- assert pointer .shape == ip .shape
294
- pointer = ip
229
+ load_openai_pretrained_model (model , n_ctx , n_special , cfg )
295
230
296
231
n_updates = 0
297
232
n_epochs = 0
298
233
if dataset != 'stsb' :
299
234
trYt = trY
300
235
if submit :
301
- save (os .path .join (save_dir , desc , 'best_params.jl' ))
236
+ path = os .path .join (save_dir , desc , 'best_params' )
237
+ torch .save (model .state_dict (), make_path (path ))
238
+
302
239
best_score = 0
303
240
for i in range (n_iter ):
304
- for xmb , mmb , ymb in iter_data (* shuffle (trX , trM , trYt , random_state = np .random ), n_batch = n_batch_train , truncate = True , verbose = True ):
305
- cost , _ = sess .run ([clf_loss , train ], {X_train :xmb , M_train :mmb , Y_train :ymb })
241
+ for xmb , mmb , ymb in iter_data (* shuffle (trX , trM , trYt , random_state = np .random ),
242
+ n_batch = n_batch_train , truncate = True , verbose = True ):
243
+ h = model (xmb , mmb )
244
+ lm_logits = lm_head (h )
245
+ clf_logits = clf_head (h , xmb )
246
+ loss = compute_loss (xmb , ymb , mmb , lm_logits , clf_logits )
247
+ loss .backward ()
248
+
306
249
n_updates += 1
307
250
if n_updates in [1000 , 2000 , 4000 , 8000 , 16000 , 32000 ] and n_epochs == 0 :
308
- log ()
251
+ # log()
309
252
n_epochs += 1
310
- log ()
311
- if submit :
312
- sess .run ([p .assign (ip ) for p , ip in zip (params , joblib .load (os .path .join (save_dir , desc , 'best_params.jl' )))])
313
- predict ()
314
- if analysis :
315
- rocstories_analysis (data_dir , os .path .join (submission_dir , 'ROCStories.tsv' ), os .path .join (log_dir , 'rocstories.jsonl' ))
253
+ # log()
254
+ # if submit:
255
+ # sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))])
256
+ # predict()
257
+ # if analysis:
258
+ # rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), os.path.join(log_dir, 'rocstories.jsonl'))
0 commit comments