@@ -87,6 +87,8 @@ class PatchInferer(Inferer):
87
87
Args:
88
88
splitter: a `Splitter` object that split the inputs into patches. Defaults to None.
89
89
If not provided or None, the inputs are considered to be already split into patches.
90
+ In this case, the output `merged_shape` and the optional `cropped_shape` cannot be inferred
91
+ and should be explicitly provided.
90
92
merger_cls: a `Merger` subclass that can be instantiated to merges patch outputs.
91
93
It can also be a string that matches the name of a class inherited from `Merger` class.
92
94
Defaults to `AvgMerger`.
@@ -100,34 +102,29 @@ class PatchInferer(Inferer):
100
102
output_keys: if the network output is a dictionary, this defines the keys of
101
103
the output dictionary to be used for merging.
102
104
Defaults to None, where all the keys are used.
105
+ match_spatial_shape: whether to crop the output to match the input shape. Defaults to True.
103
106
merger_kwargs: arguments to be passed to `merger_cls` for instantiation.
104
- `output_shape ` is calculated automatically based on the input shape and
107
+ `merged_shape ` is calculated automatically based on the input shape and
105
108
the output patch shape unless it is passed here.
106
109
"""
107
110
108
111
def __init__ (
109
112
self ,
110
- splitter : Splitter | Callable | None = None ,
113
+ splitter : Splitter | None = None ,
111
114
merger_cls : type [Merger ] | str = AvgMerger ,
112
115
batch_size : int = 1 ,
113
116
preprocessing : Callable | None = None ,
114
117
postprocessing : Callable | None = None ,
115
118
output_keys : Sequence | None = None ,
119
+ match_spatial_shape : bool = True ,
116
120
** merger_kwargs : Any ,
117
121
) -> None :
118
122
Inferer .__init__ (self )
119
-
120
123
# splitter
121
- if splitter is not None and not isinstance (splitter , Splitter ):
122
- if callable (splitter ):
123
- warnings .warn (
124
- "`splitter` is a callable instead of `Splitter` object, please make sure that it returns "
125
- "the correct values. Either Iterable[tuple[torch.Tensor, Sequence[int]]], or "
126
- "a MetaTensor with defined `PatchKey.LOCATION` metadata."
127
- )
128
- else :
124
+ if not isinstance (splitter , (Splitter , type (None ))):
125
+ if not isinstance (splitter , Splitter ):
129
126
raise TypeError (
130
- f"'splitter' should be a `Splitter` object (or a callable that returns "
127
+ f"'splitter' should be a `Splitter` object that returns: "
131
128
"an iterable of pairs of (patch, location) or a MetaTensor that has `PatchKeys.LOCATION` metadata)."
132
129
f"{ type (splitter )} is given."
133
130
)
@@ -165,6 +162,9 @@ def __init__(
165
162
# model output keys
166
163
self .output_keys = output_keys
167
164
165
+ # whether to crop the output to match the input shape
166
+ self .match_spatial_shape = match_spatial_shape
167
+
168
168
def _batch_sampler (
169
169
self , patches : Iterable [tuple [torch .Tensor , Sequence [int ]]] | MetaTensor
170
170
) -> Iterator [tuple [torch .Tensor , Sequence , int ]]:
@@ -226,14 +226,24 @@ def _initialize_mergers(self, inputs, outputs, patches, batch_size):
226
226
out_patch = torch .chunk (out_patch_batch , batch_size )[0 ]
227
227
# calculate the ratio of input and output patch sizes
228
228
ratio = tuple (op / ip for ip , op in zip (in_patch .shape [2 :], out_patch .shape [2 :]))
229
- ratios .append (ratio )
230
- # calculate output_shape only if it is not provided and splitter is not None.
231
- if self .splitter is not None and "output_shape" not in self .merger_kwargs :
232
- output_shape = self ._get_output_shape (inputs , out_patch , ratio )
233
- merger = self .merger_cls (output_shape = output_shape , ** self .merger_kwargs )
234
- else :
235
- merger = self .merger_cls (** self .merger_kwargs )
229
+
230
+ # calculate merged_shape and cropped_shape
231
+ merger_kwargs = self .merger_kwargs .copy ()
232
+ cropped_shape , merged_shape = self ._get_merged_shapes (inputs , out_patch , ratio )
233
+ if "merged_shape" not in merger_kwargs :
234
+ merger_kwargs ["merged_shape" ] = merged_shape
235
+ if merger_kwargs ["merged_shape" ] is None :
236
+ raise ValueError ("`merged_shape` cannot be `None`." )
237
+ if "cropped_shape" not in merger_kwargs :
238
+ merger_kwargs ["cropped_shape" ] = cropped_shape
239
+
240
+ # initialize the merger
241
+ merger = self .merger_cls (** merger_kwargs )
242
+
243
+ # store mergers and input/output ratios
236
244
mergers .append (merger )
245
+ ratios .append (ratio )
246
+
237
247
return mergers , ratios
238
248
239
249
def _aggregate (self , outputs , locations , batch_size , mergers , ratios ):
@@ -243,12 +253,27 @@ def _aggregate(self, outputs, locations, batch_size, mergers, ratios):
243
253
out_loc = [round (l * r ) for l , r in zip (in_loc , ratio )]
244
254
merger .aggregate (out_patch , out_loc )
245
255
246
- def _get_output_shape (self , inputs , out_patch , ratio ):
247
- """Define the shape of output merged tensors"""
248
- in_spatial_shape = inputs .shape [2 :]
249
- out_spatial_shape = tuple (round (s * r ) for s , r in zip (in_spatial_shape , ratio ))
250
- output_shape = out_patch .shape [:2 ] + out_spatial_shape
251
- return output_shape
256
+ def _get_merged_shapes (self , inputs , out_patch , ratio ):
257
+ """Define the shape of merged tensors (non-padded and padded)"""
258
+ if self .splitter is None :
259
+ return None , None
260
+
261
+ # input spatial shapes
262
+ original_spatial_shape = self .splitter .get_input_shape (inputs )
263
+ padded_spatial_shape = self .splitter .get_padded_shape (inputs )
264
+
265
+ # output spatial shapes
266
+ output_spatial_shape = tuple (round (s * r ) for s , r in zip (original_spatial_shape , ratio ))
267
+ padded_output_spatial_shape = tuple (round (s * r ) for s , r in zip (padded_spatial_shape , ratio ))
268
+
269
+ # output shapes
270
+ cropped_shape = out_patch .shape [:2 ] + output_spatial_shape
271
+ merged_shape = out_patch .shape [:2 ] + padded_output_spatial_shape
272
+
273
+ if not self .match_spatial_shape :
274
+ cropped_shape = merged_shape
275
+
276
+ return cropped_shape , merged_shape
252
277
253
278
def __call__ (
254
279
self ,
@@ -270,6 +295,7 @@ def __call__(
270
295
"""
271
296
patches_locations : Iterable [tuple [torch .Tensor , Sequence [int ]]] | MetaTensor
272
297
if self .splitter is None :
298
+ # handle situations where the splitter is not provided
273
299
if isinstance (inputs , torch .Tensor ):
274
300
if isinstance (inputs , MetaTensor ):
275
301
if PatchKeys .LOCATION not in inputs .meta :
@@ -288,6 +314,7 @@ def __call__(
288
314
)
289
315
patches_locations = inputs
290
316
else :
317
+ # apply splitter
291
318
patches_locations = self .splitter (inputs )
292
319
293
320
ratios : list [float ] = []
@@ -302,7 +329,8 @@ def __call__(
302
329
self ._aggregate (outputs , locations , batch_size , mergers , ratios )
303
330
304
331
# finalize the mergers and get the results
305
- merged_outputs = tuple (merger .finalize () for merger in mergers )
332
+ merged_outputs = [merger .finalize () for merger in mergers ]
333
+
306
334
# return according to the model output
307
335
if self .output_keys :
308
336
return dict (zip (self .output_keys , merged_outputs ))
0 commit comments