@@ -73,8 +73,8 @@ def train_batch(self, bodies, summaries, special_append=None, no_preinput=False)
73
73
inputs , summ_inp , summ_out = self .preprocess_batch (bodies , summaries , special_append )
74
74
past = None
75
75
if not no_preinput :
76
- _ , past = self .model (input_ids = inputs , past = None )
77
- logits , _ = self .model (input_ids = summ_inp , past = past )
76
+ _ , past = self .model (input_ids = inputs , past_key_values = None )
77
+ logits , _ = self .model (input_ids = summ_inp , past_key_values = past )
78
78
crit = torch .nn .CrossEntropyLoss (ignore_index = - 1 )
79
79
loss = crit (logits .view (- 1 , self .tokenizer .vocab_size ), summ_out .contiguous ().view (- 1 ))
80
80
return loss
@@ -97,11 +97,11 @@ def decode_batch(self, bodies, special_append=None, max_output_length=100, sampl
97
97
# Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
98
98
if input_past is None :
99
99
inputs = self .preprocess_input (bodies , special_append )
100
- _ , input_past = self .model (input_ids = inputs , past = None )
100
+ _ , input_past = self .model (input_ids = inputs , past_key_values = None )
101
101
102
102
past = input_past
103
103
while build_up is None or (build_up .shape [1 ] < max_output_length and not all ([self .tokenizer .end_id in build for build in build_up ])):
104
- logits , past = self .model (input_ids = current , past = past )
104
+ logits , past = self .model (input_ids = current , past_key_values = past )
105
105
probs = torch .nn .functional .softmax (logits , dim = 2 ).squeeze (1 )
106
106
logprobs = torch .nn .functional .log_softmax (logits , dim = 2 )
107
107
if sample :
@@ -149,12 +149,12 @@ def decode_beam_batch(self, bodies, beam_size=3, max_output_length=100, sample=F
149
149
one_every_k = torch .FloatTensor ([1 ] + [0 ] * (beam_size - 1 )).repeat (batch_size * beam_size ).to (self .device )
150
150
151
151
# Sometimes, we process the same input, as we run it once as a sampled, and once as an argmax, in which case we should reuse the computation
152
- _ , input_past = self .model (input_ids = inputs , past = None )
152
+ _ , input_past = self .model (input_ids = inputs , past_key_values = None )
153
153
input_past = [torch .repeat_interleave (p , repeats = beam_size , dim = 1 ) for p in input_past ]
154
154
155
155
past = input_past
156
156
while build_up is None or (build_up .shape [1 ] < max_output_length and not all ([self .tokenizer .end_id in build for build in build_up ])):
157
- logits , past = self .model (input_ids = next_words , past = past )
157
+ logits , past = self .model (input_ids = next_words , past_key_values = past )
158
158
probs = torch .nn .functional .softmax (logits , dim = 2 ).squeeze (1 )
159
159
logprobs = torch .nn .functional .log_softmax (logits , dim = 2 )
160
160
@@ -254,7 +254,7 @@ def score(self, summaries, bodies, bodies_tokenized=None, lengths=None, extra=No
254
254
summ_out = summ_out .contiguous ()
255
255
256
256
with torch .no_grad ():
257
- logits , _ = self .model (input_ids = summ_inp , past = None )
257
+ logits , _ = self .model (input_ids = summ_inp , past_key_values = None )
258
258
259
259
crit = torch .nn .CrossEntropyLoss (ignore_index = - 1 , reduction = 'none' )
260
260
loss = crit (logits .view (- 1 , self .tokenizer .vocab_size ), summ_out .view (- 1 )).view (summ_out .shape )
@@ -272,8 +272,8 @@ def score_pairs(self, bodies, summaries):
272
272
inputs , summ_inp , summ_out = self .preprocess_batch (bodies , summaries )
273
273
274
274
with torch .no_grad ():
275
- _ , past = self .model (input_ids = inputs , past = None )
276
- logits , _ = self .model (input_ids = summ_inp , past = past )
275
+ _ , past = self .model (input_ids = inputs , past_key_values = None )
276
+ logits , _ = self .model (input_ids = summ_inp , past_key_values = past )
277
277
278
278
crit = torch .nn .CrossEntropyLoss (ignore_index = - 1 , reduction = 'none' )
279
279
loss = crit (logits .view (- 1 , self .tokenizer .vocab_size ), summ_out .view (- 1 )).view (summ_out .shape )
0 commit comments