Skip to content

Commit d7a8774

Browse files
5991 Enable lazy resampling for SpatialResample (#6060)
A part of #5991 . ### Description This PR adds the changes of `SpatialResample` from #5860 . ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <[email protected]>
1 parent 69d807a commit d7a8774

File tree

10 files changed

+135
-20
lines changed

10 files changed

+135
-20
lines changed

docs/source/transforms.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,16 @@ Generic Interfaces
9191
Functionals
9292
-----------
9393

94+
Crop and Pad (functional)
95+
^^^^^^^^^^^^^^^^^^^^^^^^^
9496
.. automodule:: monai.transforms.croppad.functional
9597
:members:
9698

99+
Spatial (functional)
100+
^^^^^^^^^^^^^^^^^^^^
101+
.. automodule:: monai.transforms.spatial.functional
102+
:members:
103+
97104
.. currentmodule:: monai.transforms
98105

99106
Vanilla Transforms

monai/data/meta_tensor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def affine(self) -> torch.Tensor:
461461
@affine.setter
462462
def affine(self, d: NdarrayTensor) -> None:
463463
"""Set the affine."""
464-
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.double)
464+
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
465465

466466
@property
467467
def pixdim(self):
@@ -471,7 +471,10 @@ def pixdim(self):
471471
return affine_to_spacing(self.affine)
472472

473473
def peek_pending_shape(self):
474-
"""Get the currently expected spatial shape as if all the pending operations are executed."""
474+
"""
475+
Get the currently expected spatial shape as if all the pending operations are executed.
476+
For tensors that have more than 3 spatial dimensions, only the shapes of the top 3 dimensions will be returned.
477+
"""
475478
res = None
476479
if self.pending_operations:
477480
res = self.pending_operations[-1].get(LazyAttr.SHAPE, None)
@@ -480,11 +483,13 @@ def peek_pending_shape(self):
480483

481484
def peek_pending_affine(self):
482485
res = self.affine
486+
r = len(res) - 1
483487
for p in self.pending_operations:
484-
next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE))
488+
next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE), dtype=torch.float64)
485489
if next_matrix is None:
486490
continue
487491
res = convert_to_dst_type(res, next_matrix)[0]
492+
next_matrix = monai.data.utils.to_affine_nd(r, next_matrix)
488493
res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix)
489494
return res
490495

monai/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
SpatialPadD,
9090
SpatialPadDict,
9191
)
92+
from .croppad.functional import pad_func, pad_nd
9293
from .intensity.array import (
9394
AdjustContrast,
9495
ComputeHoVerMaps,
@@ -453,6 +454,7 @@
453454
ZoomD,
454455
ZoomDict,
455456
)
457+
from .spatial.functional import spatial_resample
456458
from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ThreadUnsafe
457459
from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform
458460
from .utility.array import (

monai/transforms/inverse.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def push_transform(self, data, *args, **kwargs):
121121
return data.copy_meta_from(meta_obj)
122122
if do_transform:
123123
xform = data.pending_operations.pop()
124+
extra = xform.copy()
124125
xform.update(transform_info)
125-
meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval)
126+
meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval, extra_info=extra)
126127
return data.copy_meta_from(meta_obj)
127128
return data
128129
kwargs["lazy_evaluation"] = lazy_eval
@@ -177,9 +178,9 @@ def track_transform_meta(
177178
if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor):
178179
# not lazy evaluation, directly update the metatensor affine (don't push to the stack)
179180
orig_affine = data_t.peek_pending_affine()
180-
orig_affine = convert_to_dst_type(orig_affine, affine)[0]
181-
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=affine.dtype)
182-
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"))
181+
orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]
182+
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
183+
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64)
183184

184185
if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
185186
if isinstance(data, Mapping):
@@ -199,6 +200,8 @@ def track_transform_meta(
199200
info[TraceKeys.ORIG_SIZE] = data_t.shape[1:]
200201
# include extra_info
201202
if extra_info is not None:
203+
extra_info.pop(LazyAttr.SHAPE, None)
204+
extra_info.pop(LazyAttr.AFFINE, None)
202205
info[TraceKeys.EXTRA_INFO] = extra_info
203206

204207
# push the transform info to the applied_operation or pending_operation stack

monai/transforms/lazy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,6 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, spatial_size, kwargs:
129129
"padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None),
130130
}
131131
resampler = monai.transforms.SpatialResample(**init_kwargs)
132-
# resampler.lazy_evaluation = False # resampler is a lazytransform
132+
resampler.lazy_evaluation = False # resampler is a lazytransform
133133
with resampler.trace_transform(False): # don't track this transform in `img`
134134
return resampler(img=img, **call_kwargs)

