1
1
# Copyright (c) Alibaba, Inc. and its affiliates.
2
- from contextlib import contextmanager
3
2
from dataclasses import dataclass , field
4
3
from functools import partial
5
4
from typing import Any , Dict , List , Literal , Optional , Tuple
@@ -384,32 +383,50 @@ class Qwen2_5OmniTemplate(Qwen2_5VLTemplate):
384
383
version = 'omni'
385
384
placeholder_tokens = ['<|IMAGE|>' , '<|AUDIO|>' , '<|VIDEO|>' ]
386
385
386
+ def __init__ (self , * args , ** kwargs ):
387
+ super ().__init__ (* args , ** kwargs )
388
+ from transformers .models .qwen2_5_omni .processing_qwen2_5_omni import Qwen2_5OmniProcessorKwargs
389
+ default = Qwen2_5OmniProcessorKwargs ._defaults
390
+ self .seconds_per_chunk = default ['videos_kwargs' ]['seconds_per_chunk' ]
391
+ self .position_id_per_seconds = default ['videos_kwargs' ]['position_id_per_seconds' ]
392
+ self .use_audio_in_video = get_env_args ('use_audio_in_video' , bool , False )
393
+ self .sampling_rate = get_env_args ('sampling_rate' , int , self .processor .feature_extractor .sampling_rate )
394
+
387
395
def replace_tag (self , media_type : Literal ['image' , 'video' , 'audio' ], index : int ,
388
396
inputs : StdTemplateInputs ) -> List [Context ]:
389
397
from qwen_omni_utils import fetch_image , fetch_video
390
- sampling_rate = self .processor .feature_extractor .sampling_rate
391
398
if media_type == 'image' :
392
399
inputs .images [index ] = fetch_image ({'image' : inputs .images [index ]})
393
400
return ['<|vision_bos|><|IMAGE|><|vision_eos|>' ]
394
401
elif media_type == 'audio' :
395
- sampling_rate = get_env_args ('sampling_rate' , int , sampling_rate )
396
- inputs .audios [index ] = load_audio (inputs .audios [index ], sampling_rate )
402
+ inputs .audios [index ] = load_audio (inputs .audios [index ], self .sampling_rate )
397
403
return ['<|audio_bos|><|AUDIO|><|audio_eos|>' ]
398
404
elif media_type == 'video' :
399
405
video = inputs .videos [index ]
400
406
inputs .videos [index ] = fetch_video ({'video' : video }).to (torch .uint8 )
401
- use_audio_in_video = get_env_args ('use_audio_in_video' , bool , False )
402
- if use_audio_in_video :
407
+ if self .use_audio_in_video :
403
408
import librosa
404
- sampling_rate = get_env_args ('sampling_rate' , int , sampling_rate )
405
- video = librosa .load (video , sr = sampling_rate )[0 ]
406
- inputs .audios .insert (inputs .audio_idx , video )
409
+ if video .startswith ('http://' ) or video .startswith ('https://' ):
410
+ import audioread
411
+ video = audioread .ffdec .FFmpegAudioFile (video )
412
+ video = librosa .load (video , sr = self .sampling_rate )[0 ]
413
+ inputs .audios .insert (inputs .audio_idx , (video , 'video' ))
407
414
inputs .audio_idx += 1
415
+ return ['<|vision_bos|><|audio_bos|><|VIDEO|><|audio_eos|><|vision_eos|>' ]
408
416
return ['<|vision_bos|><|VIDEO|><|vision_eos|>' ]
409
417
410
418
def _encode (self , inputs : StdTemplateInputs ) -> Dict [str , Any ]:
411
419
encoded = Template ._encode (self , inputs )
412
- media_inputs = self .processor (
420
+ processor = self .processor
421
+ video_audios_mask = []
422
+ for i , audio in enumerate (inputs .audios ):
423
+ if isinstance (audio , tuple ) and audio [1 ] == 'video' :
424
+ inputs .audios [i ] = audio [0 ]
425
+ video_audios_mask .append (True )
426
+ else :
427
+ video_audios_mask .append (False )
428
+ video_audios_mask = torch .tensor (video_audios_mask )
429
+ media_inputs = processor (
413
430
text = '' ,
414
431
audio = inputs .audios or None ,
415
432
images = inputs .images or None ,
@@ -420,31 +437,70 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
420
437
media_inputs = to_float_dtype (media_inputs , self .model_info .torch_dtype )
421
438
input_ids = encoded ['input_ids' ]
422
439
labels = encoded ['labels' ]
440
+ # audio
441
+ audio_token_id = self ._tokenize ('<|AUDIO|>' )
442
+ idx_list = findall (input_ids , audio_token_id )
443
+ feature_attention_mask = media_inputs .get ('feature_attention_mask' )
444
+ if feature_attention_mask is not None :
445
+ audio_feature_lengths = torch .sum (feature_attention_mask , dim = 1 )
446
+ audio_lengths = (((audio_feature_lengths - 1 ) // 2 + 1 - 2 ) // 2 + 1 )
447
+ else :
448
+ audio_lengths = None
449
+ audio_lengths_origin = audio_lengths
450
+ if idx_list :
451
+ if self .use_audio_in_video :
452
+ audio_lengths = audio_lengths [~ video_audios_mask ]
453
+
454
+ def _get_new_audio_tokens (i ):
455
+ return audio_token_id * audio_lengths [i ]
456
+
457
+ input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list , _get_new_audio_tokens )
458
+
423
459
for media_type in ['image' , 'video' ]:
424
460
token = f'<|{ media_type .upper ()} |>'
425
461
token_id = self ._tokenize (token )
426
462
idx_list = findall (input_ids , token_id )
427
463
if idx_list :
428
- merge_length = self . processor .image_processor .merge_size ** 2
464
+ merge_size = processor .image_processor .merge_size
429
465
media_grid_thw = media_inputs .get (f'{ media_type } _grid_thw' )
466
+ if media_type == 'video' and self .use_audio_in_video :
467
+ audio_lengths = audio_lengths_origin [video_audios_mask ]
468
+ video_second_per_grid = media_inputs ['video_second_per_grid' ]
469
+
470
+ def _get_new_tokens_use_audio_in_video (i ):
471
+ audio_token_indices = torch .arange (audio_lengths [i ])
472
+ grid_thw = media_grid_thw [i ]
473
+ height = grid_thw [1 ] // merge_size
474
+ width = grid_thw [2 ] // merge_size
475
+ video_token_indices = torch .arange (grid_thw [0 ]).reshape (- 1 , 1 , 1 )
476
+ video_token_indices = torch .broadcast_to (
477
+ video_token_indices , (video_token_indices .shape [0 ], height , width )).reshape (- 1 )
478
+ video_token_indices = (
479
+ video_token_indices * video_second_per_grid [i ] * self .position_id_per_seconds )
480
+ tokens_per_chunk = int (self .position_id_per_seconds * self .seconds_per_chunk )
481
+ video_chunk_indexes = processor .get_chunked_index (video_token_indices , tokens_per_chunk )
482
+ audio_chunk_indexes = processor .get_chunked_index (audio_token_indices , tokens_per_chunk )
483
+
484
+ res = []
485
+ for j in range (max (len (video_chunk_indexes ), len (audio_chunk_indexes ))):
486
+ if j < len (video_chunk_indexes ):
487
+ video_seq_length = video_chunk_indexes [j ][1 ] - video_chunk_indexes [j ][0 ]
488
+ res += token_id * video_seq_length
489
+ if j < len (audio_chunk_indexes ):
490
+ audio_seq_length = audio_chunk_indexes [j ][1 ] - audio_chunk_indexes [j ][0 ]
491
+ res += audio_token_id * audio_seq_length
492
+ return res
493
+
494
+ input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list ,
495
+ _get_new_tokens_use_audio_in_video )
430
496
431
- def _get_new_tokens (i ):
432
- token_len = (media_grid_thw [i ].prod () // merge_length )
433
- return token_id * token_len
434
-
435
- input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list , _get_new_tokens )
436
- # audio
437
- feature_attention_mask = media_inputs .get ('feature_attention_mask' )
438
- if feature_attention_mask is not None :
439
- audio_feature_lengths = torch .sum (feature_attention_mask , dim = 1 ).tolist ()
440
- token_id = self ._tokenize ('<|AUDIO|>' )
441
- idx_list = findall (input_ids , token_id )
497
+ else :
442
498
443
- def _get_new_tokens (i ):
444
- place_num = (( audio_feature_lengths [ i ] - 1 ) // 2 + 1 - 2 ) // 2 + 1
445
- return token_id * place_num
499
+ def _get_new_tokens (i ):
500
+ token_len = ( media_grid_thw [ i ]. prod ( ) // ( merge_size ** 2 ))
501
+ return token_id * token_len
446
502
447
- input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list , _get_new_tokens )
503
+ input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list , _get_new_tokens )
448
504
449
505
encoded ['input_ids' ] = input_ids
450
506
encoded ['labels' ] = labels
@@ -460,7 +516,6 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
460
516
audio_feature_lengths = torch .sum (feature_attention_mask , dim = 1 )
461
517
else :
462
518
audio_feature_lengths = None
463
- use_audio_in_video = get_env_args ('use_audio_in_video' , bool , False )
464
519
video_second_per_grid = inputs .pop ('video_second_per_grid' , None )
465
520
input_ids = inputs ['input_ids' ]
466
521
attention_mask = inputs .get ('attention_mask' )
@@ -471,7 +526,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
471
526
inputs .get ('image_grid_thw' ),
472
527
inputs .get ('video_grid_thw' ),
473
528
attention_mask ,
474
- use_audio_in_video ,
529
+ self . use_audio_in_video ,
475
530
audio_feature_lengths ,
476
531
video_second_per_grid ,
477
532
)
@@ -493,7 +548,7 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
493
548
494
549
def generate (self , model , * args , ** kwargs ):
495
550
if kwargs .get ('video_grid_thw' ) is not None :
496
- kwargs ['use_audio_in_video' ] = get_env_args ( ' use_audio_in_video' , bool , False )
551
+ kwargs ['use_audio_in_video' ] = self . use_audio_in_video
497
552
return super ().generate (model , * args , ** kwargs )
498
553
499
554
0 commit comments