@@ -33,27 +33,28 @@ def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
33
33
self .lm_coef = lm_coef
34
34
self .opt = opt
35
35
36
- def __call__ (self , X , Y , M , lm_logits , clf_logits ):
36
+ def __call__ (self , X , Y , M , clf_logits , lm_logits = None , only_return_losses = False ):
37
37
# Language modeling loss
38
- x_shifted = X [:, :, 1 :, 0 ]. contiguous (). view ( - 1 ) # Shape: 252
39
- M = M . view ( - 1 , M . size ( 2 ))
40
- lm_losses = self . lm_criterion ( lm_logits , x_shifted )
41
- lm_losses = lm_losses . view ( X . size ( 0 ) * X . size ( 1 ), X . size ( 2 ) - 1 )
42
- lm_losses = lm_losses * M [:, 1 :]
43
- lm_losses = lm_losses . sum ( 1 ) / torch . sum ( M [:, 1 :], 1 )
44
-
38
+ if lm_logits is not None :
39
+ x_shifted = X [:, :, 1 :, 0 ]. contiguous (). view ( - 1 ) # Shape: 252
40
+ M = M . view ( - 1 , M . size ( 2 ) )
41
+ lm_losses = self . lm_criterion ( lm_logits , x_shifted )
42
+ lm_losses = lm_losses . view ( X . size ( 0 ) * X . size ( 1 ), X . size ( 2 ) - 1 )
43
+ lm_losses = lm_losses * M [:, 1 :]
44
+ lm_losses = lm_losses . sum ( 1 ) / torch . sum ( M [:, 1 :], 1 )
45
45
# Classification loss
46
46
clf_losses = self .clf_criterion (clf_logits , Y )
47
+ if only_return_losses :
48
+ return (clf_losses , lm_losses ) if lm_logits is not None else clf_losses
47
49
48
- if self .lm_coef > 0 :
50
+ if self .lm_coef > 0 and lm_logits is not None :
49
51
train_loss = clf_losses .sum () + self .lm_coef * lm_losses .sum ()
50
52
else :
51
53
train_loss = clf_losses .sum ()
52
-
53
54
train_loss .backward ()
54
55
if self .opt is not None :
55
56
self .opt .step ()
56
- self .opt .optimizer . zero_grad ()
57
+ self .opt .zero_grad ()
57
58
return train_loss .item ()
58
59
59
60
@@ -75,60 +76,84 @@ def transform_roc(X1, X2, X3):
75
76
xmb [:, :, :, 1 ] = np .arange (n_vocab + n_special , n_vocab + n_special + n_ctx )
76
77
return xmb , mmb
77
78
78
- # def iter_apply(Xs, Ms, Ys):
79
- # fns = [lambda x:np.concatenate(x, 0), lambda x:float(np.sum(x))]
80
- # results = []
81
- # for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
82
- # n = len(xmb)
83
- # if n == n_batch_train:
84
- # res = sess.run([eval_mgpu_logits, eval_mgpu_clf_loss], {X_train:xmb, M_train:mmb, Y_train:ymb})
85
- # else:
86
- # res = sess.run([eval_logits, eval_clf_loss], {X:xmb, M:mmb, Y:ymb})
87
- # res = [r*n for r in res]
88
- # results.append(res)
89
- # results = zip(*results)
90
- # return [fn(res) for res, fn in zip(results, fns)]
79
+ def iter_apply (Xs , Ms , Ys ):
80
+ fns = [lambda x : np .concatenate (x , 0 ), lambda x : float (np .sum (x ))]
81
+ results = []
82
+ with torch .no_grad ():
83
+ model .eval ()
84
+ for xmb , mmb , ymb in iter_data (Xs , Ms , Ys , n_batch = n_batch_train , truncate = False , verbose = True ):
85
+ n = len (xmb )
86
+ XMB = torch .tensor (xmb , dtype = torch .long ).to (device )
87
+ YMB = torch .tensor (ymb , dtype = torch .long ).to (device )
88
+ MMB = torch .tensor (mmb ).to (device )
89
+ h = model (XMB )
90
+ clf_logits = clf_head (h , XMB )
91
+ clf_losses = compute_loss (XMB , YMB , MMB , clf_logits , only_return_losses = True )
92
+ res = (clf_logits .numpy ()* n , clf_losses .numpy ()* n )
93
+ results .append (res )
94
+ results = zip (* results )
95
+ return [fn (res ) for res , fn in zip (results , fns )]
91
96
92
- # def iter_predict(Xs, Ms):
93
- # logits = []
94
- # for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
95
- # n = len(xmb)
96
- # if n == n_batch_train:
97
- # logits.append(sess.run(eval_mgpu_logits, {X_train:xmb, M_train:mmb}))
98
- # else:
99
- # logits.append(sess.run(eval_logits, {X:xmb, M:mmb}))
100
- # logits = np.concatenate(logits, 0)
101
- # return logits
97
+ def iter_predict (Xs , Ms ):
98
+ logits = []
99
+ with torch .no_grad ():
100
+ model .eval ()
101
+ for xmb , mmb in iter_data (Xs , Ms , n_batch = n_batch_train , truncate = False , verbose = True ):
102
+ n = len (xmb )
103
+ XMB = torch .tensor (xmb , dtype = torch .long ).to (device )
104
+ MMB = torch .tensor (mmb ).to (device )
105
+ h = model (XMB )
106
+ clf_logits = clf_head (h , XMB )
107
+ logits .append (clf_logits .numpy ())
108
+ logits = np .concatenate (logits , 0 )
109
+ return logits
102
110
103
- # def log():
104
- # global best_score
105
- # tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
106
- # va_logits, va_cost = iter_apply(vaX, vaM, vaY)
107
- # tr_cost = tr_cost/len(trY[:n_valid])
108
- # va_cost = va_cost/n_valid
109
- # tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
110
- # va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
111
- # logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
112
- # print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
113
- # if submit:
114
- # score = va_acc
115
- # if score > best_score:
116
- # best_score = score
117
- # save(os.path.join(save_dir, desc, 'best_params.jl'))
111
+ def log ():
112
+ global best_score
113
+ tr_logits , tr_cost = iter_apply (trX [:n_valid ], trM [:n_valid ], trY [:n_valid ])
114
+ va_logits , va_cost = iter_apply (vaX , vaM , vaY )
115
+ tr_cost = tr_cost / len (trY [:n_valid ])
116
+ va_cost = va_cost / n_valid
117
+ tr_acc = accuracy_score (trY [:n_valid ], np .argmax (tr_logits , 1 ))* 100.
118
+ va_acc = accuracy_score (vaY , np .argmax (va_logits , 1 ))* 100.
119
+ logger .log (n_epochs = n_epochs , n_updates = n_updates , tr_cost = tr_cost , va_cost = va_cost , tr_acc = tr_acc , va_acc = va_acc )
120
+ print ('%d %d %.3f %.3f %.2f %.2f' % (n_epochs , n_updates , tr_cost , va_cost , tr_acc , va_acc ))
121
+ if submit :
122
+ score = va_acc
123
+ if score > best_score :
124
+ best_score = score
125
+ path = os .path .join (save_dir , desc , 'best_params' )
126
+ torch .save (model .state_dict (), make_path (path ))
127
+
128
+ def predict ():
129
+ filename = filenames [dataset ]
130
+ pred_fn = pred_fns [dataset ]
131
+ label_decoder = label_decoders [dataset ]
132
+ predictions = pred_fn (iter_predict (teX , teM ))
133
+ if label_decoder is not None :
134
+ predictions = [label_decoder [prediction ] for prediction in predictions ]
135
+ path = os .path .join (submission_dir , filename )
136
+ os .makedirs (os .path .dirname (path ), exist_ok = True )
137
+ with open (path , 'w' ) as f :
138
+ f .write ('{}\t {}\n ' .format ('index' , 'prediction' ))
139
+ for i , prediction in enumerate (predictions ):
140
+ f .write ('{}\t {}\n ' .format (i , prediction ))
118
141
119
- # def predict():
120
- # filename = filenames[dataset]
121
- # pred_fn = pred_fns[dataset]
122
- # label_decoder = label_decoders[dataset]
123
- # predictions = pred_fn(iter_predict(teX, teM))
124
- # if label_decoder is not None:
125
- # predictions = [label_decoder[prediction] for prediction in predictions]
126
- # path = os.path.join(submission_dir, filename)
127
- # os.makedirs(os.path.dirname(path), exist_ok=True)
128
- # with open(path, 'w') as f:
129
- # f.write('{}\t{}\n'.format('index', 'prediction'))
130
- # for i, prediction in enumerate(predictions):
131
- # f.write('{}\t{}\n'.format(i, prediction))
142
+ def run_epoch ():
143
+ for xmb , mmb , ymb in iter_data (* shuffle (trX , trM , trYt , random_state = np .random ),
144
+ n_batch = n_batch_train , truncate = True , verbose = True ):
145
+ global n_updates
146
+ XMB = torch .tensor (xmb , dtype = torch .long ).to (device )
147
+ YMB = torch .tensor (ymb , dtype = torch .long ).to (device )
148
+ MMB = torch .tensor (mmb ).to (device )
149
+ model .train ()
150
+ h = model (XMB )
151
+ lm_logits = lm_head (h )
152
+ clf_logits = clf_head (h , XMB )
153
+ compute_loss (XMB , YMB , MMB , clf_logits , lm_logits )
154
+ n_updates += 1
155
+ if n_updates in [1000 , 2000 , 4000 , 8000 , 16000 , 32000 ] and n_epochs == 0 :
156
+ log ()
132
157
133
158
argmax = lambda x :np .argmax (x , 1 )
134
159
@@ -235,7 +260,6 @@ def transform_roc(X1, X2, X3):
235
260
max_grad_norm = max_grad_norm )
236
261
237
262
compute_loss = LossCompute (criterion , criterion , lm_coef , model_opt )
238
- # TODO Initialize model (?)
239
263
# TODO add train() and eval()
240
264
load_openai_pretrained_model (model , n_ctx , n_special , args )
241
265
@@ -250,26 +274,16 @@ def transform_roc(X1, X2, X3):
250
274
if submit :
251
275
path = os .path .join (save_dir , desc , 'best_params' )
252
276
torch .save (model .state_dict (), make_path (path ))
253
-
254
277
best_score = 0
278
+ log ()
255
279
for i in range (n_iter ):
256
- for xmb , mmb , ymb in iter_data (* shuffle (trX , trM , trYt , random_state = np .random ),
257
- n_batch = n_batch_train , truncate = True , verbose = True ):
258
- XMB = torch .tensor (xmb , dtype = torch .long ).to (device )
259
- YMB = torch .tensor (ymb , dtype = torch .long ).to (device )
260
- MMB = torch .tensor (mmb ).to (device )
261
- model .train ()
262
- h = model (XMB )
263
- lm_logits = lm_head (h )
264
- clf_logits = clf_head (h , XMB )
265
- loss = compute_loss (XMB , YMB , MMB , lm_logits , clf_logits )
266
- n_updates += 1
267
- #if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
268
- # log()
280
+ run_epoch ()
269
281
n_epochs += 1
270
- # log()
271
- # if submit:
272
- # sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))])
273
- # predict()
274
- # if analysis:
275
- # rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), os.path.join(log_dir, 'rocstories.jsonl'))
282
+ log ()
283
+ if submit :
284
+ path = os .path .join (save_dir , desc , 'best_params' )
285
+ model .load_state_dict (torch .load (path ))
286
+ predict ()
287
+ if analysis :
288
+ rocstories_analysis (data_dir , os .path .join (submission_dir , 'ROCStories.tsv' ),
289
+ os .path .join (log_dir , 'rocstories.jsonl' ))
0 commit comments