Skip to content

[chat-template] Unify tests and clean up 🧼 #37275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions docs/source/en/chat_templating_multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,35 +181,6 @@ processed_chat = processor.apply_chat_template(
print(processed_chat.keys())
```

</hfoption>
<hfoption id="custom frame sampling">

Some models don't sample frames *uniformly* and require more complex logic to determine which frames to use. For example, the model may have an *adaptive frame selection* or if the model prioritizes *key moments* in a video rather than evenly spaced frames.

If a model has a different sampling strategy, you can write a function that customizes frame selection. The function should include the following requirements.

- Use the `sample_indices_fn` parameter to pass a callable function for sampling.
- If provided, this function *overrides* the standard `num_frames` and `fps` parameters.
- The function receives all the parameters passed to `load_video` and must return valid frame indices to sample from.

An example function is shown below. This gives you full control over frame selection, making the model more adaptable to different video scenarios.

```py
def sample_indices_fn(metadata, **kwargs):
# samples only the first and the second frame
return [0, 1]

processed_chat = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
sample_indices_fn=sample_indices_fn,
video_load_backend="decord",
)
print(processed_chat.keys())
```

</hfoption>
<hfoption id="list of image frames">

Expand Down
53 changes: 34 additions & 19 deletions src/transformers/models/smolvlm/processing_smolvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import numpy as np

from ...feature_extraction_utils import BatchFeature
from ...image_utils import (
ImageInput,
VideoInput,
load_video,
make_batched_videos,
make_nested_list_of_images,
)
Expand Down Expand Up @@ -425,32 +428,44 @@ def model_input_names(self):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))

# Add model-specific video sampling method when applying the template
def apply_chat_template(
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
def _load_video_for_model(
Comment on lines +431 to +432
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making sure I understand this: When we have a proper VideoProcessorBase API, there won't be model-specific _load_video_for_model methods anymore?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be model-specific method, in each video processor. Currently we have models loading videos with different sampling (SmolVLM) or even resizing (Qwen-Omni), even before the video is processed

As a second option, we can consider merging all these steps into self.preprocess(videos) as long as it doesn't slow it down. Haven't thought about it, but since you mentioned it that sounds interesting. I'll give it a try after API initial design is merged (under review still)

self,
conversation,
max_frames=None,
target_fps=None,
skip_secs=1,
video_load_backend="pyav",
sample_indices_fn=None,
**kwargs,
):
max_frames = self.default_max_frames if max_frames is None else max_frames
target_fps = self.default_fps if target_fps is None else target_fps
video: Union[str, "VideoInput"],
num_frames: Optional[int] = None,
fps: Optional[int] = None,
backend: str = "opencv",
skip_secs: int = 0.0,
) -> np.array:
"""
Loads `video` to a numpy array.

Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".

Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
max_frames = self.default_max_frames if num_frames is None else num_frames
target_fps = self.default_fps if fps is None else fps

def sample_indices_fn_func(metadata, **fn_kwargs):
return smolvlm_sample_indices_fn(
metadata, max_frames=max_frames, target_fps=target_fps, skip_secs=skip_secs, **fn_kwargs
)

# word of caution- we are blindly overriding a callable kwarg here.
# typed kwargs would be a way to avoid that @molbap
if not sample_indices_fn:
sample_indices_fn = sample_indices_fn_func
return super().apply_chat_template(
conversation, video_load_backend=video_load_backend, sample_indices_fn=sample_indices_fn, **kwargs
)
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
return video, metadata


__all__ = ["SmolVLMProcessor"]
94 changes: 68 additions & 26 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import typing
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union
from typing import Any, Dict, List, Optional, TypedDict, Union

import numpy as np
import typing_extensions
Expand Down Expand Up @@ -415,7 +415,6 @@ def sample_indices_fn(num_frames, fps, metadata, **kwargs):
video_load_backend: Optional[str] = "pyav"
video_fps: Optional[int] = None
sampling_rate: Optional[int] = 16_000
sample_indices_fn: Optional[Callable] = None
load_audio_from_video: Optional[bool] = False


Expand All @@ -435,7 +434,16 @@ class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateK

class AllKwargsForChatTemplate(
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs
): ...
):
processor_kwargs: ProcessingKwargs = {
**ProcessingKwargs.__annotations__,
}
mm_load_kwargs: ChatTemplateLoadKwargs = {
**TextKwargs.__annotations__,
}
template_kwargs: ProcessorChatTemplateKwargs = {
**ProcessorChatTemplateKwargs.__annotations__,
}


class ProcessorMixin(PushToHubMixin):
Expand Down Expand Up @@ -1315,19 +1323,20 @@ def apply_chat_template(
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
)

# Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
# and for multimodal data loading. Everything else will be used in `__call__`
tokenizer_template_kwargs = {}
for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys():
default_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
value = kwargs.pop(tokenizer_key, default_value)
tokenizer_template_kwargs[tokenizer_key] = value
# Fill sets of kwargs that should be used by different parts of template
processed_kwargs = {
"processor_kwargs": {},
"mm_load_kwargs": {},
"template_kwargs": {},
}

mm_load_kwargs = {}
for mm_load_key in ChatTemplateLoadKwargs.__annotations__.keys():
default_value = getattr(ChatTemplateLoadKwargs, mm_load_key, None)
value = kwargs.pop(mm_load_key, default_value)
mm_load_kwargs[mm_load_key] = value
for kwarg_type in processed_kwargs:
for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__.keys():
kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type]
default_value = getattr(kwarg_type_defaults, key, None)
value = kwargs.pop(key, default_value)
if value is not None and not isinstance(value, dict):
processed_kwargs[kwarg_type][key] = value

if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
Expand All @@ -1338,8 +1347,9 @@ def apply_chat_template(
is_batched = False
conversations = [conversation]

tokenize = kwargs.pop("tokenize", False)
return_dict = kwargs.pop("return_dict", False)
tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False)
return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False)
mm_load_kwargs = processed_kwargs["mm_load_kwargs"]

if tokenize:
batch_images, batch_videos = [], []
Expand Down Expand Up @@ -1382,7 +1392,7 @@ def apply_chat_template(

for fname in video_fnames:
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
video = [np.array(load_image(image_fname)).T for image_fname in fname]
video = [np.array(load_image(image_fname)) for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array
video = np.stack(video)
metadata = None
Expand All @@ -1391,12 +1401,13 @@ def apply_chat_template(
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
)
else:
video, metadata = load_video(
# TODO: raushan, should be `self.video_processor.load_video_for_model` when API is added
video, metadata = self._load_video_for_model(
fname,
num_frames=mm_load_kwargs["num_frames"],
fps=mm_load_kwargs["video_fps"],
num_frames=mm_load_kwargs.get("num_frames", None),
fps=mm_load_kwargs.get("video_fps", None),
backend=mm_load_kwargs["video_load_backend"],
sample_indices_fn=mm_load_kwargs["sample_indices_fn"],
**kwargs,
)
videos.append(video)
video_metadata.append(metadata)
Expand All @@ -1415,15 +1426,15 @@ def apply_chat_template(
batch_images=batch_images,
batch_videos=batch_videos,
batch_video_metadata=batch_video_metadata,
**mm_load_kwargs,
**processed_kwargs["mm_load_kwargs"],
)

prompt = self.tokenizer.apply_chat_template(
conversations,
chat_template=chat_template,
tokenize=False,
return_dict=False,
**tokenizer_template_kwargs,
**processed_kwargs["template_kwargs"],
)

if not is_batched:
Expand All @@ -1438,21 +1449,52 @@ def apply_chat_template(
# without actionable solution for users
single_prompt = prompt[0] if is_batched else prompt
if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
kwargs["add_special_tokens"] = False
processed_kwargs["processor_kwargs"]["add_special_tokens"] = False

out = self(
text=prompt,
images=batch_images if batch_images else None,
videos=batch_videos if batch_videos else None,
audio=batch_audios if batch_audios else None,
**kwargs,
**processed_kwargs["processor_kwargs"],
)
if return_dict:
return out
else:
return out["input_ids"]
return prompt

# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
# Keep private so we can simply remove when needed
def _load_video_for_model(
self,
video: Union[str, "VideoInput"],
num_frames: Optional[int] = None,
fps: Optional[int] = None,
backend: str = "opencv",
) -> np.array:
"""
Loads `video` to a numpy array.

Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".

Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
video, metadata = load_video(video, num_frames, fps=fps, backend=backend)
return video, metadata

def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
"""
Post-process the output of a vlm to decode the text.
Expand Down
49 changes: 0 additions & 49 deletions tests/models/aria/test_processor_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,55 +236,6 @@ def test_apply_chat_template(self):
"""
self.assertEqual(rendered, expected_rendered)

# Override as AriaImageProcessor doesn't accept `do_rescale`
def test_image_chat_template_accepts_processing_kwargs(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")

messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
]
]

formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
max_length=50,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)

formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
truncation=True,
max_length=5,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)

# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
max_image_size=980,
return_tensors="np",
)
self.assertListEqual(list(out_dict[self.images_input_name].shape), [1, 3, 980, 980])

# Override as AriaProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
Expand Down
5 changes: 0 additions & 5 deletions tests/models/aya_vision/test_processor_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,6 @@ def get_processor(self, **kwargs):
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)

# todo: yoni, fix this test
@unittest.skip("Chat template has long system prompt")
def test_chat_template_accepts_processing_kwargs(self, **kwargs):
pass

# Override as AyaVisionProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
Expand Down
Loading