Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/glpn.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] GLPNImageProcessor
- preprocess

## GLPNImageProcessorFast

[[autodoc]] GLPNImageProcessorFast
- preprocess

## GLPNModel

[[autodoc]] GLPNModel
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
("glpn", ("GLPNImageProcessor", None)),
("glpn", ("GLPNImageProcessor", "GLPNImageProcessorFast")),
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glpn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .configuration_glpn import *
from .feature_extraction_glpn import *
from .image_processing_glpn import *
from .image_processing_glpn_fast import *
from .modeling_glpn import *
else:
import sys
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/models/glpn/image_processing_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
valid_images,
validate_preprocess_arguments,
)
from ...processing_utils import ImagesKwargs
from ...utils import TensorType, filter_out_non_signature_kwargs, logging, requires_backends


Expand All @@ -49,6 +50,17 @@
logger = logging.get_logger(__name__)


class GLPNImageProcessorKwargs(ImagesKwargs, total=False):
"""
size_divisor (`int`, *optional*, defaults to 32):
When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest
multiple of `size_divisor`.
"""

size_divisor: int
resample: PILImageResampling


@requires(backends=("vision",))
class GLPNImageProcessor(BaseImageProcessor):
r"""
Expand All @@ -69,6 +81,7 @@ class GLPNImageProcessor(BaseImageProcessor):
"""

model_input_names = ["pixel_values"]
valid_kwargs = GLPNImageProcessorKwargs

def __init__(
self,
Expand Down Expand Up @@ -223,6 +236,26 @@ def preprocess(
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]

if return_tensors:
shapes = {tuple(img.shape) for img in images}
if len(shapes) > 1:
# Find max dimensions
max_height = max(img.shape[-2] for img in images)
max_width = max(img.shape[-1] for img in images)

# Pad each image to max dimensions
padded_images = []
for img in images:
h, w = img.shape[-2:]
if h < max_height or w < max_width:
# Create padded array with zeros
padded = np.zeros((*img.shape[:-2], max_height, max_width), dtype=img.dtype)
padded[..., :h, :w] = img
padded_images.append(padded)
else:
padded_images.append(img)
images = padded_images

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

Expand Down
208 changes: 208 additions & 0 deletions src/transformers/models/glpn/image_processing_glpn_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for GLPN."""

from typing import Optional, Union

import torch
from torchvision.transforms.v2 import functional as F

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
PILImageResampling,
)
from ...utils import (
TensorType,
auto_docstring,
requires_backends,
)
from .image_processing_glpn import GLPNImageProcessorKwargs


@auto_docstring
class GLPNImageProcessorFast(BaseImageProcessorFast):
"""
Fast image processor for GLPN using the Torch/TorchVision backend.
Performs:
- Crop H,W down to the nearest multiple of `size_divisor` (default 32)
- Rescale [0,255] → [0,1]
- (No normalization by default)
"""

# Persist ONLY the same keys as the slow processor
do_resize = True
do_rescale = True
do_normalize = False
resample = PILImageResampling.BILINEAR
size_divisor = 32
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
interpolation = F.InterpolationMode.BILINEAR
valid_kwargs = GLPNImageProcessorKwargs

# If BaseImageProcessorFast supports it, this makes persistence explicit:
try:
config_keys = {"do_resize", "size_divisor", "resample", "do_rescale"}
except Exception:
pass

def __init__(self, **kwargs) -> None:
if "ensure_multiple_of" in kwargs and "size_divisor" not in kwargs:
kwargs = dict(kwargs)
kwargs["size_divisor"] = kwargs.pop("ensure_multiple_of")
# ensure resample default for validation
kwargs.setdefault("resample", PILImageResampling.BILINEAR)
kwargs.setdefault("size", {"height": 480, "width": 640})
super().__init__(**kwargs)

@staticmethod
def _crop_to_multiple(
images: torch.Tensor,
size_divisor: int = 32,
interpolation: "F.InterpolationMode" = F.InterpolationMode.BILINEAR,
) -> torch.Tensor:
"""
Resize images (B,C,H,W) by flooring H and W to nearest multiple of `size_divisor`.
Uses interpolation to match slow GLPN behavior.
"""
_, _, h, w = images.shape
new_h = (h // size_divisor) * size_divisor
new_w = (w // size_divisor) * size_divisor
if (new_h, new_w) == (h, w):
return images
# Resize (not crop) to match slow processor behavior
return F.resize(images, size=(new_h, new_w), interpolation=interpolation, antialias=True)

def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: Optional[dict] = None,
size_divisor: Optional[int] = None,
interpolation: Optional["F.InterpolationMode"] = None,
do_rescale: bool = True,
rescale_factor: Optional[float] = 1 / 255,
do_normalize: bool = False,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
disable_grouping: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
resample: Optional[PILImageResampling] = None,
**kwargs,
) -> BatchFeature:
"""
GLPN fast preprocessing:
- crop to floored multiple of size_divisor
- rescale [0,1]
- normalize (off by default)
"""
# avoid validation error: inject dummy size/resample for validate_preprocess_arguments

if resample is None and interpolation is None:
resample = self.resample

grouped_images, grouped_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_groups = {}
sd = size_divisor if size_divisor is not None else self.size_divisor

for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self._crop_to_multiple(stacked_images, sd, interpolation)
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_groups[shape] = stacked_images

reordered = reorder_images(processed_groups, grouped_index)

if return_tensors:
# Detect heterogeneous shapes
shapes = {tuple(img.shape) for img in reordered}
if len(shapes) > 1:
# Pad to max height and width in batch
max_height = max(img.shape[-2] for img in reordered)
max_width = max(img.shape[-1] for img in reordered)

padded = []
for img in reordered:
h, w = img.shape[-2:]
if h < max_height or w < max_width:
# Pad to max dimensions
pad_h = max_height - h
pad_w = max_width - w
# Pad on right and bottom
img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h))
padded.append(img)
reordered = padded

processed = torch.stack(reordered, dim=0)
tensor_type = return_tensors
else:
processed = reordered
tensor_type = None

return BatchFeature(data={"pixel_values": processed}, tensor_type=tensor_type)

# ensure only slow keys are serialized
def to_dict(self):
output_dict = super().to_dict()

# Keep only these keys with their values (everything else gets set to None)
keys_to_keep = {
"image_processor_type",
"_processor_class", # Identity metadata
"do_resize",
"size_divisor",
"resample",
"do_rescale", # Core GLPN params
"default_to_square",
"data_format", # Fast processor params
}

# Set all other keys to None (don't persist their values)
for key in list(output_dict.keys()):
if key not in keys_to_keep:
output_dict[key] = None

return output_dict

def post_process_depth_estimation(self, outputs, target_sizes=None):
"""
Convert raw model outputs to final depth predictions.
Mirrors slow GLPN: PyTorch interpolate w/ bicubic, align_corners=False.
"""
requires_backends(self, "torch")
predicted_depth = outputs.predicted_depth # shape: (B, H, W) or (B, 1, H, W)

results = []
target_sizes = target_sizes or [None] * predicted_depth.shape[0]
for depth, target_size in zip(predicted_depth, target_sizes):
if target_size is not None:
# Add batch and channel dimensions for interpolation
depth_4d = depth[None, None, ...]
resized = torch.nn.functional.interpolate(
depth_4d, size=target_size, mode="bicubic", align_corners=False
)
depth = resized.squeeze(0).squeeze(0)
results.append({"predicted_depth": depth})

return results


__all__ = ["GLPNImageProcessorFast"]
Loading