Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get ppl #3268

Merged
merged 3 commits into from
Mar 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def get_reward_score(self, input_ids: List) -> List[float]:
assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)
# Make input_ids a list of token_id list
input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids
logits = self._run(coro=self._async_get_logits(input_ids=input_ids)).result()
logits, session_ids = self._run(coro=self._async_get_logits(input_ids=input_ids)).result()
logits = [x.squeeze() for x in logits]
scores = [x[-1].cpu().item() for x in logits]
for session_id in session_ids:
self.end_session(session_id)
return scores

async def _async_get_logits(self,
Expand Down Expand Up @@ -69,13 +71,12 @@ async def _proc(i):
async for outputs in gen:
pass
logits[i] = outputs.logits[:input_len, :]
if sequence_end and self.backend == 'pytorch':
await inst.async_end(session_id=i)

session_ids = list(range(len(input_ids)))
tasks = [_proc(i) for i in range(len(input_ids))]
await asyncio.gather(*tasks)

return logits
return logits, session_ids

def get_ppl(self, input_ids: Union[List[int], List[List[int]]]) -> List[float]:
"""Get perplexity scores given a list of input tokens that have to be
Expand Down Expand Up @@ -108,15 +109,17 @@ def get_ppl(self, input_ids: Union[List[int], List[List[int]]]) -> List[float]:
logger.info(f'start: {start}, end: {end}')
if start == end:
_input_ids = input_ids[indices[start]]
res = self._get_long_text_ppl(input_ids=_input_ids, max_input_len=max_input_len)
res, session_ids = self._get_long_text_ppl(input_ids=_input_ids, max_input_len=max_input_len)
result.append(res)
else:
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
res = self._get_ppl(
res, session_ids = self._get_ppl(
input_ids=_input_ids,
max_input_len=max_input_len,
)
result.extend(res)
for session_id in session_ids:
self.end_session(session_id)
output = list(range(len(result)))
for index, sorted_index in enumerate(indices):
output[sorted_index] = result[index]
Expand Down Expand Up @@ -152,23 +155,24 @@ def _get_long_text_ppl(self, input_ids, max_input_len):

losses = []
target_counts = []
session_ids = []
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[i:i + max_input_len]
step = [i]
# shift token_ids by 1 to the left
target_ids = input_ids[i + 1:i + 1 + max_input_len]

loss, target_count = self._get_ppl(input_ids=[token_ids],
max_input_len=max_input_len,
target_ids=[target_ids],
steps=step,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= seq_len))
loss, session_ids = self._get_ppl(input_ids=[token_ids],
max_input_len=len(token_ids),
target_ids=[target_ids],
steps=step,
sequence_start=(i == 0),
sequence_end=False)
losses.extend(loss)
target_counts.extend(target_count)
target_counts.append(len(target_ids))
losses = [loss * target_count for loss, target_count in zip(losses, target_counts)]
loss_sum = sum(losses)
target_count = sum(target_counts)
return loss_sum / target_count
return loss_sum / target_count, session_ids

def _get_ppl(self,
input_ids,
Expand All @@ -186,10 +190,10 @@ def _get_ppl(self,
assert sum(lens) <= max_input_len

logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
f'total_len: {total_len}')
f'total_len: {total_len}, steps: {steps}')
torch.cuda.empty_cache()

logits = self._run(coro=self._async_get_logits(
logits, session_ids = self._run(coro=self._async_get_logits(
input_ids=input_ids, steps=steps, sequence_start=sequence_start, sequence_end=sequence_end)).result()
padding_token_id = -100
if target_ids is None:
Expand Down Expand Up @@ -218,4 +222,4 @@ def _get_ppl(self,
target_count = target_mask.sum()
result.append(loss.item() / target_count.item())
logger.info(f'ppl result: {result}')
return result
return result, session_ids