Skip to content

Commit 69d807a

Browse files
authored
add pad transforms with unit tests for lazy resampling (#6031)
Part of #5991. ### Description This PR adds the pad transforms from #5860 relate to lazy resampling, and it belongs to the second part mentioned in #5991 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <[email protected]>
1 parent b9e17e8 commit 69d807a

15 files changed

+401
-201
lines changed

docs/source/transforms.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ Generic Interfaces
8888
.. autoclass:: RandomOrder
8989
:members:
9090

91+
Functionals
92+
-----------
93+
94+
.. automodule:: monai.transforms.croppad.functional
95+
:members:
96+
97+
.. currentmodule:: monai.transforms
98+
9199
Vanilla Transforms
92100
------------------
93101

monai/transforms/croppad/array.py

Lines changed: 6 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,18 @@
2222

2323
import numpy as np
2424
import torch
25-
from torch.nn.functional import pad as pad_pt
2625

2726
from monai.config import IndexSelection
2827
from monai.config.type_definitions import NdarrayOrTensor
2928
from monai.data.meta_obj import get_track_meta
3029
from monai.data.meta_tensor import MetaTensor
3130
from monai.data.utils import get_random_patch, get_valid_patch_size
31+
from monai.transforms.croppad.functional import pad_func
3232
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
3333
from monai.transforms.traits import MultiSampleTrait
34-
from monai.transforms.transform import Randomizable, Transform
34+
from monai.transforms.transform import LazyTransform, Randomizable, Transform
3535
from monai.transforms.utils import (
3636
compute_divisible_spatial_size,
37-
convert_pad_mode,
3837
create_translate,
3938
generate_label_classes_crop_centers,
4039
generate_pos_neg_label_crop_centers,
@@ -82,7 +81,7 @@
8281
]
8382

8483

85-
class Pad(InvertibleTransform):
84+
class Pad(InvertibleTransform, LazyTransform):
8685
"""
8786
Perform padding for a given an amount of padding in each dimension.
8887
@@ -124,24 +123,6 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> list[tuple[int, int
124123
"""
125124
raise NotImplementedError(f"subclass {self.__class__.__name__} must implement this method.")
126125

127-
@staticmethod
128-
def _np_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor:
129-
img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img
130-
mode = convert_pad_mode(dst=img_np, mode=mode).value
131-
if mode == "constant" and "value" in kwargs:
132-
val = kwargs.pop("value")
133-
kwargs["constant_values"] = val
134-
out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs))
135-
if isinstance(img, MetaTensor):
136-
out = convert_to_dst_type(out, dst=img)[0]
137-
return out
138-
139-
@staticmethod
140-
def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor:
141-
pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1]
142-
# torch.pad expects `[B, C, H, W, [D]]` shape
143-
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0)
144-
145126
def __call__( # type: ignore
146127
self, img: torch.Tensor, to_pad: list[tuple[int, int]] | None = None, mode: str | None = None, **kwargs
147128
) -> torch.Tensor:
@@ -162,52 +143,14 @@ def __call__( # type: ignore
162143
"""
163144
to_pad_ = self.to_pad if to_pad is None else to_pad
164145
if to_pad_ is None:
165-
to_pad_ = self.compute_pad_width(img.shape[1:])
146+
spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
147+
to_pad_ = self.compute_pad_width(spatial_shape)
166148
mode_ = self.mode if mode is None else mode
167149
kwargs_ = dict(self.kwargs)
168150
kwargs_.update(kwargs)
169151

170152
img_t = convert_to_tensor(data=img, track_meta=get_track_meta())
171-
_orig_size = img_t.shape[1:]
172-
173-
# all zeros, skip padding
174-
if np.asarray(to_pad_).any():
175-
to_pad_ = list(to_pad_)
176-
if len(to_pad_) < len(img_t.shape):
177-
to_pad_ = list(to_pad_) + [(0, 0)] * (len(img_t.shape) - len(to_pad_))
178-
if mode_ in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
179-
out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
180-
else:
181-
mode_ = convert_pad_mode(dst=img_t, mode=mode_).value
182-
try:
183-
_pad = (
184-
self._pt_pad
185-
if mode_ in {"reflect", "replicate"}
186-
and img_t.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}
187-
else self._np_pad
188-
)
189-
out = _pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
190-
except (ValueError, TypeError, RuntimeError) as err:
191-
if isinstance(err, NotImplementedError) or any(
192-
k in str(err) for k in ("supported", "unexpected keyword", "implemented")
193-
):
194-
out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
195-
else:
196-
raise ValueError(
197-
f"{img_t.shape} {to_pad_} {mode_} {kwargs_} {img_t.dtype} {img_t.device}"
198-
) from err
199-
else:
200-
out = img_t
201-
if get_track_meta():
202-
self.update_meta(tensor=out, to_pad=to_pad_) # type: ignore
203-
self.push_transform(out, orig_size=_orig_size, extra_info={"padded": to_pad_})
204-
return out
205-
206-
def update_meta(self, tensor: MetaTensor, to_pad: list[tuple[int, int]]):
207-
spatial_rank = max(len(tensor.affine) - 1, 1)
208-
to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad
209-
mat = create_translate(spatial_rank, to_shift)
210-
tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0]
153+
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) # type: ignore
211154

212155
def inverse(self, data: MetaTensor) -> MetaTensor:
213156
transform = self.pop_transform(data)

monai/transforms/croppad/dictionary.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
)
4949
from monai.transforms.inverse import InvertibleTransform
5050
from monai.transforms.traits import MultiSampleTrait
51-
from monai.transforms.transform import MapTransform, Randomizable
51+
from monai.transforms.transform import LazyTransform, MapTransform, Randomizable
5252
from monai.transforms.utils import is_positive
5353
from monai.utils import MAX_SEED, Method, PytorchPadMode, deprecated_arg_default, ensure_tuple_rep
5454

@@ -110,7 +110,7 @@
110110
]
111111

112112

113-
class Padd(MapTransform, InvertibleTransform):
113+
class Padd(MapTransform, InvertibleTransform, LazyTransform):
114114
"""
115115
Dictionary-based wrapper of :py:class:`monai.transforms.Pad`.
116116
@@ -144,6 +144,12 @@ def __init__(
144144
self.padder = padder
145145
self.mode = ensure_tuple_rep(mode, len(self.keys))
146146

147+
@LazyTransform.lazy_evaluation.setter # type: ignore
148+
def lazy_evaluation(self, value: bool) -> None:
149+
self._lazy_evaluation = value
150+
if isinstance(self.padder, LazyTransform):
151+
self.padder.lazy_evaluation = value
152+
147153
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
148154
d = dict(data)
149155
for key, m in self.key_iterator(d, self.mode):
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
A collection of "functional" transforms for spatial operations
13+
https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design
14+
"""
15+
16+
from __future__ import annotations
17+
18+
import numpy as np
19+
import torch
20+
from torch.nn.functional import pad as pad_pt
21+
22+
from monai.data.meta_obj import get_track_meta
23+
from monai.data.meta_tensor import MetaTensor
24+
from monai.transforms.inverse import TraceableTransform
25+
from monai.transforms.utils import convert_pad_mode, create_translate
26+
from monai.utils import TraceKeys, convert_to_dst_type, convert_to_tensor
27+
28+
__all__ = ["pad_nd", "pad_func"]
29+
30+
31+
def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
32+
img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img
33+
mode = convert_pad_mode(dst=img_np, mode=mode).value
34+
if mode == "constant" and "value" in kwargs:
35+
kwargs["constant_values"] = kwargs.pop("value")
36+
out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) # type: ignore
37+
if isinstance(img, MetaTensor):
38+
out = convert_to_dst_type(out, dst=img)[0]
39+
return out
40+
41+
42+
def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
43+
pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1]
44+
# torch.pad expects `[B, C, H, W, [D]]` shape
45+
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0)
46+
47+
48+
def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs):
49+
"""
50+
PyTorch/Numpy pad ``img`` with integers ``to_pad`` amounts. Depending on the ``mode`` and input dtype,
51+
a suitable backend will be used automatically.
52+
53+
Args:
54+
img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
55+
to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
56+
default to `self.to_pad`.
57+
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
58+
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
59+
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
60+
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
61+
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
62+
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
63+
kwargs: other arguments for the `np.pad` or `torch.pad` function.
64+
note that `np.pad` treats channel dimension as the first dimension.
65+
"""
66+
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
67+
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
68+
mode = convert_pad_mode(dst=img, mode=mode).value
69+
try:
70+
_pad = (
71+
_pt_pad
72+
if mode in {"reflect", "replicate"} and img.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}
73+
else _np_pad
74+
)
75+
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
76+
except (ValueError, TypeError, RuntimeError) as err:
77+
if isinstance(err, NotImplementedError) or any(
78+
k in str(err) for k in ("supported", "unexpected keyword", "implemented")
79+
):
80+
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
81+
raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err
82+
83+
84+
def pad_func(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, transform_info: dict, kwargs):
85+
"""
86+
Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according
87+
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).
88+
89+
Args:
90+
img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
91+
to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
92+
default to `self.to_pad`.
93+
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
94+
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
95+
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
96+
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
97+
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
98+
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
99+
transform_info: a dictionary with the relevant information pertaining to an applied transform.
100+
kwargs: other arguments for the `np.pad` or `torch.pad` function.
101+
note that `np.pad` treats channel dimension as the first dimension.
102+
"""
103+
extra_info = {"padded": to_pad}
104+
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
105+
spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3
106+
do_pad = np.asarray(to_pad).any()
107+
if do_pad:
108+
to_pad = list(to_pad)
109+
if len(to_pad) < len(img.shape):
110+
to_pad = list(to_pad) + [(0, 0)] * (len(img.shape) - len(to_pad))
111+
to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad
112+
xform = create_translate(spatial_rank, to_shift)
113+
shape = [d + s + e for d, (s, e) in zip(img_size, to_pad[1:])]
114+
else:
115+
shape = img_size
116+
xform = torch.eye(int(spatial_rank) + 1, device=torch.device("cpu"), dtype=torch.float64)
117+
meta_info = TraceableTransform.track_transform_meta(
118+
img,
119+
sp_size=shape,
120+
affine=xform,
121+
extra_info=extra_info,
122+
orig_size=img_size,
123+
transform_info=transform_info,
124+
lazy_evaluation=transform_info.get(TraceKeys.LAZY_EVALUATION, False),
125+
)
126+
out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())
127+
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
128+
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
129+
out = pad_nd(out, to_pad, mode, **kwargs) if do_pad else out
130+
out = convert_to_tensor(out, track_meta=get_track_meta())
131+
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out

0 commit comments

Comments
 (0)