Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
84edd1e
support for volume patching
danielenricocahall Oct 20, 2025
ef9316c
correct `call`
danielenricocahall Oct 20, 2025
229e877
fix pydoc
danielenricocahall Oct 20, 2025
ab0c738
fix `test_extract_volume_patches_basic` casting residual
danielenricocahall Oct 20, 2025
78bb290
fix `test_extract_volume_patches_same_padding` casting
danielenricocahall Oct 20, 2025
6c3a2e5
fix `test_extract_volume_patches_overlapping` casting
danielenricocahall Oct 20, 2025
b41fd35
add extra testing
danielenricocahall Oct 20, 2025
304b5be
Update keras/src/ops/image.py
danielenricocahall Oct 20, 2025
9d940fd
fix dimension orderiing
danielenricocahall Oct 20, 2025
d6b43b4
Update keras/src/ops/image.py
danielenricocahall Oct 20, 2025
335cfe5
Merge remote-tracking branch 'refs/remotes/origin/add-volume-patching…
danielenricocahall Oct 20, 2025
1f377dc
fix dimensional ordering + add test
danielenricocahall Oct 20, 2025
68f4f02
fix docstring
danielenricocahall Oct 20, 2025
5f7775b
docstring corrections
danielenricocahall Oct 20, 2025
b4b382d
add validation checks for operation
danielenricocahall Oct 20, 2025
6aeae60
comment
danielenricocahall Oct 20, 2025
9bf0656
set default data format in tests, add a few parametrized tests to val…
danielenricocahall Oct 20, 2025
aacde01
add back test for channels first/last in dilation
danielenricocahall Oct 20, 2025
40a6682
fix dilation test
danielenricocahall Oct 20, 2025
0973857
Update keras/src/ops/image.py
danielenricocahall Oct 20, 2025
9260d8a
delete redundant prop set
danielenricocahall Oct 20, 2025
36ca76a
Update keras/src/ops/image_test.py
danielenricocahall Oct 20, 2025
441eb32
Merge remote-tracking branch 'refs/remotes/origin/add-volume-patching…
danielenricocahall Oct 20, 2025
974d500
Update keras/src/ops/image_test.py
danielenricocahall Oct 20, 2025
4014a77
Merge remote-tracking branch 'refs/remotes/origin/add-volume-patching…
danielenricocahall Oct 20, 2025
4cf0d7a
rename per francois feedback
danielenricocahall Oct 22, 2025
52ec086
fix pydoc to reflect method
danielenricocahall Oct 22, 2025
e9f4e26
fix name of operation + test
danielenricocahall Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.ops.image import crop_images as crop_images
from keras.src.ops.image import elastic_transform as elastic_transform
from keras.src.ops.image import extract_patches as extract_patches
from keras.src.ops.image import extract_patches_3d as extract_patches_3d
from keras.src.ops.image import gaussian_blur as gaussian_blur
from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb
from keras.src.ops.image import map_coordinates as map_coordinates
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.ops.image import crop_images as crop_images
from keras.src.ops.image import elastic_transform as elastic_transform
from keras.src.ops.image import extract_patches as extract_patches
from keras.src.ops.image import extract_patches_3d as extract_patches_3d
from keras.src.ops.image import gaussian_blur as gaussian_blur
from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb
from keras.src.ops.image import map_coordinates as map_coordinates
Expand Down
181 changes: 181 additions & 0 deletions keras/src/ops/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,187 @@ def _extract_patches(
return patches


class ExtractPatches3D(Operation):
def __init__(
self,
size,
strides=None,
dilation_rate=1,
padding="valid",
data_format=None,
*,
name=None,
):
super().__init__(name=name)
if isinstance(size, int):
size = (size, size, size)
elif len(size) != 3:
raise TypeError(
"Invalid `size` argument. Expected an "
f"int or a tuple of length 3. Received: size={size}"
)
self.size = size
if strides is not None:
if isinstance(strides, int):
strides = (strides, strides, strides)
elif len(strides) != 3:
raise ValueError(f"Invalid `strides` argument. Got: {strides}")
else:
strides = size
self.strides = strides
self.dilation_rate = dilation_rate
self.padding = padding
self.data_format = backend.standardize_data_format(data_format)

def call(self, volumes):
return _extract_patches_3d(
volumes,
self.size,
self.strides,
self.dilation_rate,
self.padding,
self.data_format,
)

def compute_output_spec(self, volumes):
volumes_shape = list(volumes.shape)
original_ndim = len(volumes_shape)
strides = self.strides
if self.data_format == "channels_last":
channels_in = volumes_shape[-1]
else:
channels_in = volumes_shape[-4]
if original_ndim == 4:
volumes_shape = [1] + volumes_shape
filters = self.size[0] * self.size[1] * self.size[2] * channels_in
kernel_size = (self.size[0], self.size[1], self.size[2])
out_shape = compute_conv_output_shape(
volumes_shape,
filters,
kernel_size,
strides=strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate,
)
if original_ndim == 4:
out_shape = out_shape[1:]
return KerasTensor(shape=out_shape, dtype=volumes.dtype)


def _extract_patches_3d(
volumes,
size,
strides=None,
dilation_rate=1,
padding="valid",
data_format=None,
):
if isinstance(size, int):
patch_d = patch_h = patch_w = size
elif len(size) == 3:
patch_d, patch_h, patch_w = size
else:
raise TypeError(
"Invalid `size` argument. Expected an "
f"int or a tuple of length 3. Received: size={size}"
)
if strides is None:
strides = size
if isinstance(strides, int):
strides = (strides, strides, strides)
if len(strides) != 3:
raise ValueError(f"Invalid `strides` argument. Got: {strides}")
data_format = backend.standardize_data_format(data_format)
if data_format == "channels_last":
channels_in = volumes.shape[-1]
elif data_format == "channels_first":
channels_in = volumes.shape[-4]
out_dim = patch_d * patch_w * patch_h * channels_in
kernel = backend.numpy.eye(out_dim, dtype=volumes.dtype)
kernel = backend.numpy.reshape(
kernel, (patch_d, patch_h, patch_w, channels_in, out_dim)
)
_unbatched = False
if len(volumes.shape) == 4:
_unbatched = True
volumes = backend.numpy.expand_dims(volumes, axis=0)
patches = backend.nn.conv(
inputs=volumes,
kernel=kernel,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
)
if _unbatched:
patches = backend.numpy.squeeze(patches, axis=0)
return patches


@keras_export("keras.ops.image.extract_patches_3d")
def extract_patches_3d(
volumes,
size,
strides=None,
dilation_rate=1,
padding="valid",
data_format=None,
):
"""Extracts patches from the volume(s).

Args:
volumes: Input volume or batch of volumes. Must be 4D or 5D.
size: Patch size int or tuple (patch_depth, patch_height, patch_width)
strides: strides along depth, height, and width. If not specified, or
if `None`, it defaults to the same value as `size`.
dilation_rate: This is the input stride, specifying how far two
consecutive patch samples are in the input. Note that using
`dilation_rate > 1` is not supported in conjunction with
`strides > 1` on the TensorFlow backend.
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
data_format: A string specifying the data format of the input tensor.
It can be either `"channels_last"` or `"channels_first"`.
`"channels_last"` corresponds to inputs with shape
`(batch, depth, height, width, channels)`, while `"channels_first"`
corresponds to inputs with shape
`(batch, channels, depth, height, width)`. If not specified,
the value will default to `keras.config.image_data_format()`.

Returns:
Extracted patches 4D (if not batched) or 5D (if batched)

Examples:

>>> import numpy as np
>>> import keras
>>> # Batched case
>>> volumes = np.random.random(
... (2, 10, 10, 10, 3)
... ).astype("float32") # batch of 2 volumes
>>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))
>>> patches.shape
(2, 3, 3, 3, 81)
>>> # Unbatched case
>>> volume = np.random.random((10, 10, 10, 3)).astype("float32") # 1 volume
>>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))
>>> patches.shape
(3, 3, 3, 81)
"""
if any_symbolic_tensors((volumes,)):
return ExtractPatches3D(
size=size,
strides=strides,
dilation_rate=dilation_rate,
padding=padding,
data_format=data_format,
).symbolic_call(volumes)

return _extract_patches_3d(
volumes, size, strides, dilation_rate, padding, data_format=data_format
)


class MapCoordinates(Operation):
def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None):
super().__init__(name=name)
Expand Down
Loading