diff --git a/torch_em/transform/generic.py b/torch_em/transform/generic.py index e7526443..a0c2d3f4 100644 --- a/torch_em/transform/generic.py +++ b/torch_em/transform/generic.py @@ -108,10 +108,11 @@ def __call__(self, inputs): class ResizeLongestSideInputs: - def __init__(self, target_shape, is_label=False, is_rgb=False): + def __init__(self, target_shape, is_label=False, is_rgb=False, padding_mode="constant"): self.target_shape = target_shape self.is_label = is_label self.is_rgb = is_rgb + self.padding_mode = padding_mode h, w = self.target_shape[-2], self.target_shape[-1] if h != w: # We currently support resize feature for square-shaped target shape only. @@ -135,7 +136,7 @@ def _get_preprocess_shape(self, oldh, oldw): newh = int(newh + 0.5) return (newh, neww) - def convert_transformed_inputs_to_original_shape(self, resized_inputs): + def convert_transformed_inputs_to_original_shape(self, resized_inputs, resize_kwargs=None): if not hasattr(self, "pre_pad_shape"): raise RuntimeError( "'convert_transformed_inputs_to_original_shape' is only valid after the '__call__' method has run." @@ -144,8 +145,15 @@ def convert_transformed_inputs_to_original_shape(self, resized_inputs): # First step is to remove the padded region inputs = resized_inputs[tuple(self.pre_pad_shape)] # Next, we resize the inputs to original shape + + if resize_kwargs is None: # This allows the user to change resize parameters, eg. for labels, if desired. + resize_kwargs = self.kwargs + else: + if not isinstance(resize_kwargs, dict): + raise RuntimeError("If the 'resize_kwargs' are provided, it must be a dictionary.") + inputs = resize( - image=inputs, output_shape=self.original_shape, preserve_range=True, **self.kwargs + image=inputs, output_shape=self.original_shape, preserve_range=True, **resize_kwargs ) return inputs @@ -181,13 +189,14 @@ def __call__(self, inputs): # NOTE: We store this in case we would like to unpad the inputs. self.pre_pad_shape = [slice(pw[0], -pw[1] if pw[1] > 0 else None) for pw in pad_width] - inputs = np.pad(array=inputs, pad_width=pad_width, mode="constant") + inputs = np.pad(array=inputs, pad_width=pad_width, mode=self.padding_mode) return inputs class PadIfNecessary: - def __init__(self, shape): + def __init__(self, shape, padding_mode="reflect"): self.shape = tuple(shape) + self.padding_mode = padding_mode def _pad_if_necessary(self, data): if data.ndim == len(self.shape): @@ -204,7 +213,7 @@ def _pad_if_necessary(self, data): pad_width = [sh - dsh for dsh, sh in zip(data_shape, pad_shape)] assert all(pw >= 0 for pw in pad_width) pad_width = [(0, pw) for pw in pad_width] - return np.pad(data, pad_width, mode="reflect") + return np.pad(data, pad_width, mode=self.padding_mode) def __call__(self, *inputs): outputs = tuple(self._pad_if_necessary(input_) for input_ in inputs)