Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement resize and pad trafo WIP #191

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions torch_em/transform/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self,
control_point_spacing=1,
sigma=(4.0, 4.0),
alpha=(32.0, 32.0),
resample=kornia.constants.Resample.BILINEAR,
interpolation=kornia.constants.Resample.BILINEAR,
p=0.5,
keepdim=False,
same_on_batch=False):
Expand All @@ -84,9 +84,9 @@ def __init__(self,
else:
self.control_point_spacing = control_point_spacing
assert len(self.control_point_spacing) == 2
self.resample = resample
self.interpolation = interpolation
self.flags = dict(
resample=torch.tensor(self.resample.value),
interpolation=torch.tensor(self.interpolation.value),
sigma=sigma,
alpha=alpha
)
Expand Down Expand Up @@ -114,13 +114,78 @@ def __call__(self, input, params=None):
params = self.generate_parameters(input.shape)
self._params = params
noise = params["noise"]
mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest"
mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
return kornia.geometry.transform.elastic_transform2d(
input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode,
padding_mode="reflection"
)


class RandomResizeAndPad(kornia.augmentation.AugmentationBase2D):
"""Bring inputs to output shape by randomly resizing and padding.
"""
def __init__(self, output_shape, padding_mode="constant", same_on_batch=False):
super().__init__(
p=1.0, same_on_batch=same_on_batch
)
if len(output_shape) != 2:
raise ValueError(f"Can only resize 2d shape, got {len(output_shape)}")
self.output_shape = output_shape
self.padding_mode = padding_mode
self.flags = dict(interpolation=torch.tensor(kornia.constants.Resample.BILINEAR.value))

def generate_parameters(self, shape):
assert len(shape) == len(self.output_shape)
resize_shape = []
for ims, outs in zip(shape, self.output_shape):
diff = outs - ims

# The output shape is bigger than the input shape.
# We resize to a random size in between the two (the rest will get padded).
if diff > 0:
res = np.random.randint(ims, outs + 1)

# The output shape is smaller than or equal to the input shape.
# We just resize to the output shape.
else:
res = outs

resize_shape.append(res)
return {"resize_shape": tuple(resize_shape)}

def resize_and_pad(self, data, resize_shape):
# shapes match already, we don't have to do anything
if tuple(data[-2:].shape) == self.output_shape:
return data

# interpolation mode
mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
antialias = True if (self.flags["interpolation"] == 1).all() else False

print("Resizing", data.shape, "to", resize_shape, "with", mode)
out = kornia.geometry.transform.resize(
data, resize_shape, interpolation=mode, antialias=antialias
)

# pad the rest
if tuple(out.shape[-2:]) != self.output_shape:
pad_shape = tuple(
outsh - sh for outsh, sh in zip(self.output_shape, out.shape[-2:])
)
pad_shape = (pad_shape[1], 0, pad_shape[0], 0)
assert all(ps >= 0 for ps in pad_shape), f"{pad_shape}"
out = torch.nn.functional.pad(out, pad_shape, mode=self.padding_mode)

assert tuple(out.shape[-2:]) == self.output_shape, f"{out.shape}, {self.output_shape}"
return out

def __call__(self, input, params=None):
if params is None:
params = self.generate_parameters(input.shape[-2:])
self._params = params
return self.resize_and_pad(input, params["resize_shape"])


# TODO implement 'require_halo', and estimate the halo from the transformations
# so that we can load a bigger block and cut it away
class KorniaAugmentationPipeline(torch.nn.Module):
Expand Down