Skip to content

Commit 533bb59

Browse files
author
Ubuntu
committed
Code compatible with transformers==3.1.0. Freezing some libraries in requirements.txt, ready for Release 0.2
1 parent 83c66c4 commit 533bb59

File tree

3 files changed

+12
-15
lines changed

3 files changed

+12
-15
lines changed

coverage.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
class KeywordExtractor():
1313
def __init__(self, n_kws=15):
1414
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
15-
self.tokenizer.max_len = 10000
1615
self.n_kws = n_kws
1716

1817
self.bert_w2i = {w: i for i, w in enumerate(self.tokenizer.vocab)}
@@ -70,7 +69,6 @@ def extract_keywords(self, unmasked):
7069
class KeywordCoverage():
7170
def __init__(self, device, keyword_model_file, model_file=None, n_kws=15):
7271
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
73-
self.tokenizer.max_len = 10000
7472
self.vocab_size = self.tokenizer.vocab_size
7573
self.n_kws = n_kws
7674

model_generator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def train_batch(self, bodies, summaries, special_append=None, no_preinput=False)
7373
inputs, summ_inp, summ_out = self.preprocess_batch(bodies, summaries, special_append)
7474
past = None
7575
if not no_preinput:
76-
_, past = self.model(input_ids=inputs, past=None)
77-
logits, _ = self.model(input_ids=summ_inp, past=past)
76+
_, past = self.model(input_ids=inputs, past_key_values=None)
77+
logits, _ = self.model(input_ids=summ_inp, past_key_values=past)
7878
crit = torch.nn.CrossEntropyLoss(ignore_index=-1)
7979
loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.contiguous().view(-1))
8080
return loss
@@ -97,11 +97,11 @@ def decode_batch(self, bodies, special_append=None, max_output_length=100, sampl
9797
# Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
9898
if input_past is None:
9999
inputs = self.preprocess_input(bodies, special_append)
100-
_, input_past = self.model(input_ids=inputs, past=None)
100+
_, input_past = self.model(input_ids=inputs, past_key_values=None)
101101

102102
past = input_past
103103
while build_up is None or (build_up.shape[1] < max_output_length and not all([self.tokenizer.end_id in build for build in build_up])):
104-
logits, past = self.model(input_ids=current, past=past)
104+
logits, past = self.model(input_ids=current, past_key_values=past)
105105
probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1)
106106
logprobs = torch.nn.functional.log_softmax(logits, dim=2)
107107
if sample:
@@ -149,12 +149,12 @@ def decode_beam_batch(self, bodies, beam_size=3, max_output_length=100, sample=F
149149
one_every_k = torch.FloatTensor([1] + [0] * (beam_size-1)).repeat(batch_size*beam_size).to(self.device)
150150

151151
# Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
152-
_, input_past = self.model(input_ids=inputs, past=None)
152+
_, input_past = self.model(input_ids=inputs, past_key_values=None)
153153
input_past = [torch.repeat_interleave(p, repeats=beam_size, dim=1) for p in input_past]
154154

155155
past = input_past
156156
while build_up is None or (build_up.shape[1] < max_output_length and not all([self.tokenizer.end_id in build for build in build_up])):
157-
logits, past = self.model(input_ids=next_words, past=past)
157+
logits, past = self.model(input_ids=next_words, past_key_values=past)
158158
probs = torch.nn.functional.softmax(logits, dim=2).squeeze(1)
159159
logprobs = torch.nn.functional.log_softmax(logits, dim=2)
160160

@@ -254,7 +254,7 @@ def score(self, summaries, bodies, bodies_tokenized=None, lengths=None, extra=No
254254
summ_out = summ_out.contiguous()
255255

256256
with torch.no_grad():
257-
logits, _ = self.model(input_ids=summ_inp, past=None)
257+
logits, _ = self.model(input_ids=summ_inp, past_key_values=None)
258258

259259
crit = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
260260
loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.view(-1)).view(summ_out.shape)
@@ -272,8 +272,8 @@ def score_pairs(self, bodies, summaries):
272272
inputs, summ_inp, summ_out = self.preprocess_batch(bodies, summaries)
273273

274274
with torch.no_grad():
275-
_, past = self.model(input_ids=inputs, past=None)
276-
logits, _ = self.model(input_ids=summ_inp, past=past)
275+
_, past = self.model(input_ids=inputs, past_key_values=None)
276+
logits, _ = self.model(input_ids=summ_inp, past_key_values=past)
277277

278278
crit = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
279279
loss = crit(logits.view(-1, self.tokenizer.vocab_size), summ_out.view(-1)).view(summ_out.shape)

requirements.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
transformers==3.0.2
2-
sklearn
3-
nltk
1+
transformers==3.1.0
2+
sklearn==0.22.1
3+
nltk==3.5
44
h5py
55
tqdm
66
matplotlib
7-
sklearn

0 commit comments

Comments
 (0)