Skip to content

Add Fast Image Processor for vilt #37304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/vilt.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
[[autodoc]] ViltImageProcessor
- preprocess

## ViltImageProcessorFast

[[autodoc]] ViltImageProcessorFast
- preprocess

## ViltProcessor

[[autodoc]] ViltProcessor
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,7 @@
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
_import_structure["models.siglip"].append("SiglipImageProcessorFast")
_import_structure["models.siglip2"].append("Siglip2ImageProcessorFast")
_import_structure["models.vilt"].append("ViltImageProcessorFast")
_import_structure["models.vit"].append("ViTImageProcessorFast")

try:
Expand Down Expand Up @@ -6689,6 +6690,7 @@
from .models.rt_detr import RTDetrImageProcessorFast
from .models.siglip import SiglipImageProcessorFast
from .models.siglip2 import Siglip2ImageProcessorFast
from .models.vilt import ViltImageProcessorFast
from .models.vit import ViTImageProcessorFast

try:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
("upernet", ("SegformerImageProcessor",)),
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("videomae", ("VideoMAEImageProcessor",)),
("vilt", ("ViltImageProcessor",)),
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
("vit_hybrid", ("ViTHybridImageProcessor",)),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/vilt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .configuration_vilt import *
from .feature_extraction_vilt import *
from .image_processing_vilt import *
from .image_processing_vilt_fast import *
from .modeling_vilt import *
from .processing_vilt import *
else:
Expand Down
290 changes: 290 additions & 0 deletions src/transformers/models/vilt/image_processing_vilt_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Vilt."""

from typing import List, Optional, Union

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
get_max_height_width,
group_images_by_shape,
reorder_images,
)
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling, SizeDict
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)


if is_torch_available():
import torch

if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F

# Set maximum size based on the typical aspect ratio of the COCO dataset
MAX_LONGER_EDGE = 1333
MAX_SHORTER_EDGE = 800


class ViltFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
do_pad: Optional[bool]
size_divisor: Optional[int]
rescale_factor: Optional[float]


@add_start_docstrings(
"Constructs a fast Vilt image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class ViltImageProcessorFast(BaseImageProcessorFast):
# This generated class can be used as a starting point for the fast image processor.
# if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,
# only the default values should be set in the class.
# If the image processor requires more complex augmentations, methods from BaseImageProcessorFast can be overridden.
# In most cases, only the `_preprocess` method should be overridden.

# For an example of a fast image processor requiring more complex augmentations, see `LlavaNextImageProcessorFast`.

# Default values should be checked against the slow image processor
# None values left after checking can be removed
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"shortest_edge": 384}
do_resize = True
do_rescale = True
do_normalize = True
size_divisor = 32
do_pad = True
default_to_square = False
model_input_names = ["pixel_values", "pixel_mask"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should have this for this processor:

Suggested change
model_input_names = ["pixel_values", "pixel_mask"]
default_to_square = False
model_input_names = ["pixel_values", "pixel_mask"]

valid_kwargs = ViltFastImageProcessorKwargs

def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or batch of images.

This method overrides the base class method to include padding and pixel mask generation.
"""
size_divisor = kwargs.get("size_divisor", self.size_divisor)
do_pad = kwargs.get("do_pad", self.do_pad)

# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}

for shape, stacked_images in grouped_images.items():
if do_resize:
# 텐서로 변환
if isinstance(stacked_images, list):
stacked_images = torch.stack(stacked_images)
stacked_images = self._resize(stacked_images, size, interpolation, size_divisor)

resized_images_grouped[shape] = stacked_images

resized_images = reorder_images(resized_images_grouped, grouped_images_index)

# Group images by size for further processing
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}

for shape, stacked_images in grouped_images.items():
# 텐서로 변환
if isinstance(stacked_images, list):
stacked_images = torch.stack(stacked_images)

# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)

# Handle padding if required
data = {}
if do_pad:
data = self._pad_batch(processed_images, return_tensors)
else:
# If no padding, just return the processed images
if return_tensors == "pt":
processed_images = torch.stack(processed_images)
data["pixel_values"] = processed_images

return BatchFeature(data=data, tensor_type=return_tensors)

