Skip to content

Commit 0969cc6

Browse files
committed
solve Infinite loop when masking
1 parent 757fd6d commit 0969cc6

File tree

6 files changed

+41
-18
lines changed

6 files changed

+41
-18
lines changed

models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
# Copyright 2019 Tae Hwan Jung(@graykode)
2-
# forked from https://github.com/dhlee347/pytorchic-bert
3-
# (Strongly inspired by original Google BERT code and Hugging Face's code)
4-
# Remove All Dropout
1+
"""
2+
Copyright 2019 Tae Hwan Jung
3+
ALBERT Implementation with forking
4+
Clean Pytorch Code from https://github.com/dhlee347/pytorchic-bert
5+
"""
56

67
""" Transformer Model Classes & Config Class """
78

optim.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team,
2-
# and Dong-Hyun Lee, Kakao Brain.
1+
"""
2+
Copyright 2019 Tae Hwan Jung
3+
ALBERT Implementation with forking
4+
Clean Pytorch Code from https://github.com/dhlee347/pytorchic-bert
5+
"""
36

47
""" a slightly modified version of Hugging Face's BERTAdam class """
58

pretrain.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
import optim
1919
import train
2020

21-
from utils import set_seeds, get_device, get_random_word, truncate_tokens_pair, \
22-
_is_start_piece, _sample_mask
21+
from utils import set_seeds, get_device, truncate_tokens_pair, _sample_mask
2322

2423
# Input file format :
2524
# 1. One sentence per line. These should ideally be actual sentences,
@@ -113,11 +112,12 @@ def __call__(self, instance):
113112

114113
class Preprocess4Pretrain(Pipeline):
115114
""" Pre-processing steps for pretraining transformer """
116-
def __init__(self, max_pred, vocab_words, indexer, max_len,
115+
def __init__(self, max_pred, mask_prob, vocab_words, indexer, max_len,
117116
mask_alpha, mask_beta, max_gram):
118117
super().__init__()
119118
self.max_len = max_len
120119
self.max_pred = max_pred # max tokens of prediction
120+
self.mask_prob = mask_prob # masking probability
121121
self.vocab_words = vocab_words # vocabulary (sub)words
122122

123123
self.indexer = indexer # function from token to token index
@@ -137,9 +137,13 @@ def __call__(self, instance):
137137
segment_ids = [0]*(len(tokens_a)+2) + [1]*(len(tokens_b)+1)
138138
input_mask = [1]*len(tokens)
139139

140+
# the number of prediction is sometimes less than max_pred when sequence is short
141+
n_pred = min(self.max_pred, max(1, int(round(len(tokens) * self.mask_prob))))
142+
140143
# For masked Language Models
141144
masked_tokens, masked_pos, tokens = _sample_mask(tokens, self.mask_alpha,
142-
self.mask_beta, self.max_gram, self.max_pred)
145+
self.mask_beta, self.max_gram,
146+
goal_num_predict=n_pred)
143147

144148
masked_weights = [1]*len(masked_tokens)
145149

@@ -213,6 +217,7 @@ def main(args):
213217
tokenize = lambda x: tokenizer.tokenize(tokenizer.convert_to_unicode(x))
214218

215219
pipeline = [Preprocess4Pretrain(args.max_pred,
220+
args.mask_prob,
216221
list(tokenizer.vocab.keys()),
217222
tokenizer.convert_tokens_to_ids,
218223
model_cfg.max_len,
@@ -262,11 +267,12 @@ def get_loss(model, batch, global_step): # make sure loss is tensor
262267

263268
# official google-reacher/bert is use 20, but 20/512(=seq_len)*100 make only 3% Mask
264269
# So, official XLNET zihangdai/xlnet use 85 with name of num_predict(SAME HERE!)
265-
parser.add_argument('--max_pred', type=int, default=85)
270+
parser.add_argument('--max_pred', type=int, default=76, help='max tokens of prediction')
271+
parser.add_argument('--mask_prob', type=float, default=0.15, help='masking probability')
266272

267273
# try to n-gram masking SpanBERT(Joshi et al., 2019)
268274
parser.add_argument('--mask_alpha', type=int,
269-
default=6, help="How many tokens to form a group.")
275+
default=4, help="How many tokens to form a group.")
270276
parser.add_argument('--mask_beta', type=int,
271277
default=1, help="How many tokens to mask within each group.")
272278
parser.add_argument('--max_gram', type=int,

tokenization.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
2-
3-
""" Tokenization classes (It's exactly the same code as Google BERT code """
1+
"""
2+
Copyright 2019 Tae Hwan Jung
3+
ALBERT Implementation with forking
4+
Clean Pytorch Code from https://github.com/dhlee347/pytorchic-bert
5+
"""
46

57
from __future__ import absolute_import
68
from __future__ import division

train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# Copyright 2018 Dong-Hyun Lee, Kakao Brain.
1+
"""
2+
Copyright 2019 Tae Hwan Jung
3+
ALBERT Implementation with forking
4+
Clean Pytorch Code from https://github.com/dhlee347/pytorchic-bert
5+
"""
26

37
""" Training Config & Helper Classes """
48

utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# Copyright 2018 Dong-Hyun Lee, Kakao Brain.
1+
"""
2+
Copyright 2019 Tae Hwan Jung
3+
ALBERT Implementation with forking
4+
Clean Pytorch Code from https://github.com/dhlee347/pytorchic-bert
5+
"""
26

37
""" Utils Functions """
48

@@ -114,12 +118,16 @@ def _sample_mask(seg, mask_alpha, mask_beta,
114118
pvals /= pvals.sum(keepdims=True) # p(n) = 1/n / sigma(1/k)
115119

116120
cur_len = 0
121+
117122
while cur_len < seg_len:
118123
if goal_num_predict is not None and num_predict >= goal_num_predict: break
119124

120125
n = np.random.choice(ngrams, p=pvals)
121126
if goal_num_predict is not None:
122127
n = min(n, goal_num_predict - num_predict)
128+
129+
# `mask_alpha` : number of tokens forming group
130+
# `mask_beta` : number of tokens to be masked in each groups.
123131
ctx_size = (n * mask_alpha) // mask_beta
124132
l_ctx = np.random.choice(ctx_size)
125133
r_ctx = ctx_size - l_ctx
@@ -164,5 +172,4 @@ def _sample_mask(seg, mask_alpha, mask_beta,
164172
tokens.append('[MASK]')
165173
else:
166174
tokens.append(seg[i])
167-
168175
return masked_tokens, masked_pos, tokens

0 commit comments

Comments
 (0)