Skip to content

Commit 4bd93c0

Browse files
bhashemianpre-commit-ci[bot]monai-botericspod
authored
WSI sliding window splitter (#6107)
Fixes #5871 ### Description This PR: - implements a `WSISlidingWindowSplitter ` a splitter for whole slide imaging, - refactors `SlidingWindowSplitter` to make it inheritable for wsi version (backward incompatible arguments), - adds new feature to `AvgMerger` to crop the output (backward compatible), and - updates `PatchInference` to set cropping shape for new feature in `AvgMerger` (backward compatible) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Behrooz <[email protected]> Signed-off-by: monai-bot <[email protected]> Signed-off-by: Dr. Behrooz Hashemian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: monai-bot <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 8494029 commit 4bd93c0

11 files changed

+842
-187
lines changed

docs/source/inferers.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ Splitters
5757
:members:
5858
:special-members: __call__
5959

60+
`WSISlidingWindowSplitter`
61+
~~~~~~~~~~~~~~~~~~~~~~~~~~
62+
.. autoclass:: WSISlidingWindowSplitter
63+
:members:
64+
:special-members: __call__
65+
66+
6067
Mergers
6168
-------
6269
.. currentmodule:: monai.inferers

monai/data/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def iter_patch_position(
209209
image_size: Sequence[int],
210210
patch_size: Sequence[int] | int | np.ndarray,
211211
start_pos: Sequence[int] = (),
212-
overlap: Sequence[float] | float = 0.0,
212+
overlap: Sequence[float] | float | Sequence[int] | int = 0.0,
213213
padded: bool = False,
214214
):
215215
"""
@@ -221,8 +221,10 @@ def iter_patch_position(
221221
image_size: dimensions of array to iterate over
222222
patch_size: size of patches to generate slices for, 0 or None selects whole dimension
223223
start_pos: starting position in the array, default is 0 for each dimension
224-
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
225-
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
224+
overlap: the amount of overlap of neighboring patches in each dimension.
225+
Either a float or list of floats between 0.0 and 1.0 to define relative overlap to patch size, or
226+
an int or list of ints to define number of pixels for overlap.
227+
If only one float/int number is given, it will be applied to all dimensions. Defaults to 0.0.
226228
padded: if the image is padded so the patches can go beyond the borders. Defaults to False.
227229
228230
Yields:
@@ -236,7 +238,10 @@ def iter_patch_position(
236238
overlap = ensure_tuple_rep(overlap, ndim)
237239

238240
# calculate steps, which depends on the amount of overlap
239-
steps = tuple(round(p * (1.0 - o)) for p, o in zip(patch_size_, overlap))
241+
if isinstance(overlap[0], float):
242+
steps = tuple(round(p * (1.0 - o)) for p, o in zip(patch_size_, overlap))
243+
else:
244+
steps = tuple(p - o for p, o in zip(patch_size_, overlap))
240245

241246
# calculate the last starting location (depending on the padding)
242247
end_pos = image_size if padded else tuple(s - round(p) + 1 for s, p in zip(image_size, patch_size_))

monai/inferers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@
2121
SlidingWindowInfererAdapt,
2222
)
2323
from .merger import AvgMerger, Merger
24-
from .splitter import SlidingWindowSplitter, Splitter
24+
from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter
2525
from .utils import sliding_window_inference

monai/inferers/inferer.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class PatchInferer(Inferer):
8787
Args:
8888
splitter: a `Splitter` object that split the inputs into patches. Defaults to None.
8989
If not provided or None, the inputs are considered to be already split into patches.
90+
In this case, the output `merged_shape` and the optional `cropped_shape` cannot be inferred
91+
and should be explicitly provided.
9092
merger_cls: a `Merger` subclass that can be instantiated to merges patch outputs.
9193
It can also be a string that matches the name of a class inherited from `Merger` class.
9294
Defaults to `AvgMerger`.
@@ -100,34 +102,29 @@ class PatchInferer(Inferer):
100102
output_keys: if the network output is a dictionary, this defines the keys of
101103
the output dictionary to be used for merging.
102104
Defaults to None, where all the keys are used.
105+
match_spatial_shape: whether to crop the output to match the input shape. Defaults to True.
103106
merger_kwargs: arguments to be passed to `merger_cls` for instantiation.
104-
`output_shape` is calculated automatically based on the input shape and
107+
`merged_shape` is calculated automatically based on the input shape and
105108
the output patch shape unless it is passed here.
106109
"""
107110

108111
def __init__(
109112
self,
110-
splitter: Splitter | Callable | None = None,
113+
splitter: Splitter | None = None,
111114
merger_cls: type[Merger] | str = AvgMerger,
112115
batch_size: int = 1,
113116
preprocessing: Callable | None = None,
114117
postprocessing: Callable | None = None,
115118
output_keys: Sequence | None = None,
119+
match_spatial_shape: bool = True,
116120
**merger_kwargs: Any,
117121
) -> None:
118122
Inferer.__init__(self)
119-
120123
# splitter
121-
if splitter is not None and not isinstance(splitter, Splitter):
122-
if callable(splitter):
123-
warnings.warn(
124-
"`splitter` is a callable instead of `Splitter` object, please make sure that it returns "
125-
"the correct values. Either Iterable[tuple[torch.Tensor, Sequence[int]]], or "
126-
"a MetaTensor with defined `PatchKey.LOCATION` metadata."
127-
)
128-
else:
124+
if not isinstance(splitter, (Splitter, type(None))):
125+
if not isinstance(splitter, Splitter):
129126
raise TypeError(
130-
f"'splitter' should be a `Splitter` object (or a callable that returns "
127+
f"'splitter' should be a `Splitter` object that returns: "
131128
"an iterable of pairs of (patch, location) or a MetaTensor that has `PatchKeys.LOCATION` metadata)."
132129
f"{type(splitter)} is given."
133130
)
@@ -165,6 +162,9 @@ def __init__(
165162
# model output keys
166163
self.output_keys = output_keys
167164

165+
# whether to crop the output to match the input shape
166+
self.match_spatial_shape = match_spatial_shape
167+
168168
def _batch_sampler(
169169
self, patches: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
170170
) -> Iterator[tuple[torch.Tensor, Sequence, int]]:
@@ -226,14 +226,24 @@ def _initialize_mergers(self, inputs, outputs, patches, batch_size):
226226
out_patch = torch.chunk(out_patch_batch, batch_size)[0]
227227
# calculate the ratio of input and output patch sizes
228228
ratio = tuple(op / ip for ip, op in zip(in_patch.shape[2:], out_patch.shape[2:]))
229-
ratios.append(ratio)
230-
# calculate output_shape only if it is not provided and splitter is not None.
231-
if self.splitter is not None and "output_shape" not in self.merger_kwargs:
232-
output_shape = self._get_output_shape(inputs, out_patch, ratio)
233-
merger = self.merger_cls(output_shape=output_shape, **self.merger_kwargs)
234-
else:
235-
merger = self.merger_cls(**self.merger_kwargs)
229+
230+
# calculate merged_shape and cropped_shape
231+
merger_kwargs = self.merger_kwargs.copy()
232+
cropped_shape, merged_shape = self._get_merged_shapes(inputs, out_patch, ratio)
233+
if "merged_shape" not in merger_kwargs:
234+
merger_kwargs["merged_shape"] = merged_shape
235+
if merger_kwargs["merged_shape"] is None:
236+
raise ValueError("`merged_shape` cannot be `None`.")
237+
if "cropped_shape" not in merger_kwargs:
238+
merger_kwargs["cropped_shape"] = cropped_shape
239+
240+
# initialize the merger
241+
merger = self.merger_cls(**merger_kwargs)
242+
243+
# store mergers and input/output ratios
236244
mergers.append(merger)
245+
ratios.append(ratio)
246+
237247
return mergers, ratios
238248

239249
def _aggregate(self, outputs, locations, batch_size, mergers, ratios):
@@ -243,12 +253,27 @@ def _aggregate(self, outputs, locations, batch_size, mergers, ratios):
243253
out_loc = [round(l * r) for l, r in zip(in_loc, ratio)]
244254
merger.aggregate(out_patch, out_loc)
245255

246-
def _get_output_shape(self, inputs, out_patch, ratio):
247-
"""Define the shape of output merged tensors"""
248-
in_spatial_shape = inputs.shape[2:]
249-
out_spatial_shape = tuple(round(s * r) for s, r in zip(in_spatial_shape, ratio))
250-
output_shape = out_patch.shape[:2] + out_spatial_shape
251-
return output_shape
256+
def _get_merged_shapes(self, inputs, out_patch, ratio):
257+
"""Define the shape of merged tensors (non-padded and padded)"""
258+
if self.splitter is None:
259+
return None, None
260+
261+
# input spatial shapes
262+
original_spatial_shape = self.splitter.get_input_shape(inputs)
263+
padded_spatial_shape = self.splitter.get_padded_shape(inputs)
264+
265+
# output spatial shapes
266+
output_spatial_shape = tuple(round(s * r) for s, r in zip(original_spatial_shape, ratio))
267+
padded_output_spatial_shape = tuple(round(s * r) for s, r in zip(padded_spatial_shape, ratio))
268+
269+
# output shapes
270+
cropped_shape = out_patch.shape[:2] + output_spatial_shape
271+
merged_shape = out_patch.shape[:2] + padded_output_spatial_shape
272+
273+
if not self.match_spatial_shape:
274+
cropped_shape = merged_shape
275+
276+
return cropped_shape, merged_shape
252277

253278
def __call__(
254279
self,
@@ -270,6 +295,7 @@ def __call__(
270295
"""
271296
patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
272297
if self.splitter is None:
298+
# handle situations where the splitter is not provided
273299
if isinstance(inputs, torch.Tensor):
274300
if isinstance(inputs, MetaTensor):
275301
if PatchKeys.LOCATION not in inputs.meta:
@@ -288,6 +314,7 @@ def __call__(
288314
)
289315
patches_locations = inputs
290316
else:
317+
# apply splitter
291318
patches_locations = self.splitter(inputs)
292319

293320
ratios: list[float] = []
@@ -302,7 +329,8 @@ def __call__(
302329
self._aggregate(outputs, locations, batch_size, mergers, ratios)
303330

304331
# finalize the mergers and get the results
305-
merged_outputs = tuple(merger.finalize() for merger in mergers)
332+
merged_outputs = [merger.finalize() for merger in mergers]
333+
306334
# return according to the model output
307335
if self.output_keys:
308336
return dict(zip(self.output_keys, merged_outputs))

monai/inferers/merger.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,20 @@ class Merger(ABC):
3232
- finalize: perform any final process and return the merged output
3333
3434
Args:
35-
output_shape: the shape of the merged output tensor. Default to None.
35+
merged_shape: the shape of the tensor required to merge the patches.
36+
cropped_shape: the shape of the final merged output tensor.
37+
If not provided, it will be the same as `merged_shape`.
3638
device: the device where Merger tensors should reside.
3739
"""
3840

39-
def __init__(self, output_shape: Sequence[int] | None = None, device: torch.device | str | None = None) -> None:
40-
self.output_shape = output_shape
41+
def __init__(
42+
self,
43+
merged_shape: Sequence[int],
44+
cropped_shape: Sequence[int] | None = None,
45+
device: torch.device | str | None = None,
46+
) -> None:
47+
self.merged_shape = merged_shape
48+
self.cropped_shape = self.merged_shape if cropped_shape is None else cropped_shape
4149
self.device = device
4250
self.is_finalized = False
4351

@@ -77,26 +85,29 @@ class AvgMerger(Merger):
7785
"""Merge patches by taking average of the overlapping area
7886
7987
Args:
80-
output_shape: the shape of the merged output tensor.
88+
merged_shape: the shape of the tensor required to merge the patches.
89+
cropped_shape: the shape of the final merged output tensor.
90+
If not provided, it will be the same as `merged_shape`.
8191
device: the device for aggregator tensors and final results.
8292
value_dtype: the dtype for value aggregating tensor and the final result.
8393
count_dtype: the dtype for sample counting tensor.
8494
"""
8595

8696
def __init__(
8797
self,
88-
output_shape: Sequence[int],
98+
merged_shape: Sequence[int],
99+
cropped_shape: Sequence[int] | None = None,
89100
device: torch.device | str = "cpu",
90101
value_dtype: torch.dtype = torch.float32,
91102
count_dtype: torch.dtype = torch.uint8,
92103
) -> None:
93-
super().__init__(output_shape=output_shape, device=device)
94-
if not self.output_shape:
95-
raise ValueError(f"`output_shape` must be provided for `AvgMerger`. {self.output_shape} is give.")
104+
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape, device=device)
105+
if not self.merged_shape:
106+
raise ValueError(f"`merged_shape` must be provided for `AvgMerger`. {self.merged_shape} is give.")
96107
self.value_dtype = value_dtype
97108
self.count_dtype = count_dtype
98-
self.values = torch.zeros(self.output_shape, dtype=self.value_dtype, device=self.device)
99-
self.counts = torch.zeros(self.output_shape, dtype=self.count_dtype, device=self.device)
109+
self.values = torch.zeros(self.merged_shape, dtype=self.value_dtype, device=self.device)
110+
self.counts = torch.zeros(self.merged_shape, dtype=self.count_dtype, device=self.device)
100111

101112
def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:
102113
"""
@@ -134,6 +145,8 @@ def finalize(self) -> torch.Tensor:
134145
if not self.is_finalized:
135146
# use in-place division to save space
136147
self.values.div_(self.counts)
148+
# finalize the shape
149+
self.values = self.values[tuple(slice(0, end) for end in self.cropped_shape)]
137150
# set finalize flag to protect performing in-place division again
138151
self.is_finalized = True
139152

0 commit comments

Comments
 (0)