Skip to content

Commit

Permalink
refactor transforms into submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Jan 22, 2025
1 parent 4d07b55 commit bc49671
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 151 deletions.
31 changes: 31 additions & 0 deletions viscy/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from viscy.transforms._redef import (
RandAdjustContrastd,
RandAffined,
RandGaussianNoised,
RandGaussianSmoothd,
RandScaleIntensityd,
RandWeightedCropd,
ScaleIntensityRangePercentilesd,
)
from viscy.transforms._transforms import (
BatchedZoom,
NormalizeSampled,
RandInvertIntensityd,
StackChannelsd,
TiledSpatialCropSamplesd,
)

__all__ = [
"BatchedZoom",
"NormalizeSampled",
"RandAdjustContrastd",
"RandAffined",
"RandGaussianNoised",
"RandGaussianSmoothd",
"RandInvertIntensityd",
"RandScaleIntensityd",
"RandWeightedCropd",
"ScaleIntensityRangePercentilesd",
"StackChannelsd",
"TiledSpatialCropSamplesd",
]
73 changes: 73 additions & 0 deletions viscy/transforms/_gaussian_blur.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""3D version of `kornia.augmentation._2d.intensity.gaussian_blur`."""

from typing import Any, Iterable

from kornia.augmentation import random_generator as rg
from kornia.augmentation._3d.intensity.base import IntensityAugmentationBase3D
from kornia.constants import BorderType
from kornia.filters import filter3d, get_gaussian_kernel3d
from monai.transforms import MapTransform, RandomizableTransform
from torch import Tensor


class RandomGaussianBlur(IntensityAugmentationBase3D):
def __init__(
self,
kernel_size: tuple[int, int, int] | int,
sigma: tuple[float, float, float] | Tensor,
border_type: str = "reflect",
same_on_batch: bool = False,
p: float = 0.5,
keepdim: bool = False,
) -> None:
super().__init__(p=p, same_on_batch=same_on_batch, p_batch=1.0, keepdim=keepdim)

self.flags = {
"kernel_size": kernel_size,
"border_type": BorderType.get(border_type),
}
self._param_generator = rg.RandomGaussianBlurGenerator(sigma)

def apply_transform(
self,
input: Tensor,
params: dict[str, Tensor],
flags: dict[str, Any],
transform: Tensor | None = None,
) -> Tensor:
sigma = params["sigma"].unsqueeze(-1).expand(-1, 2)
kernel = get_gaussian_kernel3d(
kernel_size=self.flags["kernel_size"], sigma=sigma
)
return filter3d(input, kernel, border_type=self.flags["border_type"])


class BatchedRandGaussianBlurd(MapTransform, RandomizableTransform):
def __init__(
self,
keys: str | Iterable[str],
kernel_size: tuple[int, int] | int,
sigma: tuple[float, float],
border_type: str = "reflect",
same_on_batch: bool = False,
prob: float = 0.1,
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys=allow_missing_keys)
RandomizableTransform.__init__(self, prob)
self.filter = RandomGaussianBlur(
kernel_size=kernel_size,
sigma=sigma,
border_type=border_type,
same_on_batch=same_on_batch,
p=prob,
)

def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]:
self.randomize(None)
if not self._do_transform:
return sample
for key in self.keys:
if key in sample:
sample[key] = -sample[key]
return sample
134 changes: 134 additions & 0 deletions viscy/transforms/_redef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Redefine transforms from MONAI for jsonargparse."""

from typing import Sequence

from monai.transforms import (
RandAdjustContrastd,
RandAffined,
RandGaussianNoised,
RandGaussianSmoothd,
RandScaleIntensityd,
RandWeightedCropd,
ScaleIntensityRangePercentilesd,
)
from numpy.typing import DTypeLike


class RandWeightedCropd(RandWeightedCropd):
def __init__(
self,
keys: Sequence[str] | str,
w_key: str,
spatial_size: Sequence[int],
num_samples: int = 1,
**kwargs,
):
super().__init__(
keys=keys,
w_key=w_key,
spatial_size=spatial_size,
num_samples=num_samples,
**kwargs,
)


class RandAffined(RandAffined):
def __init__(
self,
keys: Sequence[str] | str,
prob: float,
rotate_range: Sequence[float] | float,
shear_range: Sequence[float] | float,
scale_range: Sequence[float] | float,
**kwargs,
):
super().__init__(
keys=keys,
prob=prob,
rotate_range=rotate_range,
shear_range=shear_range,
scale_range=scale_range,
**kwargs,
)


class RandAdjustContrastd(RandAdjustContrastd):
def __init__(
self,
keys: Sequence[str] | str,
prob: float,
gamma: tuple[float, float] | float,
**kwargs,
):
super().__init__(keys=keys, prob=prob, gamma=gamma, **kwargs)


class RandScaleIntensityd(RandScaleIntensityd):
def __init__(
self,
keys: Sequence[str] | str,
factors: tuple[float, float] | float,
prob: float,
**kwargs,
):
super().__init__(keys=keys, factors=factors, prob=prob, **kwargs)


class RandGaussianNoised(RandGaussianNoised):
def __init__(
self,
keys: Sequence[str] | str,
prob: float,
mean: float,
std: float,
**kwargs,
):
super().__init__(keys=keys, prob=prob, mean=mean, std=std, **kwargs)


class RandGaussianSmoothd(RandGaussianSmoothd):
def __init__(
self,
keys: Sequence[str] | str,
prob: float,
sigma_x: tuple[float, float] | float,
sigma_y: tuple[float, float] | float,
sigma_z: tuple[float, float] | float,
**kwargs,
):
super().__init__(
keys=keys,
prob=prob,
sigma_x=sigma_x,
sigma_y=sigma_y,
sigma_z=sigma_z,
**kwargs,
)


class ScaleIntensityRangePercentilesd(ScaleIntensityRangePercentilesd):
def __init__(
self,
keys: Sequence[str] | str,
lower: float,
upper: float,
b_min: float | None,
b_max: float | None,
clip: bool = False,
relative: bool = False,
channel_wise: bool = False,
dtype: DTypeLike | None = None,
allow_missing_keys: bool = False,
):
super().__init__(
keys=keys,
lower=lower,
upper=upper,
b_min=b_min,
b_max=b_max,
clip=clip,
relative=relative,
channel_wise=channel_wise,
dtype=dtype,
allow_missing_keys=allow_missing_keys,
)
Loading

0 comments on commit bc49671

Please sign in to comment.