Skip to content

Commit

Permalink
Minor update to expose padding mode and make resize flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jan 28, 2025
1 parent e0b2356 commit 2c7dbfb
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions torch_em/transform/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def __call__(self, inputs):


class ResizeLongestSideInputs:
def __init__(self, target_shape, is_label=False, is_rgb=False):
def __init__(self, target_shape, is_label=False, is_rgb=False, padding_mode="constant"):
self.target_shape = target_shape
self.is_label = is_label
self.is_rgb = is_rgb
self.padding_mode = padding_mode

h, w = self.target_shape[-2], self.target_shape[-1]
if h != w: # We currently support resize feature for square-shaped target shape only.
Expand All @@ -135,7 +136,7 @@ def _get_preprocess_shape(self, oldh, oldw):
newh = int(newh + 0.5)
return (newh, neww)

def convert_transformed_inputs_to_original_shape(self, resized_inputs):
def convert_transformed_inputs_to_original_shape(self, resized_inputs, resize_kwargs=None):
if not hasattr(self, "pre_pad_shape"):
raise RuntimeError(
"'convert_transformed_inputs_to_original_shape' is only valid after the '__call__' method has run."
Expand All @@ -144,8 +145,15 @@ def convert_transformed_inputs_to_original_shape(self, resized_inputs):
# First step is to remove the padded region
inputs = resized_inputs[tuple(self.pre_pad_shape)]
# Next, we resize the inputs to original shape

if resize_kwargs is None: # This allows the user to change resize parameters, eg. for labels, if desired.
resize_kwargs = self.kwargs
else:
if not isinstance(resize_kwargs, dict):
raise RuntimeError("If the 'resize_kwargs' are provided, it must be a dictionary.")

inputs = resize(
image=inputs, output_shape=self.original_shape, preserve_range=True, **self.kwargs
image=inputs, output_shape=self.original_shape, preserve_range=True, **resize_kwargs
)
return inputs

Expand Down Expand Up @@ -181,13 +189,14 @@ def __call__(self, inputs):
# NOTE: We store this in case we would like to unpad the inputs.
self.pre_pad_shape = [slice(pw[0], -pw[1] if pw[1] > 0 else None) for pw in pad_width]

inputs = np.pad(array=inputs, pad_width=pad_width, mode="constant")
inputs = np.pad(array=inputs, pad_width=pad_width, mode=self.padding_mode)
return inputs


class PadIfNecessary:
def __init__(self, shape):
def __init__(self, shape, padding_mode="reflect"):
self.shape = tuple(shape)
self.padding_mode = padding_mode

def _pad_if_necessary(self, data):
if data.ndim == len(self.shape):
Expand All @@ -204,7 +213,7 @@ def _pad_if_necessary(self, data):
pad_width = [sh - dsh for dsh, sh in zip(data_shape, pad_shape)]
assert all(pw >= 0 for pw in pad_width)
pad_width = [(0, pw) for pw in pad_width]
return np.pad(data, pad_width, mode="reflect")
return np.pad(data, pad_width, mode=self.padding_mode)

def __call__(self, *inputs):
outputs = tuple(self._pad_if_necessary(input_) for input_ in inputs)
Expand Down

0 comments on commit 2c7dbfb

Please sign in to comment.