@@ -35,9 +35,11 @@ def get_reward_score(self, input_ids: List) -> List[float]:
35
35
assert all (isinstance (x , int ) for x in input_ids ) or all (isinstance (x , List ) for x in input_ids )
36
36
# Make input_ids a list of token_id list
37
37
input_ids = [input_ids ] if isinstance (input_ids [0 ], int ) else input_ids
38
- logits = self ._run (coro = self ._async_get_logits (input_ids = input_ids )).result ()
38
+ logits , session_ids = self ._run (coro = self ._async_get_logits (input_ids = input_ids )).result ()
39
39
logits = [x .squeeze () for x in logits ]
40
40
scores = [x [- 1 ].cpu ().item () for x in logits ]
41
+ for session_id in session_ids :
42
+ self .end_session (session_id )
41
43
return scores
42
44
43
45
async def _async_get_logits (self ,
@@ -69,13 +71,12 @@ async def _proc(i):
69
71
async for outputs in gen :
70
72
pass
71
73
logits [i ] = outputs .logits [:input_len , :]
72
- if sequence_end and self .backend == 'pytorch' :
73
- await inst .async_end (session_id = i )
74
74
75
+ session_ids = list (range (len (input_ids )))
75
76
tasks = [_proc (i ) for i in range (len (input_ids ))]
76
77
await asyncio .gather (* tasks )
77
78
78
- return logits
79
+ return logits , session_ids
79
80
80
81
def get_ppl (self , input_ids : Union [List [int ], List [List [int ]]]) -> List [float ]:
81
82
"""Get perplexity scores given a list of input tokens that have to be
@@ -108,15 +109,17 @@ def get_ppl(self, input_ids: Union[List[int], List[List[int]]]) -> List[float]:
108
109
logger .info (f'start: { start } , end: { end } ' )
109
110
if start == end :
110
111
_input_ids = input_ids [indices [start ]]
111
- res = self ._get_long_text_ppl (input_ids = _input_ids , max_input_len = max_input_len )
112
+ res , session_ids = self ._get_long_text_ppl (input_ids = _input_ids , max_input_len = max_input_len )
112
113
result .append (res )
113
114
else :
114
115
_input_ids = [input_ids [indices [i ]] for i in range (start , end )]
115
- res = self ._get_ppl (
116
+ res , session_ids = self ._get_ppl (
116
117
input_ids = _input_ids ,
117
118
max_input_len = max_input_len ,
118
119
)
119
120
result .extend (res )
121
+ for session_id in session_ids :
122
+ self .end_session (session_id )
120
123
output = list (range (len (result )))
121
124
for index , sorted_index in enumerate (indices ):
122
125
output [sorted_index ] = result [index ]
@@ -152,23 +155,24 @@ def _get_long_text_ppl(self, input_ids, max_input_len):
152
155
153
156
losses = []
154
157
target_counts = []
158
+ session_ids = []
155
159
for i in range (0 , seq_len , max_input_len ):
156
160
token_ids = input_ids [i :i + max_input_len ]
157
161
step = [i ]
158
162
# shift token_ids by 1 to the left
159
163
target_ids = input_ids [i + 1 :i + 1 + max_input_len ]
160
-
161
- loss , target_count = self ._get_ppl (input_ids = [token_ids ],
162
- max_input_len = max_input_len ,
163
- target_ids = [target_ids ],
164
- steps = step ,
165
- sequence_start = (i == 0 ),
166
- sequence_end = (i + max_input_len >= seq_len ))
164
+ loss , session_ids = self ._get_ppl (input_ids = [token_ids ],
165
+ max_input_len = len (token_ids ),
166
+ target_ids = [target_ids ],
167
+ steps = step ,
168
+ sequence_start = (i == 0 ),
169
+ sequence_end = False )
167
170
losses .extend (loss )
168
- target_counts .extend (target_count )
171
+ target_counts .append (len (target_ids ))
172
+ losses = [loss * target_count for loss , target_count in zip (losses , target_counts )]
169
173
loss_sum = sum (losses )
170
174
target_count = sum (target_counts )
171
- return loss_sum / target_count
175
+ return loss_sum / target_count , session_ids
172
176
173
177
def _get_ppl (self ,
174
178
input_ids ,
@@ -186,10 +190,10 @@ def _get_ppl(self,
186
190
assert sum (lens ) <= max_input_len
187
191
188
192
logger .info (f'get_ppl: bs: { len (input_ids )} , lens: { lens } , '
189
- f'total_len: { total_len } ' )
193
+ f'total_len: { total_len } , steps: { steps } ' )
190
194
torch .cuda .empty_cache ()
191
195
192
- logits = self ._run (coro = self ._async_get_logits (
196
+ logits , session_ids = self ._run (coro = self ._async_get_logits (
193
197
input_ids = input_ids , steps = steps , sequence_start = sequence_start , sequence_end = sequence_end )).result ()
194
198
padding_token_id = - 100
195
199
if target_ids is None :
@@ -218,4 +222,4 @@ def _get_ppl(self,
218
222
target_count = target_mask .sum ()
219
223
result .append (loss .item () / target_count .item ())
220
224
logger .info (f'ppl result: { result } ' )
221
- return result
225
+ return result , session_ids
0 commit comments