Skip to content

Commit 84bf004

Browse files
authored
Fix qwen2.5-omni use_audio_in_video (#3987)
1 parent cbea8ae commit 84bf004

File tree

8 files changed

+102
-48
lines changed

8 files changed

+102
-48
lines changed

requirements/framework.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ sentencepiece
2828
tensorboard
2929
tiktoken
3030
tqdm
31-
transformers>=4.33,<4.52
31+
transformers>=4.33,<4.53
3232
transformers_stream_generator
3333
trl>=0.13,<0.17
3434
uvicorn

swift/llm/model/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def update_attn_impl(config: PretrainedConfig,
3939
attn_impl_keys: Optional[List[str]] = None) -> None:
4040
if attn_impl is None:
4141
return
42+
logger.info(f'attn_impl: {attn_impl}')
4243
use_flash_attn = AttnImpl.to_use_flash_attn(attn_impl)
4344
if use_flash_attn:
4445
attn_impl = 'flash_attention_2'

swift/llm/template/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def _pre_tokenize_images(self, context_list: List[Context], loss_scale_list: Lis
739739
if context == '<image>' and inputs.is_multimodal and inputs.image_idx < len(inputs.images):
740740
c_list = self.replace_tag('image', inputs.image_idx, inputs)
741741
inputs.image_idx += 1
742-
loss_scale = 0.
742+
loss_scale = 0. if self.template_backend == 'swift' else 1.
743743
else:
744744
c_list = [context]
745745
res += c_list

swift/llm/template/template/gemma.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
109109
input_ids = encoded['input_ids']
110110
labels = encoded['labels']
111111
idx_list = findall(input_ids, self.boi_token_id)
112-
img_tokens = self.tokenizer.encode(self.processor.full_image_sequence)
112+
img_tokens = self._tokenize(self.processor.full_image_sequence)
113113
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
114114

115115
# TODO: customize

swift/llm/template/template/qwen.py

+84-29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from contextlib import contextmanager
32
from dataclasses import dataclass, field
43
from functools import partial
54
from typing import Any, Dict, List, Literal, Optional, Tuple
@@ -384,32 +383,50 @@ class Qwen2_5OmniTemplate(Qwen2_5VLTemplate):
384383
version = 'omni'
385384
placeholder_tokens = ['<|IMAGE|>', '<|AUDIO|>', '<|VIDEO|>']
386385

386+
def __init__(self, *args, **kwargs):
387+
super().__init__(*args, **kwargs)
388+
from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import Qwen2_5OmniProcessorKwargs
389+
default = Qwen2_5OmniProcessorKwargs._defaults
390+
self.seconds_per_chunk = default['videos_kwargs']['seconds_per_chunk']
391+
self.position_id_per_seconds = default['videos_kwargs']['position_id_per_seconds']
392+
self.use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
393+
self.sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
394+
387395
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
388396
inputs: StdTemplateInputs) -> List[Context]:
389397
from qwen_omni_utils import fetch_image, fetch_video
390-
sampling_rate = self.processor.feature_extractor.sampling_rate
391398
if media_type == 'image':
392399
inputs.images[index] = fetch_image({'image': inputs.images[index]})
393400
return ['<|vision_bos|><|IMAGE|><|vision_eos|>']
394401
elif media_type == 'audio':
395-
sampling_rate = get_env_args('sampling_rate', int, sampling_rate)
396-
inputs.audios[index] = load_audio(inputs.audios[index], sampling_rate)
402+
inputs.audios[index] = load_audio(inputs.audios[index], self.sampling_rate)
397403
return ['<|audio_bos|><|AUDIO|><|audio_eos|>']
398404
elif media_type == 'video':
399405
video = inputs.videos[index]
400406
inputs.videos[index] = fetch_video({'video': video}).to(torch.uint8)
401-
use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
402-
if use_audio_in_video:
407+
if self.use_audio_in_video:
403408
import librosa
404-
sampling_rate = get_env_args('sampling_rate', int, sampling_rate)
405-
video = librosa.load(video, sr=sampling_rate)[0]
406-
inputs.audios.insert(inputs.audio_idx, video)
409+
if video.startswith('http://') or video.startswith('https://'):
410+
import audioread
411+
video = audioread.ffdec.FFmpegAudioFile(video)
412+
video = librosa.load(video, sr=self.sampling_rate)[0]
413+
inputs.audios.insert(inputs.audio_idx, (video, 'video'))
407414
inputs.audio_idx += 1
415+
return ['<|vision_bos|><|audio_bos|><|VIDEO|><|audio_eos|><|vision_eos|>']
408416
return ['<|vision_bos|><|VIDEO|><|vision_eos|>']
409417

410418
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
411419
encoded = Template._encode(self, inputs)
412-
media_inputs = self.processor(
420+
processor = self.processor
421+
video_audios_mask = []
422+
for i, audio in enumerate(inputs.audios):
423+
if isinstance(audio, tuple) and audio[1] == 'video':
424+
inputs.audios[i] = audio[0]
425+
video_audios_mask.append(True)
426+
else:
427+
video_audios_mask.append(False)
428+
video_audios_mask = torch.tensor(video_audios_mask)
429+
media_inputs = processor(
413430
text='',
414431
audio=inputs.audios or None,
415432
images=inputs.images or None,
@@ -420,31 +437,70 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
420437
media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype)
421438
input_ids = encoded['input_ids']
422439
labels = encoded['labels']
440+
# audio
441+
audio_token_id = self._tokenize('<|AUDIO|>')
442+
idx_list = findall(input_ids, audio_token_id)
443+
feature_attention_mask = media_inputs.get('feature_attention_mask')
444+
if feature_attention_mask is not None:
445+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
446+
audio_lengths = (((audio_feature_lengths - 1) // 2 + 1 - 2) // 2 + 1)
447+
else:
448+
audio_lengths = None
449+
audio_lengths_origin = audio_lengths
450+
if idx_list:
451+
if self.use_audio_in_video:
452+
audio_lengths = audio_lengths[~video_audios_mask]
453+
454+
def _get_new_audio_tokens(i):
455+
return audio_token_id * audio_lengths[i]
456+
457+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens)
458+
423459
for media_type in ['image', 'video']:
424460
token = f'<|{media_type.upper()}|>'
425461
token_id = self._tokenize(token)
426462
idx_list = findall(input_ids, token_id)
427463
if idx_list:
428-
merge_length = self.processor.image_processor.merge_size**2
464+
merge_size = processor.image_processor.merge_size
429465
media_grid_thw = media_inputs.get(f'{media_type}_grid_thw')
466+
if media_type == 'video' and self.use_audio_in_video:
467+
audio_lengths = audio_lengths_origin[video_audios_mask]
468+
video_second_per_grid = media_inputs['video_second_per_grid']
469+
470+
def _get_new_tokens_use_audio_in_video(i):
471+
audio_token_indices = torch.arange(audio_lengths[i])
472+
grid_thw = media_grid_thw[i]
473+
height = grid_thw[1] // merge_size
474+
width = grid_thw[2] // merge_size
475+
video_token_indices = torch.arange(grid_thw[0]).reshape(-1, 1, 1)
476+
video_token_indices = torch.broadcast_to(
477+
video_token_indices, (video_token_indices.shape[0], height, width)).reshape(-1)
478+
video_token_indices = (
479+
video_token_indices * video_second_per_grid[i] * self.position_id_per_seconds)
480+
tokens_per_chunk = int(self.position_id_per_seconds * self.seconds_per_chunk)
481+
video_chunk_indexes = processor.get_chunked_index(video_token_indices, tokens_per_chunk)
482+
audio_chunk_indexes = processor.get_chunked_index(audio_token_indices, tokens_per_chunk)
483+
484+
res = []
485+
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
486+
if j < len(video_chunk_indexes):
487+
video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0]
488+
res += token_id * video_seq_length
489+
if j < len(audio_chunk_indexes):
490+
audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0]
491+
res += audio_token_id * audio_seq_length
492+
return res
493+
494+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list,
495+
_get_new_tokens_use_audio_in_video)
430496

