Skip to content

Commit ef45263

Browse files
committed
delete checkpoint and variable names
1 parent 9a18495 commit ef45263

File tree

2 files changed

+5
-71
lines changed

2 files changed

+5
-71
lines changed

checkpoint.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

pretrain.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,15 @@ def get_loss(model, batch, global_step): # make sure loss is tensor
245245
logits_lm, logits_clsf = model(input_ids, segment_ids, input_mask, masked_pos)
246246
loss_lm = criterion1(logits_lm.transpose(1, 2), masked_ids) # for masked LM
247247
loss_lm = (loss_lm*masked_weights.float()).mean()
248-
loss_clsf = criterion2(logits_clsf, is_next) # for sentence classification
248+
loss_sop = criterion2(logits_clsf, is_next) # for sentence classification
249249
writer.add_scalars('data/scalar_group',
250250
{'loss_lm': loss_lm.item(),
251-
'loss_clsf': loss_clsf.item(),
252-
'loss_total': (loss_lm + loss_clsf).item(),
251+
'loss_sop': loss_sop.item(),
252+
'loss_total': (loss_lm + loss_sop).item(),
253253
'lr': optimizer.get_lr()[0],
254254
},
255255
global_step)
256-
return loss_lm + loss_clsf
256+
return loss_lm + loss_sop
257257

258258
trainer.train(get_loss, model_file=None, data_parallel=True)
259259

@@ -266,7 +266,7 @@ def get_loss(model, batch, global_step): # make sure loss is tensor
266266
parser.add_argument('--model_cfg', type=str, default='./config/albert_unittest.json')
267267

268268
# official google-reacher/bert is use 20, but 20/512(=seq_len)*100 make only 3% Mask
269-
# So, official XLNET zihangdai/xlnet use 85 with name of num_predict(SAME HERE!)
269+
# So, using 76(=0.15*512) as `max_pred`
270270
parser.add_argument('--max_pred', type=int, default=76, help='max tokens of prediction')
271271
parser.add_argument('--mask_prob', type=float, default=0.15, help='masking probability')
272272

0 commit comments

Comments
 (0)