Skip to content

Simplify Idefics2, Idefics3, SmolVLM images handling #37291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 47 additions & 84 deletions src/transformers/models/idefics2/image_processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_nested_list_of_images,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
Expand Down Expand Up @@ -86,18 +86,17 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:


def get_max_height_width(
images_list: List[List[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images_list[0][0])
input_data_format = infer_channel_dimension_format(images[0])

image_sizes = []
for images in images_list:
for image in images:
image_sizes.append(get_image_size(image, channel_dim=input_data_format))
for image in images:
image_sizes.append(get_image_size(image, channel_dim=input_data_format))

max_height, max_width = max_across_indices(image_sizes)
return (max_height, max_width)
Expand Down Expand Up @@ -284,7 +283,6 @@ def pad(
images: List[np.ndarray],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
Expand All @@ -299,54 +297,35 @@ def pad(
The value to use for the padding if `mode` is `"constant"`.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
pad_size = get_max_height_width(images, input_data_format=input_data_format)

batch_size = len(images)
max_num_images = max(len(images_) for images_ in images)
input_data_format = (
infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format
infer_channel_dimension_format(images[0]) if input_data_format is None else input_data_format
)
data_format = input_data_format if data_format is None else data_format

def empty_image(size, input_data_format):
if input_data_format == ChannelDimension.FIRST:
return np.zeros((3, *size), dtype=np.uint8)
elif input_data_format == ChannelDimension.LAST:
return np.zeros((*size, 3), dtype=np.uint8)
raise ValueError("Invalid channel dimension format.")

padded_images_list = [
[empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
padded_masks = (
[make_pixel_mask(image, output_size=pad_size, input_data_format=input_data_format) for image in images]
if return_pixel_mask
else None
)
images = [
self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images
]
padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]

for batch_idx in range(batch_size):
for sample_idx, image in enumerate(images[batch_idx]):
padded_images_list[batch_idx][sample_idx] = self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
padded_masks[batch_idx][sample_idx] = make_pixel_mask(
image, output_size=pad_size, input_data_format=input_data_format
)

padded_masks = padded_masks if return_pixel_mask else None
return padded_images_list, padded_masks

return images, padded_masks

def _crop(
self,
Expand Down Expand Up @@ -471,9 +450,9 @@ def preprocess(
do_pad = do_pad if do_pad is not None else self.do_pad
do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting

images_list = make_nested_list_of_images(images)
images = make_flat_list_of_images(images)

if not valid_images(images_list[0]):
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
Expand All @@ -491,73 +470,57 @@ def preprocess(
)

if do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
images = [convert_to_rgb(image) for image in images]

# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
images = [to_numpy_array(image) for image in images]

if do_rescale and is_scaled_image(images_list[0][0]):
if do_rescale and is_scaled_image(images[0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0])
input_data_format = infer_channel_dimension_format(images[0])

if do_image_splitting:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we try to do a single for loop while we are at it?

new_images_list = []
for images in images_list:
new_images = []
for image in images:
new_images.extend(self.split_image(image, input_data_format))
new_images_list.append(new_images)
images_list = new_images_list
new_images = []
for image in images:
new_images.extend(self.split_image(image, input_data_format))
images = new_images

if do_resize:
images_list = [
[
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
for images in images_list
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]

if do_rescale:
images_list = [
[
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
for images in images_list
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]

if do_normalize:
images_list = [
[
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
for images in images_list
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]

pixel_attention_mask = None
if do_pad:
images_list, pixel_attention_mask = self.pad(
images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
images, pixel_attention_mask = self.pad(
images, return_pixel_mask=True, input_data_format=input_data_format
)

if data_format is not None:
images_list = [
[
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in images
]
for images in images_list
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in images
]

data = {"pixel_values": np.array(images_list) if do_pad else images_list} # Faster tensor conversion
data = {"pixel_values": np.array(images) if do_pad else images} # Faster tensor conversion
if pixel_attention_mask is not None:
data["pixel_attention_mask"] = np.array(pixel_attention_mask) if do_pad else pixel_attention_mask

Expand Down
35 changes: 2 additions & 33 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,19 +1086,7 @@ def inputs_merger(
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device)
return new_inputs_embeds

@add_start_docstrings_to_model_forward(
"""
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
max_num_images is the maximum number of images among the batch_size samples in the batch.

Padding images are not needed beyond padding the pixel_values at the entrance of the model.
For efficiency, we only pass through the vision_model's forward the real images by
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
""",
IDEFICS2_INPUTS_DOCSTRING,
)
@add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -1128,12 +1116,7 @@ def forward(
)
use_cache = False

# retrieve input_ids and inputs_embeds
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
if input_ids is None and inputs_embeds is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to also check cases when both are not None:

if (input_ids is None) ^ (inputs_embeds is not None):

raise ValueError("You have to specify either input_ids or inputs_embeds")

past_seen_tokens = 0
Expand Down Expand Up @@ -1163,14 +1146,7 @@ def forward(
if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])

# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()

# Handle the vision attention mask
if pixel_attention_mask is None:
Expand All @@ -1179,13 +1155,6 @@ def forward(
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()

patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
Expand Down
Loading