Skip to content

Commit

Permalink
fix qwen2 vl crash in continous batching (#3004)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Feb 20, 2025
1 parent ed96ba6 commit 06dfe9a
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,16 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
adapter_segment_builder = None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
if (
batches[0].position_ids is not None
and batches[0].position_ids.dim() == 2
):
# Qwen2_vl case:
position_ids = batches[0].position_ids.new_empty(
(total_batch_size, batches[0].position_ids.shape[-1])
)
else:
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
Expand Down Expand Up @@ -2133,7 +2142,11 @@ def generate_token(
if not prefill or (prefill and finished_prefilling):
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
batch.position_ids += accepted_ids
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids
Expand Down

0 comments on commit 06dfe9a

Please sign in to comment.