def _resize(
self,
images: "torch.Tensor",
size: SizeDict,
interpolation: Optional["F.InterpolationMode"] = None,
size_divisor: Optional[int] = None,
**kwargs,
) -> "torch.Tensor":
"""
Resize an image or batch of images to specified size.

Args:
images (`torch.Tensor`): Image or batch of images to resize.
size (`Dict[str, int]`): Size dictionary with shortest_edge key.
interpolation (`F.InterpolationMode`, *optional*): Interpolation method to use.
size_divisor (`int`, *optional*): Value to ensure height/width are divisible by.

Returns:
`torch.Tensor`: Resized image or batch of images.
"""
if interpolation is None:
interpolation = self.resample

# Resize with aspect ratio preservation
shorter = size.shortest_edge
longer = int(MAX_LONGER_EDGE / MAX_SHORTER_EDGE * shorter)

heights = images.shape[-2]
widths = images.shape[-1]

# Determine the new dimensions
if heights < widths:
new_heights = shorter
new_widths = widths * (shorter / heights)
else:
new_heights = heights * (shorter / widths)
new_widths = shorter

# Check if the longer side exceeds max size
if max(new_heights, new_widths) > longer:
scale = longer / max(new_heights, new_widths)
new_heights = new_heights * scale
new_widths = new_widths * scale

new_heights = int(new_heights + 0.5)
new_widths = int(new_widths + 0.5)

# Make dimensions divisible by size_divisor
if size_divisor is not None:
new_heights = new_heights // size_divisor * size_divisor
new_widths = new_widths // size_divisor * size_divisor

# Resize the image
return F.resize(images, [new_heights, new_widths], interpolation=interpolation)

def _pad_batch(
self,
images: list["torch.Tensor"],
return_tensors: Optional[Union[str, TensorType]],
) -> dict:
"""
Pad a batch of images to the same size based on the maximum dimensions.

Args:
images (`list[torch.Tensor]`): List of images to pad.
return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return.

Returns:
`dict`: Dictionary containing padded images and pixel masks.
"""
# Calculate global maximum dimensions across all images
max_size = get_max_height_width(images)

# Group images by shape before padding
grouped_images, grouped_images_index = group_images_by_shape(images)
processed_grouped = {}

for shape, stacked_images in grouped_images.items():
# Convert list to tensor if needed
if isinstance(stacked_images, list):
stacked_images = torch.stack(stacked_images)

# Create mask template for efficient masking
if return_tensors == "pt" and len(stacked_images) > 0:
device = stacked_images.device
mask_template = torch.zeros(max_size, dtype=torch.int64, device=device)

# Process each image in the group
padded_images = []
pixel_masks = []

for image in stacked_images:
original_size = image.shape[-2:]
needs_padding = original_size[0] != max_size[0] or original_size[1] != max_size[1]

if needs_padding:
padding_bottom = max_size[0] - original_size[0]
padding_right = max_size[1] - original_size[1]
padding = [0, 0, padding_right, padding_bottom]

# Pad the image
padded_image = F.pad(image, padding, fill=0)

# Create pixel mask (1 for valid pixels, 0 for padding)
pixel_mask = mask_template.clone()
pixel_mask[: original_size[0], : original_size[1]].fill_(1)
else:
padded_image = image
pixel_mask = torch.ones(max_size, dtype=torch.int64, device=image.device)

padded_images.append(padded_image)
pixel_masks.append(pixel_mask)

# Stack for this group if tensors are requested
if return_tensors == "pt" and padded_images:
padded_images = torch.stack(padded_images)
pixel_masks = torch.stack(pixel_masks)

# Store processed group
processed_grouped[shape] = (padded_images, pixel_masks)

# Reorder images back to original order
padded_images = []
pixel_masks = []

for _, (group_key, position) in grouped_images_index.items():
padded_images.append(processed_grouped[group_key][0][position])
pixel_masks.append(processed_grouped[group_key][1][position])

# Stack if tensors are requested for final result
if return_tensors == "pt" and padded_images:
padded_images = torch.stack(padded_images)
pixel_masks = torch.stack(pixel_masks)

return {"pixel_values": padded_images, "pixel_mask": pixel_masks}


__all__ = ["ViltImageProcessorFast"]
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_torchvision_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])


class ViltImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])


class ViTImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

Expand Down
Loading