Skip to content

Commit 028b94c

Browse files
authored
Fix get ppl (InternLM#3268)
* fix get_ppl * remove useless code * remove debug logs
1 parent d95ecc0 commit 028b94c

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

lmdeploy/serve/utils.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def get_reward_score(self, input_ids: List) -> List[float]:
3535
assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)
3636
# Make input_ids a list of token_id list
3737
input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids
38-
logits = self._run(coro=self._async_get_logits(input_ids=input_ids)).result()
38+
logits, session_ids = self._run(coro=self._async_get_logits(input_ids=input_ids)).result()
3939
logits = [x.squeeze() for x in logits]
4040
scores = [x[-1].cpu().item() for x in logits]
41+
for session_id in session_ids:
42+
self.end_session(session_id)
4143
return scores
4244

4345
async def _async_get_logits(self,
@@ -69,13 +71,12 @@ async def _proc(i):
6971
async for outputs in gen:
7072
pass
7173
logits[i] = outputs.logits[:input_len, :]
72-
if sequence_end and self.backend == 'pytorch':
73-
await inst.async_end(session_id=i)
7474

75+
session_ids = list(range(len(input_ids)))
7576
tasks = [_proc(i) for i in range(len(input_ids))]
7677
await asyncio.gather(*tasks)
7778

78-
return logits
79+
return logits, session_ids
7980

8081
def get_ppl(self, input_ids: Union[List[int], List[List[int]]]) -> List[float]:
8182
"""Get perplexity scores given a list of input tokens that have to be
@@ -108,15 +109,17 @@ def get_ppl(self, input_ids: Union[List[int], List[List[int]]]) -> List[float]:
108109
logger.info(f'start: {start}, end: {end}')
109110
if start == end:
110111
_input_ids = input_ids[indices[start]]
111-
res = self._get_long_text_ppl(input_ids=_input_ids, max_input_len=max_input_len)
112+
res, session_ids = self._get_long_text_ppl(input_ids=_input_ids, max_input_len=max_input_len)
112113
result.append(res)
113114
else:
114115
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
115-
res = self._get_ppl(
116+
res, session_ids = self._get_ppl(
116117
input_ids=_input_ids,
117118
max_input_len=max_input_len,
118119
)
119120
result.extend(res)
121+
for session_id in session_ids:
122+
self.end_session(session_id)
120123
output = list(range(len(result)))
121124
for index, sorted_index in enumerate(indices):
122125
output[sorted_index] = result[index]
@@ -152,23 +155,24 @@ def _get_long_text_ppl(self, input_ids, max_input_len):
152155

153156
losses = []
154157
target_counts = []
158+
session_ids = []
155159
for i in range(0, seq_len, max_input_len):
156160
token_ids = input_ids[i:i + max_input_len]
157161
step = [i]
158162
# shift token_ids by 1 to the left
159163
target_ids = input_ids[i + 1:i + 1 + max_input_len]
160-
161-
loss, target_count = self._get_ppl(input_ids=[token_ids],
162-
max_input_len=max_input_len,
163-
target_ids=[target_ids],
164-
steps=step,
165-
sequence_start=(i == 0),
166-
sequence_end=(i + max_input_len >= seq_len))
164+
loss, session_ids = self._get_ppl(input_ids=[token_ids],
165+
max_input_len=len(token_ids),
166+
target_ids=[target_ids],
167+
steps=step,
168+
sequence_start=(i == 0),
169+
sequence_end=False)
167170
losses.extend(loss)
168-
target_counts.extend(target_count)
171+
target_counts.append(len(target_ids))
172+
losses = [loss * target_count for loss, target_count in zip(losses, target_counts)]
169173
loss_sum = sum(losses)
170174
target_count = sum(target_counts)
171-
return loss_sum / target_count
175+
return loss_sum / target_count, session_ids
172176

173177
def _get_ppl(self,
174178
input_ids,
@@ -186,10 +190,10 @@ def _get_ppl(self,
186190
assert sum(lens) <= max_input_len
187191

188192
logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
189-
f'total_len: {total_len}')
193+
f'total_len: {total_len}, steps: {steps}')
190194
torch.cuda.empty_cache()
191195

192-
logits = self._run(coro=self._async_get_logits(
196+
logits, session_ids = self._run(coro=self._async_get_logits(
193197
input_ids=input_ids, steps=steps, sequence_start=sequence_start, sequence_end=sequence_end)).result()
194198
padding_token_id = -100
195199
if target_ids is None:
@@ -218,4 +222,4 @@ def _get_ppl(self,
218222
target_count = target_mask.sum()
219223
result.append(loss.item() / target_count.item())
220224
logger.info(f'ppl result: {result}')
221-
return result
225+
return result, session_ids

0 commit comments

Comments
 (0)