monai/transforms/spatial/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __call__(
201201
"""
202202
# get dtype as torch (e.g., torch.float64)
203203
dtype_pt = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
204-
align_corners = self.align_corners if align_corners is None else align_corners
204+
align_corners = align_corners if align_corners is not None else self.align_corners
205205
mode = mode if mode is not None else self.mode
206206
padding_mode = padding_mode if padding_mode is not None else self.padding_mode
207207
return spatial_resample(

monai/transforms/spatial/dictionary.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
Zoom,
5555
)
5656
from monai.transforms.traits import MultiSampleTrait
57-
from monai.transforms.transform import MapTransform, RandomizableTransform
57+
from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform
5858
from monai.transforms.utils import create_grid
5959
from monai.utils import (
6060
GridSampleMode,
@@ -142,7 +142,7 @@
142142
]
143143

144144

145-
class SpatialResampled(MapTransform, InvertibleTransform):
145+
class SpatialResampled(MapTransform, InvertibleTransform, LazyTransform):
146146
"""
147147
Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`.
148148
@@ -204,6 +204,11 @@ def __init__(
204204
self.dtype = ensure_tuple_rep(dtype, len(self.keys))
205205
self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys))
206206

207+
@LazyTransform.lazy_evaluation.setter # type: ignore
208+
def lazy_evaluation(self, val: bool) -> None:
209+
self._lazy_evaluation = val
210+
self.sp_transform.lazy_evaluation = val
211+
207212
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
208213
d: dict = dict(data)
209214
for key, mode, padding_mode, align_corners, dtype, dst_key in self.key_iterator(

monai/transforms/spatial/functional.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,37 @@
5050
def spatial_resample(
5151
img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, transform_info
5252
) -> torch.Tensor:
53+
"""
54+
Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``.
55+
This function operates eagerly or lazily according to
56+
``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).
57+
58+
Args:
59+
img: data to be resampled, assuming `img` is channel-first.
60+
dst_affine: target affine matrix, if None, use the input affine matrix, effectively no resampling.
61+
spatial_size: output spatial size, if the component is ``-1``, use the corresponding input spatial size.
62+
mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
63+
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
64+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
65+
When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
66+
and the value represents the order of the spline interpolation.
67+
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
68+
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
69+
Padding mode for outside grid values. Defaults to ``"border"``.
70+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
71+
When `mode` is an integer, using numpy/cupy backends, this argument accepts
72+
{'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
73+
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
74+
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
75+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
76+
Defaults to ``None``, effectively using the value of `self.align_corners`.
77+
dtype_pt: data `dtype` for resampling computation.
78+
transform_info: a dictionary with the relevant information pertaining to an applied transform.
79+
"""
5380
original_spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
5481
src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4)
5582
img = convert_to_tensor(data=img, track_meta=get_track_meta())
83+
# ensure spatial rank is <= 3
5684
spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3)
5785
if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None:
5886
spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size
@@ -101,7 +129,7 @@ def spatial_resample(
101129
# no significant change or lazy change, return original image
102130
out = convert_to_tensor(img, track_meta=get_track_meta())
103131
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
104-
im_size = torch.tensor(img.shape).tolist()
132+
im_size = list(img.shape)
105133
chns, in_sp_size, additional_dims = im_size[0], im_size[1 : spatial_rank + 1], im_size[spatial_rank + 1 :]
106134

107135
if additional_dims:

tests/test_spatial_resample.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from monai.data.meta_tensor import MetaTensor
2323
from monai.data.utils import to_affine_nd
2424
from monai.transforms import SpatialResample
25+
from monai.transforms.lazy.functional import apply_transforms
2526
from monai.utils import optional_import
2627
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose
2728

@@ -131,6 +132,28 @@
131132
TEST_TORCH_INPUT.append(t + [track_meta])
132133

133134

135+
def get_apply_param(init_param=None, call_param=None):
136+
apply_param = {}
137+
for key in ["pending", "mode", "padding_mode", "dtype", "align_corners"]:
138+
if init_param:
139+
if key in init_param.keys():
140+
apply_param[key] = init_param[key]
141+
if call_param:
142+
if key in call_param.keys():
143+
apply_param[key] = call_param[key]
144+
return apply_param
145+
146+
147+
def test_resampler_lazy(resampler, non_lazy_out, init_param=None, call_param=None):
148+
resampler.lazy_evaluation = True
149+
pending_out = resampler(**call_param)
150+
assert_allclose(pending_out.peek_pending_affine(), non_lazy_out.affine)
151+
assert_allclose(pending_out.peek_pending_shape(), non_lazy_out.shape[1:4])
152+
apply_param = get_apply_param(init_param, call_param)
153+
lazy_out = apply_transforms(pending_out, **apply_param)[0]
154+
assert_allclose(lazy_out, non_lazy_out, rtol=1e-5)
155+
156+
134157
class TestSpatialResample(unittest.TestCase):
135158
@parameterized.expand(TESTS)
136159
def test_flips(self, img, device, data_param, expected_output):
@@ -140,9 +163,14 @@ def test_flips(self, img, device, data_param, expected_output):
140163
img.affine = torch.eye(4)
141164
if hasattr(img, "to"):
142165
img = img.to(device)
143-
out = SpatialResample()(img=img, **data_param)
166+
resampler = SpatialResample()
167+
call_param = data_param.copy()
168+
call_param["img"] = img
169+
out = resampler(**call_param)
144170
assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2)
145-
assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), data_param["dst_affine"])
171+
assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), call_param["dst_affine"])
172+
173+
test_resampler_lazy(resampler, out, init_param=None, call_param=call_param)
146174

147175
@parameterized.expand(TEST_4_5_D)
148176
def test_4d_5d(self, new_shape, tile, device, dtype, expected_data):
@@ -152,10 +180,15 @@ def test_4d_5d(self, new_shape, tile, device, dtype, expected_data):
152180

153181
dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]])
154182
dst = dst.to(dtype)
155-
out = SpatialResample(dtype=dtype, align_corners=True)(img=img, dst_affine=dst, align_corners=False)
183+
init_param = {"dtype": dtype, "align_corners": True}
184+
call_param = {"img": img, "dst_affine": dst, "align_corners": False}
185+
resampler = SpatialResample(**init_param)
186+
out = resampler(**call_param)
156187
assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2)
157188
assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2)
158189

190+
test_resampler_lazy(resampler, out, init_param, call_param)
191+
159192
@parameterized.expand(TEST_DEVICES)
160193
def test_ill_affine(self, device):
161194
img = MetaTensor(torch.arange(12).reshape(1, 2, 2, 3)).to(device)
@@ -182,9 +215,14 @@ def test_input_torch(self, new_shape, tile, device, dtype, expected_data, track_
182215
img = torch.as_tensor(np.tile(img, tile)).to(device)
183216
dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]])
184217
dst = dst.to(dtype).to(device)
185-
186-
out = SpatialResample(dtype=dtype)(img=img, dst_affine=dst)
218+
init_param = {"dtype": dtype}
219+
call_param = {"img": img, "dst_affine": dst}
220+
resampler = SpatialResample(**init_param)
221+
out = resampler(**call_param)
187222
assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2)
223+
224+
test_resampler_lazy(resampler, out, init_param, call_param)
225+
188226
if track_meta:
189227
self.assertIsInstance(out, MetaTensor)
190228
assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2)

tests/test_spatial_resampled.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from monai.data.meta_tensor import MetaTensor
2121
from monai.data.utils import to_affine_nd
22+
from monai.transforms.lazy.functional import apply_transforms
2223
from monai.transforms.spatial.dictionary import SpatialResampled
2324
from tests.utils import TEST_DEVICES, assert_allclose
2425

@@ -85,19 +86,45 @@
8586
)
8687

8788

89+
def get_apply_param(init_param=None, call_param=None):
90+
apply_param = {}
91+
for key in ["pending", "mode", "padding_mode", "dtype", "align_corners"]:
92+
if init_param:
93+
if key in init_param.keys():
94+
apply_param[key] = init_param[key]
95+
if call_param:
96+
if key in call_param.keys():
97+
apply_param[key] = call_param[key]
98+
return apply_param
99+
100+
88101
class TestSpatialResample(unittest.TestCase):
89102
@parameterized.expand(TESTS)
90103
def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output):
91104
img = MetaTensor(img, affine=torch.eye(4)).to(device)
92105
data = {"img": img, "dst_affine": dst_affine}
93-
94-
xform = SpatialResampled(keys="img", **kwargs)
95-
output_data = xform(data)
106+
init_param = kwargs.copy()
107+
init_param["keys"] = "img"
108+
call_param = {"data": data}
109+
xform = SpatialResampled(**init_param)
110+
output_data = xform(**call_param)
96111
out = output_data["img"]
97112

98113
assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2)
99114
assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), dst_affine, rtol=1e-2, atol=1e-2)
100115

116+
# check lazy
117+
lazy_xform = SpatialResampled(**init_param)
118+
lazy_xform.lazy_evaluation = True
119+
pending_output_data = lazy_xform(**call_param)
120+
pending_out = pending_output_data["img"]
121+
assert_allclose(pending_out.peek_pending_affine(), out.affine)
122+
assert_allclose(pending_out.peek_pending_shape(), out.shape[1:4])
123+
apply_param = get_apply_param(init_param=init_param, call_param=call_param)
124+
lazy_out = apply_transforms(pending_out, **apply_param)[0]
125+
assert_allclose(lazy_out, out, rtol=1e-5)
126+
127+
# check inverse
101128
inverted = xform.inverse(output_data)["img"]
102129
self.assertEqual(inverted.applied_operations, []) # no further invert after inverting
103130
expected_affine = to_affine_nd(len(out.affine) - 1, torch.eye(4))

0 commit comments

Comments
 (0)