431-
def _get_new_tokens(i):
432-
token_len = (media_grid_thw[i].prod() // merge_length)
433-
return token_id * token_len
434-
435-
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
436-
# audio
437-
feature_attention_mask = media_inputs.get('feature_attention_mask')
438-
if feature_attention_mask is not None:
439-
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1).tolist()
440-
token_id = self._tokenize('<|AUDIO|>')
441-
idx_list = findall(input_ids, token_id)
497+
else:
442498

443-
def _get_new_tokens(i):
444-
place_num = ((audio_feature_lengths[i] - 1) // 2 + 1 - 2) // 2 + 1
445-
return token_id * place_num
499+
def _get_new_tokens(i):
500+
token_len = (media_grid_thw[i].prod() // (merge_size**2))
501+
return token_id * token_len
446502

447-
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
503+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
448504

449505
encoded['input_ids'] = input_ids
450506
encoded['labels'] = labels
@@ -460,7 +516,6 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
460516
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
461517
else:
462518
audio_feature_lengths = None
463-
use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
464519
video_second_per_grid = inputs.pop('video_second_per_grid', None)
465520
input_ids = inputs['input_ids']
466521
attention_mask = inputs.get('attention_mask')
@@ -471,7 +526,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
471526
inputs.get('image_grid_thw'),
472527
inputs.get('video_grid_thw'),
473528
attention_mask,
474-
use_audio_in_video,
529+
self.use_audio_in_video,
475530
audio_feature_lengths,
476531
video_second_per_grid,
477532
)
@@ -493,7 +548,7 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
493548

494549
def generate(self, model, *args, **kwargs):
495550
if kwargs.get('video_grid_thw') is not None:
496-
kwargs['use_audio_in_video'] = get_env_args('use_audio_in_video', bool, False)
551+
kwargs['use_audio_in_video'] = self.use_audio_in_video
497552
return super().generate(model, *args, **kwargs)
498553

499554

tests/test_align/test_template/test_audio.py

+2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def test_step_audio_chat():
5656

5757

5858
def test_qwen2_5_omni():
59+
USE_AUDIO_IN_VIDEO = True
60+
os.environ['USE_AUDIO_IN_VIDEO'] = str(USE_AUDIO_IN_VIDEO)
5961
pt_engine = PtEngine('Qwen/Qwen2.5-Omni-7B')
6062
response = _infer_model(pt_engine)
6163
pt_engine.default_template.template_backend = 'jinja'

tests/test_align/test_template/test_video.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,9 @@ def test_qwen2_5_vl():
131131

132132

133133
def test_qwen2_5_omni():
134-
os.environ['VIDEO_MAX_PIXELS'] = str(28 * 28 * 64)
135-
USE_AUDIO_IN_VIDEO = False
134+
USE_AUDIO_IN_VIDEO = True
136135
os.environ['USE_AUDIO_IN_VIDEO'] = str(USE_AUDIO_IN_VIDEO)
137-
pt_engine = PtEngine('Qwen/Qwen2.5-Omni-7B')
136+
pt_engine = PtEngine('Qwen/Qwen2.5-Omni-7B', attn_impl='flash_attn')
138137
system = ('You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, '
139138
'capable of perceiving auditory and visual inputs, as well as generating text and speech.')
140139
messages = [{'role': 'system', 'content': system}, {'role': 'user', 'content': '<video>'}]
@@ -143,15 +142,11 @@ def test_qwen2_5_omni():
143142
pt_engine.default_template.template_backend = 'jinja'
144143
response2 = _infer_model(pt_engine, messages=messages, videos=videos)
145144
if USE_AUDIO_IN_VIDEO:
146-
147-
ground_truth = ('Oh, that sounds like a really cool project! Are you using a specific app on the tablet for '
148-
"drawing? And what kind of details are you adding to the guitar? It'd be interesting to hear "
149-
'more about your creative process.')
145+
ground_truth = ("Oh, that's a really cool drawing! It looks like a guitar. You've got the body "
146+
'and the neck drawn in a simple yet effective way. The lines are clean and the '
147+
'shape is well-defined. What made you choose to draw a guitar?')
150148
else:
151-
ground_truth = (
152-
"Oh, that sounds like a really cool project! So, you're using a tablet to draw a guitar and a key? "
153-
"That's a creative way to combine two different things. Have you thought about what you'll do "
154-
'with the final drawing? Maybe could use it for a poster or something? Let me know how it turns out!')
149+
ground_truth = ('嗯,你是在用平板画画呢。你画的这把吉他,看起来很简洁明了。你用的笔触也很流畅,线条很清晰。你对颜色的运用也很不错,整体看起来很协调。你要是还有啥想法或者问题,随时跟我说哈。')
155150
assert response == response2 == ground_truth
156151

157152

tests/test_align/test_template/test_vision.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,11 @@ def test_phi4_vision():
509509

510510
def test_gemma3_vision():
511511
pt_engine = PtEngine('LLM-Research/gemma-3-4b-it')
512-
response = _infer_model(pt_engine)
512+
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<image>Describe this image in detail.'}])
513513
pt_engine.default_template.template_backend = 'jinja'
514-
response2 = _infer_model(pt_engine)
515-
assert response == response2
514+
response2 = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<image>Describe this image in detail.'}])
515+
assert response[:80] == response2[:80] == (
516+
"Here's a detailed description of the image:\n\n**Overall Impression:**\n\nThe image ")
516517

517518

518519
def test_mistral_2503():
@@ -596,9 +597,9 @@ def test_kimi_vl():
596597
# test_minicpmo()
597598
# test_valley()
598599
# test_ui_tars()
599-
# test_gemma3_vision()
600+
test_gemma3_vision()
600601
# test_mistral_2503()
601602
# test_llama4()
602603
# test_internvl3_8b()
603604
# test_internvl3_9b()
604-
test_kimi_vl()
605+
# test_kimi_vl()

0 commit comments

Comments
 (0)