-
Notifications
You must be signed in to change notification settings - Fork 28.7k
Simplify Idefics2, Idefics3, SmolVLM images handling #37291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yonigozlan
wants to merge
8
commits into
huggingface:main
Choose a base branch
from
yonigozlan:flatten-idefics3-im-proc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
68d3f0e
process flatten images directly for idefics2 idefics3 smolvlm
yonigozlan 8051d94
fix missing attention mask for padded image
yonigozlan b4c187c
Merge branch 'main' into flatten-idefics3-im-proc
yonigozlan 01f86c0
fix when pixels_attention_mask is none
yonigozlan 71125eb
fix modeling tests
yonigozlan fb17dc4
fix style
yonigozlan ce2a37a
nit
yonigozlan fdbd9df
fix processors tests
yonigozlan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1086,19 +1086,7 @@ def inputs_merger( | |
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device) | ||
return new_inputs_embeds | ||
|
||
@add_start_docstrings_to_model_forward( | ||
""" | ||
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to | ||
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where | ||
max_num_images is the maximum number of images among the batch_size samples in the batch. | ||
|
||
Padding images are not needed beyond padding the pixel_values at the entrance of the model. | ||
For efficiency, we only pass through the vision_model's forward the real images by | ||
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where | ||
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3. | ||
""", | ||
IDEFICS2_INPUTS_DOCSTRING, | ||
) | ||
@add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING) | ||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
|
@@ -1128,12 +1116,7 @@ def forward( | |
) | ||
use_cache = False | ||
|
||
# retrieve input_ids and inputs_embeds | ||
if input_ids is not None: | ||
batch_size, seq_length = input_ids.shape | ||
elif inputs_embeds is not None: | ||
batch_size, seq_length, _ = inputs_embeds.shape | ||
else: | ||
if input_ids is None and inputs_embeds is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to also check cases when both are not
|
||
raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
||
past_seen_tokens = 0 | ||
|
@@ -1163,14 +1146,7 @@ def forward( | |
if pixel_values is not None and image_hidden_states is not None: | ||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") | ||
elif pixel_values is not None: | ||
batch_size, num_images, num_channels, height, width = pixel_values.shape | ||
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility | ||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) | ||
|
||
# Remove padding images - padding images are full 0. | ||
nb_values_per_image = pixel_values.shape[1:].numel() | ||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image | ||
pixel_values = pixel_values[real_images_inds].contiguous() | ||
|
||
# Handle the vision attention mask | ||
if pixel_attention_mask is None: | ||
|
@@ -1179,13 +1155,6 @@ def forward( | |
dtype=torch.bool, | ||
device=pixel_values.device, | ||
) | ||
else: | ||
# Remove padding images from the mask/pP p | ||
pixel_attention_mask = pixel_attention_mask.view( | ||
batch_size * num_images, *pixel_attention_mask.shape[2:] | ||
) | ||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() | ||
|
||
patch_size = self.config.vision_config.patch_size | ||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) | ||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we try to do a single for loop while we are at it?