@@ -245,15 +245,15 @@ def get_loss(model, batch, global_step): # make sure loss is tensor
245
245
logits_lm , logits_clsf = model (input_ids , segment_ids , input_mask , masked_pos )
246
246
loss_lm = criterion1 (logits_lm .transpose (1 , 2 ), masked_ids ) # for masked LM
247
247
loss_lm = (loss_lm * masked_weights .float ()).mean ()
248
- loss_clsf = criterion2 (logits_clsf , is_next ) # for sentence classification
248
+ loss_sop = criterion2 (logits_clsf , is_next ) # for sentence classification
249
249
writer .add_scalars ('data/scalar_group' ,
250
250
{'loss_lm' : loss_lm .item (),
251
- 'loss_clsf ' : loss_clsf .item (),
252
- 'loss_total' : (loss_lm + loss_clsf ).item (),
251
+ 'loss_sop ' : loss_sop .item (),
252
+ 'loss_total' : (loss_lm + loss_sop ).item (),
253
253
'lr' : optimizer .get_lr ()[0 ],
254
254
},
255
255
global_step )
256
- return loss_lm + loss_clsf
256
+ return loss_lm + loss_sop
257
257
258
258
trainer .train (get_loss , model_file = None , data_parallel = True )
259
259
@@ -266,7 +266,7 @@ def get_loss(model, batch, global_step): # make sure loss is tensor
266
266
parser .add_argument ('--model_cfg' , type = str , default = './config/albert_unittest.json' )
267
267
268
268
# official google-reacher/bert is use 20, but 20/512(=seq_len)*100 make only 3% Mask
269
- # So, official XLNET zihangdai/xlnet use 85 with name of num_predict(SAME HERE!)
269
+ # So, using 76(=0.15*512) as `max_pred`
270
270
parser .add_argument ('--max_pred' , type = int , default = 76 , help = 'max tokens of prediction' )
271
271
parser .add_argument ('--mask_prob' , type = float , default = 0.15 , help = 'masking probability' )
272
272
0 commit comments